Cannot assign a device to node

举报
风吹稻花香 发表于 2021/06/05 00:56:45 2021/06/05
【摘要】 [python]  view plain  copy import sys   import tensorflow as tf   #from icon_reg_net import GoogleNet&nbs...


[python]  view plain  copy
  1. import sys  
  2. import tensorflow as tf  
  3. #from icon_reg_net import GoogleNet  
  4. from det_icon_reg_deploy import GoogleNet_Reg  
  5. from det_icon_cls_deploy import GoogleNet_Cls  
  6. import numpy as np  
  7.   
  8.   
  9. image_path = './1113014_7.jpg'  
  10.   
  11. mean = np.array([148.148.148.])  
  12. new_shape = [224224]  
  13.   
  14.   
  15. def device_for_node(n):  
  16.     if n.type == "MatMul":  
  17.         return "/gpu:0"  
  18.     else:  
  19.         return "/cpu:0"  
  20.   
  21.   
  22. def test(image_path):  
  23.     g = tf.Graph()  
  24.     with g.as_default():   
  25.         with g.device(device_for_node):    #直接写 "/gpu:0" 会出问题,详见下述1  
  26.             file_data = tf.read_file(image_path)  
  27.             # Decode the image data  
  28.             img = tf.image.decode_jpeg(file_data, channels=3)  
  29.             #img = tf.reverse(img, [False, False, True])  
  30.             img = tf.image.resize_images(img, new_shape[0], new_shape[1])  
  31.             img = tf.to_float(img) - mean  
  32.   
  33.     with tf.Session(graph=g) as sess1:  
  34.         #tf.initialize_all_variables().run()  
  35.         print type(img), img.get_shape()  
  36.         img = sess1.run(img)   # 这里需要先执行这句,详见下述2  
  37.         print type(img), img.shape  
  38.         img = np.reshape(img, (12242243))  
  39.         input_node = tf.placeholder(tf.float32, shape=(None, new_shape[0], new_shape[0], 3))  
  40.         net = GoogleNet_Reg({'data': input_node})  
  41.         model_path = './det_icon_reg_iter_110000.npy'  
  42.         net.load(data_path=model_path, session=sess1)   
  43.   
  44.         probs = sess1.run(net.get_output(), feed_dict={input_node: img})  
  45.     print probs  
  46.   
  47.     pos = probs[0]  
  48.     x1 = int(pos[0] * new_shape[0])  
  49.     y1 = int(pos[1] * new_shape[1])  
  50.     x2 = int(pos[2] * new_shape[0])  
  51.     y2 = int(pos[3] * new_shape[1])  
  52.   
  53.     #tf.reset_default_graph() #如果前面没有g = tf.Graph(),那么如果不加上这句可能会出错,详见下述3  
  54.     #g2 = tf.Graph()  
  55.     roiimg = tf.slice(img, begin=tf.pack([0, x1, y1, 0]), size=tf.pack([1, x2-x1, y2-y1, 3]))  
  56.     roiimg = tf.image.resize_images(roiimg, new_shape[0], new_shape[1])  
  57.   
  58.     #g = tf.Graph()  
  59.     with tf.Session() as sess2:      
  60.         roiimg = sess2.run(roiimg)  
  61.         roiimg = np.reshape(img, (12242243))  
  62.         print g  
  63.         print tf.get_default_graph()  
  64.         test = tf.constant(1)  
  65.         print test.graph     #这里使用的是默认的图,tf.get_default_graph() == test.graph  
  66.           
  67.         input_node = tf.placeholder(tf.float32, shape=(None, new_shape[0], new_shape[0], 3))  
  68.         net = GoogleNet_Cls({'data': input_node})  
  69.   
  70.         model_path = './det_icon_cls_iter_50000.npy'  
  71.         net.load(data_path=model_path, session=sess2)  
  72.           
  73.         probs = sess2.run(net.get_output(), feed_dict={input_node: roiimg})  
  74.   
  75.         print probs  
  76.   
  77.     scores = probs[0]  
  78.     rank = np.argsort(-scores)  
  79.     print rank[0], scores[rank[0]]  
  80.   
  81.   
  82. if __name__ == '__main__':  
  83.     if len(sys.argv)>1 :  
  84.         print sys.argv  
  85.         func = getattr(sys.modules[__name__], sys.argv[1])  
  86.         func(*sys.argv[2:])  
  87.     else:  
  88.         print >> sys.stderr,'%s command' % (__file__)  
  89.           
  90.           


1.  在指定使用GPU时,如果直接指定   "/gpu:0"

[python]  view plain  copy
  1. with g.device("/gpu:0"):  
这样会报错,类似:


  
  1. tensorflow.python.framework.errors.InvalidArgumentError: Cannot assign a device to node 'GradientDescent/update_Variable_2/ScatterSub': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/GPU:0'
  2. [[Node: GradientDescent/update_Variable_2/ScatterSub = ScatterSub[T=DT_FLOAT, Tindices=DT_INT64, use_locking=false](Variable_2, gradients/concat_1, GradientDescent/update_Variable_2/mul)]]
  3. Caused by op u'GradientDescent/update_Variable_2/ScatterSub', defined at:
原因是并不是图中的所有的操作都支持GPU运算:

It seems a whole bunch of operations used in this example aren't supported on a GPU. A quick workaround is to restrict operations such that only matrix muls are ran on the GPU.

There's an example in the docs: http://tensorflow.org/api_docs/python/framework.md

See the section on tf.Graph.device(device_name_or_function)

简答的办法是指定只有矩阵乘法才在GPU上进行:


  
  1. def device_for_node(n):
  2. if n.type == "MatMul":
  3. return "/gpu:0"
  4. else:
  5. return "/cpu:0"
  6. with graph.as_default():
  7. with graph.device(device_for_node):
  8. ...


2.  在没有执行 img = sess1.run(img)   前,img 是 tf.read_file 然后 tf.image.decode_jpeg 后得到的,但是这里都是添加到图中的操作,并没有真正被执行,此时 img 的类型是

[python]  view plain  copy
  1. <class 'tensorflow.python.framework.ops.Tensor'> (2242243)  
在 img = sess1.run(img) 后,img 变成了
[python]  view plain  copy
  1. <type 'numpy.ndarray'> (2242243)  



3. 所有操作如果不指定图,则会使用默认图  tf.get_default_graph() ,上述代码中加载了两个模型,如果两个模型中出现里相同的name,就会出错。


[python]  view plain  copy
  1. g = tf.Graph()  
  2.    with g.as_default():  
  3.      ...  
  4. lt;pre name="code" class="python">with tf.Session(graph=g) as sess1:  
...


上面的方式将第一个模型放在了图g中,session执行的是图g,而不是默认图,这样后面就可以不用显示指定图,直接使用默认图。如果没有使用图g, 那么在第二个模型时需要先将默认图重置以清空默认图中之前的添加的ops

[python]  view plain  copy
  1. tf.reset_default_graph()  

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/77622857

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。