Tensorflow Session

Submitted by Lizhe on Thu, 08/17/2017 - 09:53

tf.Session() 接收3个可选参数

target 该参数用于在分布式设置中指定连接不同的tf.train.Server实例, 一般来说默认是空字符串

graph 参数指定了将要在Session对象中加载的Session对象, 默认值是None, 表示使用当前默认的数据流

config 参数允许用户指定配置, 如cpu或gpu内核数

 

一旦创建了Session对象之后, 就可以使用session.run()方法来计算Tensor对象的输出

run方法接收一个参数fetches 以及 3个可选参数 feed_dict, options 和 run_metadata

fetches参数接收任意的数据流图元素 ( Operation 和 Tensor) , 如果是Tensor对象, run会输出一个NumPy数组, 如果是Operation则输出None

feed_dict参数用于覆盖数据流图中的Tensor对象的原始值, 它接收一个字典参数, key为要覆盖值的句柄, 值类型必需与原类型相同或者可以相互转换

下面的例子输出6 而不是 3

import tensorflow as tf

a = tf.constant(1)
b = tf.mul(a,3)

sess = tf.Session()

replace_dict = {a:2}

print(sess.run(b,feed_dict=replace_dict))

[root@localhost session]# python helloSession.py 
6
[root@localhost session]#

 

session对象使用完后, 需要使用close()方法释放相关资源

或者使用with使其自动关闭

with tf.Session() as sess:

    ....

也可以利用Session类的as_default() 方法将Session对象作为默认Session对象

a = tf.constant(1)

sess = tf.Session()

with sess.as_default():

    a.eval()

sess.close()