使用 Inception-v3,实现图像识别(Python、C++)

举报
不脱发的程序猿 发表于 2020/12/31 23:35:42 2020/12/31
【摘要】 目录 简介 使用 Python API 使用 C++ API 简介 对于我们的大脑来说,视觉识别似乎是一件特别简单的事。人类不费吹灰之力就可以分辨狮子和美洲虎、看懂路标或识别人脸。但对计算机而言,这些实际上是很难处理的问题:这些问题只是看起来简单,因为大脑非常擅长理解图像。 在过去几年内,机器学习领域在解决此类难题方面取得了巨大进展。尤其是,我们发现一种称为深...

目录

简介

使用 Python API

使用 C++ API


  • 简介

对于我们的大脑来说,视觉识别似乎是一件特别简单的事。人类不费吹灰之力就可以分辨狮子和美洲虎、看懂路标或识别人脸。但对计算机而言,这些实际上是很难处理的问题:这些问题只是看起来简单,因为大脑非常擅长理解图像。

在过去几年内,机器学习领域在解决此类难题方面取得了巨大进展。尤其是,我们发现一种称为深度卷积神经网络的模型可以很好地处理较难的视觉识别任务 - 在某些领域的表现与人类大脑不相上下,甚至更胜一筹。

研究人员通过用 ImageNet(计算机视觉的一种学术基准)验证其工作成果,证明他们在计算机视觉方面取得了稳步发展。他们陆续推出了以下几个模型,每一个都比上一个有所改进,且每一次都取得了新的领先成果:QuocNet、AlexNet、Inception (GoogLeNet)、BN-Inception-v2。Google 内部和外部的研究人员均发表过关于所有这些模型的论文,但这些成果仍是难以复制的。现在我们将采取后续步骤,发布用于在我们的最新模型 Inception-v3 上进行图像识别的代码。

Inception-v3 使用 2012 年的数据针对 ImageNet 大型视觉识别挑战赛训练而成。它的层次结构如下图所示:

Inception-v3处理的是标准的计算机视觉任务,在此类任务中,模型会尝试将所有图像分成 1000 个类别,如 “斑马”、“斑点狗” 和 “洗碗机”。例如,以下是 AlexNet 对某些图像进行分类的结果:

 

为了比较各个模型,我会检查正确答案不在模型预测的最有可能的 5 个选项中的频率,称为 “top-5 错误率”。 AlexNet 在 2012 年的验证数据集上实现了 15.3% 的 top-5 错误率;Inception (GoogLeNet)、BN-Inception-v2 和 Inception-v3 的 top-5 错误率分别达到 6.67%、4.9% 和 3.46%。

人类在 ImageNet 挑战赛上的表现如何?Andrej Karpathy 曾尝试衡量自己的表现,他发表了一篇博文,提到自己的 top-5 错误率为 5.1%。

本次将介绍如何使用 Inception-v3。小伙伴们将了解如何使用 Python 或 C++ 将图像分成 1000 个类别。此外,我们还将讨论如何从该模型提取更高级别的特征,以重复用于其他视觉任务。

  • 使用 Python API

首次运行程序时,classify_image.py 会从 tensorflow.org 下载经过训练的模型。你的硬盘上需要有约 200M 的可用空间。

首先,从 GitHub 克隆 TensorFlow 模型代码库。

cd models/tutorials/image/imagenet

 

classify_image.py 程序内容如下:


  
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import argparse
  5. import os.path
  6. import re
  7. import sys
  8. import tarfile
  9. import numpy as np
  10. from six.moves import urllib
  11. import tensorflow as tf
  12. FLAGS = None
  13. # pylint: disable=line-too-long
  14. DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
  15. # pylint: enable=line-too-long
  16. class NodeLookup(object):
  17. """Converts integer node ID's to human readable labels."""
  18. def __init__(self,
  19. label_lookup_path=None,
  20. uid_lookup_path=None):
  21. if not label_lookup_path:
  22. label_lookup_path = os.path.join(
  23. FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
  24. if not uid_lookup_path:
  25. uid_lookup_path = os.path.join(
  26. FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
  27. self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
  28. def load(self, label_lookup_path, uid_lookup_path):
  29. """Loads a human readable English name for each softmax node.
  30. Args:
  31. label_lookup_path: string UID to integer node ID.
  32. uid_lookup_path: string UID to human-readable string.
  33. Returns:
  34. dict from integer node ID to human-readable string.
  35. """
  36. if not tf.gfile.Exists(uid_lookup_path):
  37. tf.logging.fatal('File does not exist %s', uid_lookup_path)
  38. if not tf.gfile.Exists(label_lookup_path):
  39. tf.logging.fatal('File does not exist %s', label_lookup_path)
  40. # Loads mapping from string UID to human-readable string
  41. proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
  42. uid_to_human = {}
  43. p = re.compile(r'[n\d]*[ \S,]*')
  44. for line in proto_as_ascii_lines:
  45. parsed_items = p.findall(line)
  46. uid = parsed_items[0]
  47. human_string = parsed_items[2]
  48. uid_to_human[uid] = human_string
  49. # Loads mapping from string UID to integer node ID.
  50. node_id_to_uid = {}
  51. proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
  52. for line in proto_as_ascii:
  53. if line.startswith(' target_class:'):
  54. target_class = int(line.split(': ')[1])
  55. if line.startswith(' target_class_string:'):
  56. target_class_string = line.split(': ')[1]
  57. node_id_to_uid[target_class] = target_class_string[1:-2]
  58. # Loads the final mapping of integer node ID to human-readable string
  59. node_id_to_name = {}
  60. for key, val in node_id_to_uid.items():
  61. if val not in uid_to_human:
  62. tf.logging.fatal('Failed to locate: %s', val)
  63. name = uid_to_human[val]
  64. node_id_to_name[key] = name
  65. return node_id_to_name
  66. def id_to_string(self, node_id):
  67. if node_id not in self.node_lookup:
  68. return ''
  69. return self.node_lookup[node_id]
  70. def create_graph():
  71. """Creates a graph from saved GraphDef file and returns a saver."""
  72. # Creates graph from saved graph_def.pb.
  73. with tf.gfile.FastGFile(os.path.join(
  74. FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
  75. graph_def = tf.GraphDef()
  76. graph_def.ParseFromString(f.read())
  77. _ = tf.import_graph_def(graph_def, name='')
  78. def run_inference_on_image(image):
  79. """Runs inference on an image.
  80. Args:
  81. image: Image file name.
  82. Returns:
  83. Nothing
  84. """
  85. if not tf.gfile.Exists(image):
  86. tf.logging.fatal('File does not exist %s', image)
  87. image_data = tf.gfile.FastGFile(image, 'rb').read()
  88. # Creates graph from saved GraphDef.
  89. create_graph()
  90. with tf.Session() as sess:
  91. # Some useful tensors:
  92. # 'softmax:0': A tensor containing the normalized prediction across
  93. # 1000 labels.
  94. # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
  95. # float description of the image.
  96. # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
  97. # encoding of the image.
  98. # Runs the softmax tensor by feeding the image_data as input to the graph.
  99. softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
  100. predictions = sess.run(softmax_tensor,
  101. {'DecodeJpeg/contents:0': image_data})
  102. predictions = np.squeeze(predictions)
  103. # Creates node ID --> English string lookup.
  104. node_lookup = NodeLookup()
  105. top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
  106. for node_id in top_k:
  107. human_string = node_lookup.id_to_string(node_id)
  108. score = predictions[node_id]
  109. print('%s (score = %.5f)' % (human_string, score))
  110. def maybe_download_and_extract():
  111. """Download and extract model tar file."""
  112. dest_directory = FLAGS.model_dir
  113. if not os.path.exists(dest_directory):
  114. os.makedirs(dest_directory)
  115. filename = DATA_URL.split('/')[-1]
  116. filepath = os.path.join(dest_directory, filename)
  117. if not os.path.exists(filepath):
  118. def _progress(count, block_size, total_size):
  119. sys.stdout.write('\r>> Downloading %s %.1f%%' % (
  120. filename, float(count * block_size) / float(total_size) * 100.0))
  121. sys.stdout.flush()
  122. filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
  123. print()
  124. statinfo = os.stat(filepath)
  125. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  126. tarfile.open(filepath, 'r:gz').extractall(dest_directory)
  127. def main(_):
  128. maybe_download_and_extract()
  129. image = (FLAGS.image_file if FLAGS.image_file else
  130. os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  131. run_inference_on_image(image)
  132. if __name__ == '__main__':
  133. parser = argparse.ArgumentParser()
  134. # classify_image_graph_def.pb:
  135. # Binary representation of the GraphDef protocol buffer.
  136. # imagenet_synset_to_human_label_map.txt:
  137. # Map from synset ID to a human readable string.
  138. # imagenet_2012_challenge_label_map_proto.pbtxt:
  139. # Text representation of a protocol buffer mapping a label to synset ID.
  140. parser.add_argument(
  141. '--model_dir',
  142. type=str,
  143. default=r'C:\Users\Administrator\Desktop\imagenet',
  144. help="""\
  145. Path to classify_image_graph_def.pb,
  146. imagenet_synset_to_human_label_map.txt, and
  147. imagenet_2012_challenge_label_map_proto.pbtxt.\
  148. """
  149. )
  150. parser.add_argument(
  151. '--image_file',
  152. type=str,
  153. default=r'C:\Users\Administrator\Desktop\imagenet\cropped_panda.jpg',
  154. help='Absolute path to image file.'
  155. )
  156. parser.add_argument(
  157. '--num_top_predictions',
  158. type=int,
  159. default=5,
  160. help='Display this many predictions.'
  161. )
  162. FLAGS, unparsed = parser.parse_known_args()
  163. tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

运行以下命令:

python classify_image.py
 

以上命令会对提供的大熊猫图像进行分类。

 

如果模型运行正确,脚本将生成以下输出:


  
  1. giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493)
  2. indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878)
  3. lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317)
  4. custard apple (score = 0.00149)
  5. earthstar (score = 0.00127)

如果想提供其他 JPEG 图像,只需修改 --image_file 参数即可。

如果将模型数据下载到其他目录,则需要使 --model_dir 指向所使用的目录。

在Windows环境下小伙伴们可以在直接到GitHub下载该程序案例:https://github.com/tensorflow/models

但是有时下载识别模型时经常会失败,这里我给大家分享下我调试好的Demo:https://download.csdn.net/download/m0_38106923/10892062

  • 使用 C++ API

可以使用 C++ 运行同一 Inception-v3 模型,以在生产环境中使用模型。为此,可以下载包含 GraphDef 的归档文件,GraphDef 会以如下方式定义模型(从 TensorFlow 代码库的根目录运行):


  
  1. curl -L "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz" |
  2. tar -C tensorflow/examples/label_image/data -xz

接下来,我们需要编译包含加载和运行图的代码的 C++ 二进制文件。如果按照针对您平台的说明下载 TensorFlow 源安装文件,则应该能够通过从 shell 终端运行以下命令来构建该示例:

bazel build tensorflow/examples/label_image/...
 

上述命令应该会创建一个可执行的二进制文件,然后可以运行该文件,如下所示:

bazel-bin/tensorflow/examples/label_image/label_image
 

这里使用的是框架附带的默认示例图像,输出结果应与以下内容类似:


  
  1. I tensorflow/examples/label_image/main.cc:206] military uniform (653): 0.834306
  2. I tensorflow/examples/label_image/main.cc:206] mortarboard (668): 0.0218692
  3. I tensorflow/examples/label_image/main.cc:206] academic gown (401): 0.0103579
  4. I tensorflow/examples/label_image/main.cc:206] pickelhaube (716): 0.00800814
  5. I tensorflow/examples/label_image/main.cc:206] bulletproof vest (466): 0.00535088

在本例中,我们使用的是默认的海军上将格蕾丝·赫柏的图像,您可以看到,网络可正确识别她穿的是军装,分数高达 0.8。

 

有关其工作原理,请参阅 tensorflow/examples/label_image/main.cc 文件(https://www.tensorflowers.cn/t/7558)。希望此代码可帮助小伙伴们将 TensorFlow 集成到自己的应用中,因此将逐步介绍主要函数:

命令行标记可控制文件加载路径以及输入图像的属性。由于应向模型输入 299x299 RGB 的正方形图像,因此标记 input_width 和 input_height 应设成这些值。此外,我们还需要将像素值从介于 0 至 255 之间的整数缩放成浮点值,因为图执行运算时采用的是浮点数。我们使用 input_mean 和 input_std 标记控制缩放;先用每个像素值减去 input_mean,然后除以 input_std。

这些值看起来可能有点不可思议,但它们只是原模型作者根据他 / 她想要用做输入图像以用于训练的内容定义的。如果小伙伴们有自行训练的图,只需对值做出调整,使其与您在训练过程中使用的任何值一致即可。

你可以参阅 ReadTensorFromImageFile() 函数,了解这些标记是如何应用到图像的。


  
  1. // Given an image file name, read in the data, try to decode it as an image,
  2. // resize it to the requested size, and then scale the values as desired.
  3. Status ReadTensorFromImageFile(string file_name, const int input_height,
  4. const int input_width, const float input_mean,
  5. const float input_std,
  6. std::vector<Tensor>* out_tensors) {
  7. tensorflow::GraphDefBuilder b;

首先,创建一个 GraphDefBuilder 对象,它可用于指定要运行或加载的模型。


  
  1. string input_name = "file_reader";
  2. string output_name = "normalized";
  3. tensorflow::Node* file_reader =
  4. tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
  5. b.opts().WithName(input_name));

然后,为要运行的小型模型创建节点,以加载、调整和缩放像素值,从而获得主模型期望作为其输入的结果。我创建的第一个节点只是一个 Const 操作,它会存储一个张量,其中包含要加载的图像的文件名。然后,该张量会作为第一个输入传递到 ReadFile 操作。小伙伴们可能会注意到,我将 b.opts() 作为最后一个参数传递到所有操作创建函数。该参数可确保该节点会添加到 GraphDefBuilder 中存储的模型定义中。此外,我还通过向 b.opts() 发起 WithName() 调用来命名 ReadFile 运算符,从而命名该节点,虽然这不是绝对必要的操作(因为如果您不执行此操作,系统会自动为该节点分配名称),但确实可简化调试过程。


  
  1. // Now try to figure out what kind of file it is and decode it.
  2. const int wanted_channels = 3;
  3. tensorflow::Node* image_reader;
  4. if (tensorflow::StringPiece(file_name).ends_with(".png")) {
  5. image_reader = tensorflow::ops::DecodePng(
  6. file_reader,
  7. b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
  8. } else {
  9. // Assume if it's not a PNG then it must be a JPEG.
  10. image_reader = tensorflow::ops::DecodeJpeg(
  11. file_reader,
  12. b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
  13. }
  14. // Now cast the image data to float so we can do normal math on it.
  15. tensorflow::Node* float_caster = tensorflow::ops::Cast(
  16. image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
  17. // The convention for image ops in TensorFlow is that all images are expected
  18. // to be in batches, so that they're four-dimensional arrays with indices of
  19. // [batch, height, width, channel]. Because we only have a single image, we
  20. // have to add a batch dimension of 1 to the start with ExpandDims().
  21. tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
  22. float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
  23. // Bilinearly resize the image to fit the required dimensions.
  24. tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
  25. dims_expander, tensorflow::ops::Const({input_height, input_width},
  26. b.opts().WithName("size")),
  27. b.opts());
  28. // Subtract the mean and divide by the scale.
  29. tensorflow::ops::Div(
  30. tensorflow::ops::Sub(
  31. resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
  32. tensorflow::ops::Const({input_std}, b.opts()),
  33. b.opts().WithName(output_name));

接下来,我继续添加更多节点,以便将文件数据解码为图像、将整数转换为浮点值、调整大小,最终对像素值运行减法和除法运算。


  
  1. // This runs the GraphDef network definition that we've just constructed, and
  2. // returns the results in the output tensor.
  3. tensorflow::GraphDef graph;
  4. TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));

最后,我获得一个存储在变量 b 中的模型定义,并可以使用 ToGraphDef() 函数将其转换成一个完整的图定义。


  
  1. std::unique_ptr<tensorflow::Session> session(
  2. tensorflow::NewSession(tensorflow::SessionOptions()));
  3. TF_RETURN_IF_ERROR(session->Create(graph));
  4. TF_RETURN_IF_ERROR(session->Run({}, {output_name}, {}, out_tensors));
  5. return Status::OK();

接下来,创建一个 tf.Session 对象(它是实际运行图的接口)并运行它,从而指定要从哪个节点获得输出,以及将输出数据存放在什么位置。

这为我们提供了一个由 Tensor 对象构成的向量,在此例中,我们知道它将仅是单个对象的长度。在这种情况下,可以将 Tensor 视为多维数组,它将 299 像素高、299 像素宽、3 通道的图像存储为浮点值。如果产品中已有自己的图像处理框架,则应该能够使用该框架,只要在将图像馈送到主图之前对其应用相同的转换即可。

下面是使用 C++ 动态创建小型 TensorFlow 图的简单示例,但对于预训练的 Inception 模型,我们需要从文件中加载更大的定义。可以查看 LoadGraph() 函数,了解如何做到这一点。


  
  1. // Reads a model graph definition from disk, and creates a session object you
  2. // can use to run it.
  3. Status LoadGraph(string graph_file_name,
  4. std::unique_ptr<tensorflow::Session>* session) {
  5. tensorflow::GraphDef graph_def;
  6. Status load_graph_status =
  7. ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
  8. if (!load_graph_status.ok()) {
  9. return tensorflow::errors::NotFound("Failed to load compute graph at '",
  10. graph_file_name, "'");
  11. }

如果已经浏览图像加载代码,则应该对许多术语都比较熟悉了。我会加载直接包含 GraphDef 的 protobuf 文件,而不是使用 GraphDefBuilder 生成 GraphDef 对象。


  
  1. session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
  2. Status session_create_status = (*session)->Create(graph_def);
  3. if (!session_create_status.ok()) {
  4. return session_create_status;
  5. }
  6. return Status::OK();
  7. }

然后,我从该 GraphDef 创建一个 Session 对象,并将其传递回调用程序,以便调用程序稍后可以运行它。

GetTopLabels() 函数很像图像加载,只是在本例中,我想要获取运行主图得到的结果,并将其转换成得分最高的标签的排序列表。与图像加载器类似,该函数可创建一个 GraphDefBuilder,向其添加几个节点,然后运行较短的图,从而获取一对输出张量。在本例中,它们分别表示最高结果的经过排序的得分和索引位置。


  
  1. // Analyzes the output of the Inception graph to retrieve the highest scores and
  2. // their positions in the tensor, which correspond to categories.
  3. Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
  4. Tensor* indices, Tensor* scores) {
  5. tensorflow::GraphDefBuilder b;
  6. string output_name = "top_k";
  7. tensorflow::ops::TopK(tensorflow::ops::Const(outputs[0], b.opts()),
  8. how_many_labels, b.opts().WithName(output_name));
  9. // This runs the GraphDef network definition that we've just constructed, and
  10. // returns the results in the output tensors.
  11. tensorflow::GraphDef graph;
  12. TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
  13. std::unique_ptr<tensorflow::Session> session(
  14. tensorflow::NewSession(tensorflow::SessionOptions()));
  15. TF_RETURN_IF_ERROR(session->Create(graph));
  16. // The TopK node returns two outputs, the scores and their original indices,
  17. // so we have to append :0 and :1 to specify them both.
  18. std::vector<Tensor> out_tensors;
  19. TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
  20. {}, &out_tensors));
  21. *scores = out_tensors[0];
  22. *indices = out_tensors[1];
  23. return Status::OK();

PrintTopLabels() 函数会采用这些经过排序的结果,并以友好的方式输出这些结果。CheckTopLabel() 函数与其极为相似,但出于调试目的,需确保最有可能的标签是我们预期的值。

最后,main() 将所有这些调用绑定在一起。


  
  1. int main(int argc, char* argv[]) {
  2. // We need to call this to set up global state for TensorFlow.
  3. tensorflow::port::InitMain(argv[0], &argc, &argv);
  4. Status s = tensorflow::ParseCommandLineFlags(&argc, argv);
  5. if (!s.ok()) {
  6. LOG(ERROR) << "Error parsing command line flags: " << s.ToString();
  7. return -1;
  8. }
  9. // First we load and initialize the model.
  10. std::unique_ptr<tensorflow::Session> session;
  11. string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph);
  12. Status load_graph_status = LoadGraph(graph_path, &session);
  13. if (!load_graph_status.ok()) {
  14. LOG(ERROR) << load_graph_status;
  15. return -1;
  16. }

加载主图


  
  1. // Get the image from disk as a float array of numbers, resized and normalized
  2. // to the specifications the main graph expects.
  3. std::vector<Tensor> resized_tensors;
  4. string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image);
  5. Status read_tensor_status = ReadTensorFromImageFile(
  6. image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean,
  7. FLAGS_input_std, &resized_tensors);
  8. if (!read_tensor_status.ok()) {
  9. LOG(ERROR) << read_tensor_status;
  10. return -1;
  11. }
  12. const Tensor& resized_tensor = resized_tensors[0];

加载、处理输入图像并调整其大小


  
  1. // Actually run the image through the model.
  2. std::vector<Tensor> outputs;
  3. Status run_status = session->Run({ {FLAGS_input_layer, resized_tensor}},
  4. {FLAGS_output_layer}, {}, &outputs);
  5. if (!run_status.ok()) {
  6. LOG(ERROR) << "Running model failed: " << run_status;
  7. return -1;
  8. }

在本示例中,我们将图像作为输入,运行已加载的图


  
  1. // This is for automated testing to make sure we get the expected result with
  2. // the default settings. We know that label 866 (military uniform) should be
  3. // the top label for the Admiral Hopper image.
  4. if (FLAGS_self_test) {
  5. bool expected_matches;
  6. Status check_status = CheckTopLabel(outputs, 866, &expected_matches);
  7. if (!check_status.ok()) {
  8. LOG(ERROR) << "Running check failed: " << check_status;
  9. return -1;
  10. }
  11. if (!expected_matches) {
  12. LOG(ERROR) << "Self-test failed!";
  13. return -1;
  14. }
  15. }

出于测试目的,我们可以在下方检查以确保获得了预期的输出


  
  1. // Do something interesting with the results we've generated.
  2. Status print_status = PrintTopLabels(outputs, FLAGS_labels);

最后,输出我们找到的标签


  
  1. if (!print_status.ok()) {
  2. LOG(ERROR) << "Running print failed: " << print_status;
  3. return -1;
  4. }

在本示例中,我使用 TensorFlow 的 Status 对象处理错误,它非常方便,因为通过它,小伙伴们可以使用 ok() 检查工具了解是否发生了任何错误,如果有错误,则可以输出可以读懂的错误消息。

在本示例中,我演示的是对象识别,但小伙伴们应该能够对自己在各种领域找到的或自行训练的其他模型使用非常相似的代码。我希望这一小示例可就如何在自己的产品中使用 TensorFlow 为大家带来一些启发。

文章来源: handsome-man.blog.csdn.net,作者:不脱发的程序猿,版权归原作者所有,如需转载,请联系作者。

原文链接:handsome-man.blog.csdn.net/article/details/85645215

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。