tensorflow中模型保存加載操作的學習體會

姓名:樂仁華 學號:16140220023

【嵌牛導讀】:本文簡述了學習tensorflow中保存加載模型的體會及總結

【嵌牛鼻子】:tensorflow,檢查點文件

【嵌牛提問】:tensorflow中保存加載模型有什么方法?

【嵌牛正文】:

先簡單提一下模型參數保存及加載的函數

tf.train.Saver()

tf.train.Saver()是tensorflow中加載,保存模型參數的一個類
使用方法:

#實例化類
saver = tf.train.Saver()

#調用save方法保存參數,ckpt為保存的模型參數名,如'run_dir/model.ckpt',
#其中run_dir表示模型所在的文件夾
#step表示迭代步數
saver.save(sess,ckpt,gloabal_step=step)

#如果需要加載參數
restorer = tf.train.Saver()
#這里的ckpt與保存過程的ckpt一致
restorer.restore(sess,ckpt)

更多詳細的用法可以看官方文檔

檢查點文件格式


保存的檢查點文件如上圖所示,
.meta文件保存了當前圖結構
.index文件保存了當前參數名
.data文件保存了當前參數值
每調用一次save方法會產生新的文件

獲取最新保存的檢查點文件

#假設check_path為保存這些檢查點文件的文件夾
#tf.train.get_checkpoint_state(check_point)表示查看check_point文件夾下是否有檢查點文件
ckpt = tf.train.get_checkpoint_state(check_point)
#獲取最新保存的模型檢查點文件
ckpt.model_checkpoint_path

還有其他的方法,不過我沒怎么用過,大家可以自己上網查查

查看檢查點文件中的各tensor

有時我們會需要查看檢查點文件中各變量,這時可以調用tensorflow中的方法查看

from tensorflow.python import pywrap_tensorflow

# 從檢查點文件中讀取數據
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# 顯示變量名及其值
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

保存及加載圖結構

我們知道tensorflow是以圖表示計算過程的,各節點操作都在圖上,自然也就有保存圖結構的方法

tf.train.write_graph()

具體參數看這:

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)
# Writes a graph proto to a file.
#      graph_or_graph_def: A `Graph` or a `GraphDef` protocol buffer.
#      logdir: Directory where to write the graph. This can refer to remote
#        filesystems, such as Google Cloud Storage (GCS).
#      name: Filename for the graph.
#      as_text: If `True`, writes the graph as an ASCII proto.
    
#    Returns:
#      The path of the output proto file.
(從內置文檔摘來的,相信大家都看得懂^_^)

如果要加載的話就用這個:

tf.train.import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None)
#參數如下
#meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
#       the path) containing a `MetaGraphDef`.
#    clear_devices: Whether or not to clear the device field for an `Operation`
#        or `Tensor` during import.
#     import_scope: Optional `string`. Name scope to add. Only used when
#        initializing from protocol buffer.
#      **kwargs: Optional keyed arguments.
    
#    Returns:
#      A saver constructed from `saver_def` in `MetaGraphDef` or None.
    
#      A None value is returned if no variables exist in the `MetaGraphDef`
(還是相信大家^_^)

細心的讀者可能發現了前頭提到的檢查點文件里面也有個保存結構的文件,那這兩者有啥區別嗎,說實話我也不清楚。。。。。

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容