《智能系统与技术丛书 生成对抗网络入门指南》—2.2.3Tensorflow实例:图像分类
2.2.3 Tensorflow实例:图像分类
图像分类是机器学习任务中非常常见的问题,这里我们查看一个TensorFlow的官方案例:如何使用TensorFlow的高级接口Estimator来实现鸢尾花的图像分类。
鸢尾花有多种类型,可以通过花萼和花瓣的不同特征来加以区分(见图2-12)。TensorFlow提供的数据集中包含了下面四个植物学特征。
花萼长度
花萼宽度
花瓣长度
花瓣宽度
每一条数据也对应了一种鸢尾花分类的标签,如下所示。
山鸢尾(0)
变色鸢尾(1)
维吉尼亚鸢尾(2)
图2-12 不同类型的鸢尾花
我们可以根据此数据集来设置输入函数,以提供用于训练、评估和预测的数据。
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
接着对于数据的特征需要设置特征列。
FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
feature_columns = [
tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
这里使用深度神经网络模型来作为分类器进行训练,其包含一个输入层、三个隐含层以及一个输出层。其中输入层为四个节点,分别对应四种特征,隐含层分别为10、20、10个单元,最后的输出层为三个节点,分别对应三个分类。TensorFlow的代码实现如下。
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
模型搭建完毕就可以基于数据集进行训练了。
classifier.train(
input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
steps=args.train_steps)
最终对训练完毕的模型使用测试数据进行准确性评估,这样一整套基于深度模型的图像分类模型就搭建完毕了。
eval_result = classifier.evaluate(
input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
- 点赞
- 收藏
- 关注作者
评论(0)