介紹
如果你翻過(guò)Tensorflow的核心代碼,一定會(huì)奇怪表示圖的class如此之多像GraphDef/Graph等。通常GraphDef表示一組與Graph相關(guān)的屬性Jason對(duì)(本質(zhì)上是Graph的Protocol buffer表示)。而真正Executor所執(zhí)行計(jì)算的是Graph。一般我們用戶(hù)使用高級(jí)語(yǔ)言像Python所構(gòu)建好的graph模型,會(huì)在底下悄悄地生成一個(gè)由GraphDef表示的圖結(jié)構(gòu)來(lái)。然后我們使用Python等語(yǔ)言里的Session具體去分配內(nèi)存,初使化參數(shù),運(yùn)行計(jì)算圖時(shí),TF的后端會(huì)將我們前一部所構(gòu)建的GraphDef轉(zhuǎn)化為一個(gè)可執(zhí)行的Graph。
本節(jié)中我們將著力于從細(xì)節(jié)上講述GraphDef到Graph的轉(zhuǎn)換即實(shí)際可執(zhí)行圖——Graph的構(gòu)建。
兩個(gè)關(guān)鍵的構(gòu)建函數(shù)
從GraphDef到Graph有兩個(gè)函數(shù)可以使用,分別為ConvertGraphDefToGraph和ImportGraphDef。其中前者ConverGraphDefToGraph函數(shù)主要用來(lái)使用一個(gè)輸入的GraphDef為參數(shù)從頭構(gòu)建出一個(gè)完整的Graph出來(lái)。而后者ImportGraphDef則用于使用輸入的GraphDef來(lái)擴(kuò)充已有的Graph類(lèi),以來(lái)擴(kuò)展它的組成。下面我們分別講述這兩個(gè)函數(shù),詳細(xì)可見(jiàn):tensorflow/core/graph/graph_constructor.h
- ConvertGraphDefToGraph
我們可以看到此函數(shù)中處了必需的兩個(gè)參數(shù)GraphDef與Graph外還有一個(gè)參數(shù)叫GraphConstructorOptions。這個(gè)選項(xiàng)結(jié)構(gòu)里面包含了所有用于指導(dǎo)此轉(zhuǎn)換進(jìn)行的選項(xiàng)參數(shù)。隨著對(duì)Tensorflow core code了解的增多,我們會(huì)看到愈來(lái)愈多的此類(lèi)將所有函數(shù)參數(shù)與配置項(xiàng)放入一個(gè)Option struct/class里面的做法。
struct GraphConstructorOptions {
GraphConstructorOptions() {}
// If true, allows internal ops in the GraphDef.
bool allow_internal_ops = false;
// If true, the graph def is expected to have fully specified
// devices for all nodes. A node in the resulting graph "g" has the
// device name set accordingly.
bool expect_device_spec = false;
};
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g);
去tensorflow/core/graph/graph_constructor.cc里面查看此函數(shù)的定義,我們會(huì)發(fā)現(xiàn)原來(lái)其具體實(shí)現(xiàn)將依靠更深一層次的class GraphConstructor來(lái)完成。如下是它的實(shí)現(xiàn):
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) {
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
/*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
/*missing_unused_input_map_keys=*/nullptr);
}
以下是GraphConstructor的主要構(gòu)成。它里面有個(gè)inner的struct Options,主要用來(lái)獲得我們上述中所說(shuō)過(guò)的外部的struct GraphConstructorOptions(還有下文將提到的ImportGraphDefOptions)里面的配置項(xiàng)。
我們能從下面代碼中看出所有真正的Import GraphDef,然后進(jìn)行檢查合理性,安全性,然后再逐步建立Graph里的數(shù)據(jù)結(jié)構(gòu)的一系列過(guò)程都在TryImport這個(gè)函數(shù)里面可見(jiàn)到。
class GraphConstructor {
public:
struct Options {
Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(in.allow_internal_ops),
expect_device_spec(in.expect_device_spec),
importing(false),
validate_colocation_constraints(false) {}
Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit)
: allow_internal_ops(false),
expect_device_spec(false),
prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/")
? in.prefix
: in.prefix + "/"),
uniquify_names(in.uniquify_names),
uniquify_prefix(in.uniquify_prefix),
input_map(in.input_map),
skip_mapped_nodes(in.skip_mapped_nodes),
control_dependencies(in.control_dependencies),
return_tensors(in.return_tensors),
return_nodes(in.return_nodes),
importing(true),
validate_colocation_constraints(in.validate_colocation_constraints),
validate_shape(in.validate_shape) {}
//以下兩個(gè)由GraphConstructorOptions提供
bool allow_internal_ops;
bool expect_device_spec;
//以下一些則由ImportGraphOptions來(lái)提供
string prefix;
bool uniquify_names;
bool uniquify_prefix;
std::map<TensorId, TensorId> input_map;
bool skip_mapped_nodes;
std::vector<string> control_dependencies;
std::vector<TensorId> return_tensors;
std::vector<string> return_nodes;
bool importing;
bool validate_colocation_constraints;
bool validate_shape = true;
};
//以下為具體做construct的函數(shù)
static Status Construct(
const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
std::vector<TensorId>* missing_unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
}
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
return_tensors, return_nodes,
missing_unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
return s;
}
//所有真正Import GraphDef并構(gòu)建Graph的一些過(guò)程序列
Status TryImport() {
TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
TF_RETURN_IF_ERROR(BuildNodeIndex());
TF_RETURN_IF_ERROR(InitFromEdges());
TF_RETURN_IF_ERROR(Convert());
TF_RETURN_IF_ERROR(AddBackEdges());
TF_RETURN_IF_ERROR(UpdateVersionDef());
TF_RETURN_IF_ERROR(PopulateReturnTensors());
TF_RETURN_IF_ERROR(PopulateReturnNodes());
TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
UpdateUniquifiedColocationNames();
FixupSourceAndSinkEdges(g_);
return Status::OK();
}
};
- ImportGraphDef
如上所述,此函數(shù)主要用來(lái)擴(kuò)展已有的圖Graph結(jié)構(gòu),在里面添加新的節(jié)點(diǎn),擴(kuò)展原Graph功能。
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
//
// On error, returns non-OK and leaves `*g` unmodified.
//
// `refiner` can be null. It should be non-null if the caller
// intends to add additional nodes to the graph after the import. This
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
// non-empty. It can also be set to fetch the unused input map keys. If it's
// non-null, all the vector fields must be empty.
extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
const GraphDef& gdef, Graph* g,
ShapeRefiner* refiner,
ImportGraphDefResults* results = nullptr);
它同除了應(yīng)有的GraphDef與Graph外,還有一個(gè)配置項(xiàng)參數(shù)ImportGraphDefOptions與一個(gè)ShapeRefiner參數(shù),主要用來(lái)保證在此函數(shù)調(diào)用中當(dāng)有新的節(jié)點(diǎn)被加入到原Graph中時(shí),保證節(jié)點(diǎn)間的輸入、輸出的Shape相互匹配。此外ImportGraphDefResults函數(shù)則主要用來(lái)輸出此圖中的輸出節(jié)點(diǎn)與輸出張量。
首先我們來(lái)看下ImportGraphDefOptions里面都包含哪些配置項(xiàng)。
struct ImportGraphDefOptions {
ImportGraphDefOptions()
: uniquify_names(false),
uniquify_prefix(false),
skip_mapped_nodes(false),
validate_shape(true) {}
//prefix, uniquify_names, uniquify_prefix這三個(gè)參數(shù)主要用于保證對(duì)來(lái)自GraphDef的新增節(jié)點(diǎn)其命名不與Graph中的原有節(jié)點(diǎn)相沖突。另外就是如果有沖突的話(huà)應(yīng)當(dāng)如何來(lái)處理。
// Name prefix to use for nodes imported from the GraphDef. For example, if
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
// named "animals/bunny" in *g. Must not be already used as a node name or
// prefix in the graph.
string prefix;
// If true, imported node names will be modified if their name already exists
// in the graph. If false, conflicting names will be treated as an error. Note
// that this option has no effect if `prefix` is specified, since `prefix`
// will guarantee all node names are unique.
bool uniquify_names;
// If true, `prefix` will be modified if it already exists as a node name or
// prefix in the graph. If false, a conflicting prefix will be treated as an
// error. This option has no effect if `prefix` isn't specified.
bool uniquify_prefix;
//具體構(gòu)建新節(jié)點(diǎn)時(shí),作為intermediate結(jié)構(gòu)來(lái)保存NodeDef TensorId到Graph中Node里TensorId間的映射。
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
// corresponding to `input_map` keys will be remapped to the nodes in `g`
// corresponding to the values.
//
// Keys should not include `prefix`, i.e., a key TensorId's name should be the
// name as it originally appears in `gdef`.
//
// If this is non-empty, ImportGraphDef must be called with the shape refiner
// used to create the existing nodes referenced in `input_map`.
std::map<TensorId, TensorId> input_map;
// If true, nodes that will have all output edges removed because of
// overrides in `input_map` will not be imported.
bool skip_mapped_nodes;
// The names of existing nodes in `g` that the imported graph should have
// control dependencies on.
//
// Note that to avoid creating many redundant control edges, ImportGraphDef()
// won't add control edges to nodes that will inherit the dependencies from
// other nodes in `gdef`.
std::vector<string> control_dependencies;
// Tensors in `gdef` that will be returned via the ImportGraphDefResults
// output parameter of `ImportGraphDef()`. If this list is non-empty, the
// caller must pass a results object to `ImportGraphDef()`. The
// `return_tensors` field will be populated with the imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each TensorId's name should be
// the name as it originally appears in `gdef`.
//
// If this contains a tensor that's also being remapped via `input_map`, the
// corresponding existing tensor in `g` will be returned.
std::vector<TensorId> return_tensors;
// The names of nodes in `gdef` that will be returned via the
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
// is non-empty, the caller must pass a results object to
// `ImportGraphDef()`. The `return_nodes` field will be populated with the
// imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each node's name should be the
// name as it originally appears in `gdef`.
//
// Unlike `return_tensors`, `input_map` has no effect on the nodes
// returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
std::vector<string> return_nodes;
// If true, checks that all colocation constraints are nodes in the GraphDef.
bool validate_colocation_constraints = true;
// If false skips shape validation.
bool validate_shape;
};
不得不佩服Google工程師的代碼清晰度。而且它們的注釋給的也挺恰當(dāng)。讀其代碼,真以其人亦當(dāng)為風(fēng)度翩翩之清爽公子亦!
下面再簡(jiǎn)單看下ImportGraphDefResults的結(jié)構(gòu)。其注釋及結(jié)構(gòu)成員命名已經(jīng)足以說(shuō)明問(wèn)題了,不再詳解,亦無(wú)必要了。:)
struct ImportGraphDefResults {
// The requested tensors associated with
// ImportGraphDefOptions::return_tensors. Note that the index may be different
// than the requested index if the returned tensor has been remapped according
// to `input_map`.
typedef int Index;
std::vector<std::pair<Node*, Index>> return_tensors;
// The requested nodes associated with ImportGraphDefOptions::return_nodes.
std::vector<Node*> return_nodes;
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
// weren't used as an input to any node in `gdef`. These keys are likely due
// to typos, and callers may wish to treat their existence as an error.
std::vector<TensorId> missing_unused_input_map_keys;
};