介紹
計算圖本身是一個有向無環(huán)圖,它主要由一組節(jié)點(Node,抽象表示一個Op函數(shù)執(zhí)行)與表示節(jié)點之間相互依賴的邊(Edge,表示Nodes之間的輸入、輸出或次序控制依賴關(guān)系)組成。
本節(jié)當(dāng)中我們將詳細分析Tensorflow里面節(jié)點與邊的一些實現(xiàn)。
節(jié)點
以下為Node class在Tensorflow里面的定義。詳見:tensorflow/core/graph/graph.h。我們將在其代碼中逐個分析其不同的屬性與方法。
- 基本方法與屬性
class Node {
public:
string DebugString() const;
int id() const { return id_; } // 每個節(jié)點都會分配這么一個固定的id,同一副圖里面的不同Node有著其唯一的標(biāo)識id
int cost_id() const { return cost_id_; } // 此處主要標(biāo)明Node 內(nèi)存分配相關(guān)的id,有些Node為ref類型Node,可能其實現(xiàn)當(dāng)中并不實際分配內(nèi)存而只是引用其它Node節(jié)點里面分配的內(nèi)存;這樣它們將擁有相同的cost_id,它在對圖的內(nèi)存分配優(yōu)化及優(yōu)先級策略上有指導(dǎo)、幫助的意義
const string& name() const;
const string& type_string() const; // 顯示不同的type,如有的為Conv,有的為Multmul,還有則為Send或Recv等
const NodeDef& def() const; //輸出Node的protocol buffer definition
const OpDef& op_def() const; //輸出此Node相關(guān)聯(lián)的Op的protocol buffer definition
/* 以下主要為Node的輸入、輸出Tensor類型、數(shù)量及其引用等,容易理解 */
// input and output types
int32 num_inputs() const;
DataType input_type(int32 i) const;
const DataTypeVector& input_types() const;
int32 num_outputs() const;
DataType output_type(int32 o) const;
const DataTypeVector& output_types() const;
/* 用戶可指定或查詢某Node節(jié)點執(zhí)行所用的device,但其在真正執(zhí)行時,executor只是參考此建議,最終真正所用的device還是由executor綜合考慮后決定 */
// The device requested by the user. For the actual assigned device,
// use assigned_device_name() below.
const string& requested_device() const;
// This changes the user requested device but not necessarily the device that
// on which the operation will run.
void set_requested_device(const string& device);
// 以下一組函數(shù)可用來查詢/添加/刪除此Node所具有的屬性
// Read only access to attributes
AttrSlice attrs() const;
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
}
void ClearAttr(const string& name);
// Inputs requested by the NodeDef. For the actual inputs, use in_edges.
const protobuf::RepeatedPtrField<string>& requested_inputs() const;
//以下為一組功能函數(shù),具體來查詢輸入、輸出的Edges/Nodes,并使用不同的數(shù)據(jù)結(jié)構(gòu)返回,因為此類操作在Tensorflow中使用非常頻繁,因此需要考慮數(shù)據(jù)結(jié)構(gòu)的效率、內(nèi)存使用等特點
// Get the neighboring nodes via edges either in or out of this node. This
// includes control edges.
gtl::iterator_range<NeighborIter> in_nodes() const;
gtl::iterator_range<NeighborIter> out_nodes() const;
const EdgeSet& in_edges() const { return in_edges_; }
const EdgeSet& out_edges() const { return out_edges_; }
// Returns into '*n' the node that has an output connected to the
// 'idx' input of this Node.
Status input_node(int idx, const Node** n) const;
Status input_node(int idx, Node** n) const;
private:
friend class Graph; //Graph與Node經(jīng)常會相互調(diào)用彼此函數(shù),這里設(shè)為友類
Node();
NodeProperties* properties() const { return props_.get(); }
void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props);
// Releases memory from props_, in addition to restoring *this to its
// uninitialized state.
void Clear();
};
- 節(jié)點類型
Tensorflow的程序設(shè)計當(dāng)中,一切計算、控制操作都會由節(jié)點來表示。因此不只像傳統(tǒng)意義上大家認(rèn)為構(gòu)成主要模型的Conv/Relu/Matmul/FC等計算操作被表示為節(jié)點,其它像變量初始化(VARIABLE),常量賦值等操作都會有相應(yīng)的Node節(jié)點存在在圖中。然后由Session統(tǒng)一驅(qū)動執(zhí)行。這就是靜態(tài)圖構(gòu)建與執(zhí)行的基本原理。
以下為所有的節(jié)點類型。我們平時說的Conv/Relu/Matmul/BN等計算節(jié)點都被歸于NC_OTHER里面。。而其它在這里有名有姓的則為圖上的控制節(jié)點,也稱為特殊節(jié)點。
// A set of mutually exclusive classes for different kinds of nodes,
// class_ is initialized in the Node::Initialize routine based on the
// node's type_string().
enum NodeClass {
NC_UNINITIALIZED,
NC_SWITCH,
NC_MERGE,
NC_ENTER,
NC_EXIT,
NC_NEXT_ITERATION,
NC_LOOP_COND,
NC_CONTROL_TRIGGER,
NC_SEND,
NC_HOST_SEND,
NC_RECV,
NC_HOST_RECV,
NC_CONSTANT,
NC_VARIABLE,
NC_IDENTITY,
NC_GET_SESSION_HANDLE,
NC_GET_SESSION_TENSOR,
NC_DELETE_SESSION_TENSOR,
NC_METADATA,
NC_SCOPED_ALLOCATOR,
NC_COLLECTIVE,
NC_OTHER // Not a special kind of node
};
- 節(jié)點輸入/輸出
以下兩個結(jié)構(gòu)分別抽象表示Node的輸入、輸出張量(Tensor),本質(zhì)上Tensorflow圖上流動的正是如此一個個Input/Output tensors。
// Represents an input of a node, i.e., the `index`-th input to `node`.
struct InputTensor {
const Node* node;
int index;
InputTensor(const Node* n, int i) : node(n), index(i) {}
InputTensor() : node(nullptr), index(0) {}
};
// Represents an output of a node, i.e., the `index`-th output of `node`. Note
// that a single `OutputTensor` can correspond to multiple `Edge`s if the output
// is consumed by multiple destination nodes.
struct OutputTensor {
const Node* node;
int index;
OutputTensor(const Node* n, int i) : node(n), index(i) {}
OutputTensor() : node(nullptr), index(0) {}
};
- 節(jié)點屬性
tf中每個節(jié)點的屬性包含其輸入、輸出Tensors的類型以及此節(jié)點的protocol定義NodeDef及其所關(guān)聯(lián)的Op的定義OpDef。
class NodeProperties {
public:
NodeProperties(const OpDef* op_def, const NodeDef& node_def,
const DataTypeSlice inputs, const DataTypeSlice outputs)
: op_def(op_def),
node_def(node_def),
input_types(inputs.begin(), inputs.end()),
output_types(outputs.begin(), outputs.end()) {}
const OpDef* op_def; // not owned
NodeDef node_def;
const DataTypeVector input_types;
const DataTypeVector output_types;
};
邊
在下面我們從class Edge的代碼里來分析下TF中邊的實現(xiàn)。詳細可見:tensorflow/core/graph/graph.h
class Edge {
public:
//我們介紹過邊表示Nodes之間的依賴關(guān)系,此處即為dst節(jié)點執(zhí)行所需的某個輸入依賴于來自src節(jié)點的某個輸出或者作為控制邊要求src節(jié)點的執(zhí)行先于節(jié)點dst完成
Node* src() const { return src_; }
Node* dst() const { return dst_; }
int id() const { return id_; } //TF Graph當(dāng)中與Node一樣,每個邊也有其唯一的標(biāo)識id
// Return the index of the source output that produces the data
// carried by this edge. The special value kControlSlot is used
// for control dependencies.
int src_output() const { return src_output_; }
// Return the index of the destination input that consumes the data
// carried by this edge. The special value kControlSlot is used
// for control dependencies.
int dst_input() const { return dst_input_; }
// Return true iff this is an edge that indicates a control-flow
// (as opposed to a data-flow) dependency.
bool IsControlEdge() const;
string DebugString() const;
private:
Edge() {}
friend class EdgeSetTest;
friend class Graph;
Node* src_;
Node* dst_;
int id_;
int src_output_;
int dst_input_;
};