python – 如何使用TensorFlow Graphkeys获取所有权重?

python – 如何使用TensorFlow Graphkeys获取所有权重?,第1张

概述Tensorflow定义了预设的集合,如下所示: https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections 我目前正在使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)来获取所有变量[*的名字;如果他们不是那么他们就不会被展示,即使他们 Tensorflow定义了预设的集合,如下所示: https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections

我目前正在使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)来获取所有变量[*的名字;如果他们不是那么他们就不会被展示,即使他们存在].

类似地,我期望tf.get_collection(tf.GraphKeys.WEIGHTS)输出权重列表,但它是一个空数组.这也适用于GraphKeys.BIASES和.ACTIVATIONS.

这里发生了什么?

在我看来,这里似乎有两种可能性.首先,它们实际上从未自动定义,只是推荐的集合名称.其次,我的网络非常破碎,但似乎并非如此.

有人有这方面的经验吗?

解决方法 默认情况下,所有变量都绑定到tf.GraphKeys.GLOBAL_VARIABLES集合.方便的方法是将每个权重设置为集合tf.GraphKeys.WEIGHTS,如下所示:

In [2]: w = tf.Variable([1,2,3],collections=[tf.GraphKeys.WEIGHTS],dtype=tf.float32)In [3]: w2 = tf.Variable([11,22,32],dtype=tf.float32)

然后你可以通过以下方式获取它们:

tf.get_collection_ref(tf.GraphKeys.WEIGHTS)

这是权重:

[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,<tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
总结

以上是内存溢出为你收集整理的python – 如何使用TensorFlow Graphkeys获取所有权重?全部内容,希望文章能够帮你解决python – 如何使用TensorFlow Graphkeys获取所有权重?所遇到的程序开发问题。

如果觉得内存溢出网站内容还不错,欢迎将内存溢出网站推荐给程序员好友。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://54852.com/langs/1193990.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-06-03
下一篇2022-06-03

发表评论

登录后才能评论

评论列表(0条)

    保存