tensorflow 之tf.Session

tensorflow 之tf.Session,第1张

样例代码

hello.py文件内容如下。这也是tensorflow入门级案例。其中创建了tf.Session. 本文分析一下session的相关代码及依赖。

tensorflow安装在:envs/python3.10/lib/python3.10/site-packages/tensorflow目录下。也就是使用minconda的envs目录下。

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.disable_eager_execution()
print(tf.__version__)

v = tf.constant(3)
print(v)
print(tf.Session)
with tf.Session() as sess:
    vv = sess.run(v)
    print(vv)

以下是输出
$ python hello.py 
WARNING:tensorflow:From /data0/huozai/miniconda2/envs/python3.10/lib/python3.10/site-packages/tensorflow/python/compat/v2_compat.py:107: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
2.10.0
Tensor("Const:0", shape=(), dtype=int32)

2022-05-15 21:08:36.578359: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-15 21:08:36.582211: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
3
session.py

而tf.Session就是此目录下python/client/session.py模块中的一个类

tf.Session继承了

BaseSession

也是在这个文件中。BaseSession的__init__方法中使用了tf_session(from tensorflow.python.client import pywrap_tf_session as tf_session)


    self._session = None
    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
    try:
      # pylint: disable=protected-access
      self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts)
      # pylint: enable=protected-access
    finally:
      tf_session.TF_DeleteSessionOptions(opts)
pywrap_tf_session

是一个wrapper,也就是说把C++代码打包成python库。也就是目录下的_pywrap_tf_session.so。这个so是怎么打包出来的呢?

 _pywrap_tf_session.so

这是在源码目录下(不是安装目录):tensorfow/python/client/BUILD文件中。

tf_session_wrapper.cc是用pybind11对Session的包装。

wrapper中的TF_NewSessionRef和TF_NewSession

 TF_NewSessionRef在tf_sesson_helper.cc中调用 了TF_NewSession

TF_NewSession

 

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << "Failed to get session factory: " << s;
    return s;
  }
  // Starts exporting metrics through a platform-specific monitoring API (if
  // provided). For builds using "tensorflow/core/platform/default", this is
  // currently a no-op.
  session_created->GetCell()->Set(true);
  s = factory->NewSession(options, out_session);
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << "Failed to create session: " << s;
  }
  return s;
}

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/921242.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-05-16
下一篇2022-05-16

发表评论

登录后才能评论

评论列表(0条)

    保存