import tensorflow as tf
input = [[[1,1,1],[2,2,2]],
[[3,3,3],[4,4,4]],
[[5,5,5],[6,6,6]]]
x = tf.slice(input,[0,0,0],[1,2,3])
sess = tf.InteractiveSession()
print(sess.run(x))
>>> [[[1 1 1]
[[2 2 2]]]
首先來看tf.slice里的幾個參數,
- input代表輸入的tensor,
- [0,0,0]代表begin,起始值
- [1,2,3]代表切的大小size。
要明白tf.slice是一個切片函數,那應該怎么切呢?
注意到tf.slice從begin開始切,
例如上面就是從[0,0,0],也就是第0行第0列第0維開始切,
然后size[1,2,3]表示切出1行2列3維的大小。
所以切出來了:
???????[[[1 1 1]
???????[[2 2 2]]]
倘若是下面的代碼:
import tensorflow as tf
input = [[[1,1,1],[2,2,2]],
[[3,3,3],[4,4,4]],
[[5,5,5],[6,6,6]]]
x = tf.slice(input,[1,0,0],[2,1,3])
sess = tf.InteractiveSession()
print(sess.run(x))
>>> [[[3 3 3]]
[[5 5 5]]]
tf.slice會從第一行第0列第0維開始切,
并切出2行1列3維的大小。