《智能系统与技术丛书 生成对抗网络入门指南》—2.2.3Tensorflow实例:图像分类

举报
华章计算机 发表于 2019/05/29 15:54:46 2019/05/29
【摘要】 本书摘自《智能系统与技术丛书 生成对抗网络入门指南》一文中的第2章,第2.2.3节,作者是史丹青。

2.2.3 Tensorflow实例:图像分类

       图像分类是机器学习任务中非常常见的问题,这里我们查看一个TensorFlow的官方案例:如何使用TensorFlow的高级接口Estimator来实现鸢尾花的图像分类。

       鸢尾花有多种类型,可以通过花萼和花瓣的不同特征来加以区分(见图2-12)。TensorFlow提供的数据集中包含了下面四个植物学特征。

       花萼长度

       花萼宽度

       花瓣长度

       花瓣宽度

       每一条数据也对应了一种鸢尾花分类的标签,如下所示。

       山鸢尾(0)

       变色鸢尾(1)

       维吉尼亚鸢尾(2)

image.png

图2-12 不同类型的鸢尾花

       我们可以根据此数据集来设置输入函数,以提供用于训练、评估和预测的数据。

def input_evaluation_set():

    features = {'SepalLength': np.array([6.4, 5.0]),

                'SepalWidth':  np.array([2.8, 2.3]),

                'PetalLength': np.array([5.6, 3.3]),

                'PetalWidth':  np.array([2.2, 1.0])}

    labels = np.array([2, 1])

    return features, labels

      接着对于数据的特征需要设置特征列。

FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

feature_columns = [

    tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]

       这里使用深度神经网络模型来作为分类器进行训练,其包含一个输入层、三个隐含层以及一个输出层。其中输入层为四个节点,分别对应四种特征,隐含层分别为10、20、10个单元,最后的输出层为三个节点,分别对应三个分类。TensorFlow的代码实现如下。

classifier = tf.estimator.DNNClassifier(

    feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)

       模型搭建完毕就可以基于数据集进行训练了。

classifier.train(

    input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),

    steps=args.train_steps)

       最终对训练完毕的模型使用测试数据进行准确性评估,这样一整套基于深度模型的图像分类模型就搭建完毕了。

eval_result = classifier.evaluate(

    input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。