Tensorflow.Estimators笔记 -预制模型

举报
Edison 发表于 2018/10/18 21:16:54 2018/10/18
【摘要】 tensorflow自带的评估器(estimator)有Pre-made模式,可以轻松启动第一个深度学习模型。

一、创建一个机遇预制模型的评估器, 需要先构建以下几个任务:

1.1、创建一个或多个输入函数

1.2、定义模型的特征列

1.3、实例化一个评估器、指定特征列和超参数

1.4、调用评估器对象的方法,采用适当的输入函数作为数据源

二、创建输入函数

输入函数能够返回一个tf.data.Dataset对象,一个含有两个元素的元组。

2.1 特征:一个python字典——

2.1.1 key是特征名

2.1.2 value 是包括所有特征值的数组

2.2 标签:包括所有例子的标签值

一个输入函数例子: 

def input_evaluation_set():

    features = {'SepalLength': np.array([6.4, 5.0]),

                'SepalWidth':  np.array([2.8, 2.3]),

                'PetalLength': np.array([5.6, 3.3]),

                'PetalWidth':  np.array([2.2, 1.0])}

    labels = np.array([2, 1])

    return features, labels

Tensorflow的Dataset API示意图:

image.png

图例:

Dataset——包含创建和转换数据集的方法的基类

TextLineDataset——从文本文件中逐行读取

TFRecordDataset——读取TFRecord文件

FixedLengthRecordDataset ——读取二进制文件中的固定格式记录

Iterator——提供一种访问数据集元素的方法

Dataset API 可以胜任很多通用案例。举个例子,使用DatasetAPI可以很轻易的读取海量并行文件中的记录,同时将他们加入一个单独的流中。

举个例子,在这个训练程序例子的输入函数,在iris_data.py中可用:

def train_input_fn(features, labels, batch_size):

    """An input function for training"""

    # Convert the inputs to a Dataset.

    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.

    return dataset.shuffle(1000).repeat().batch(batch_size)

三、定义特征列

特征列是描述模型应该如何使用来自特征字典中输入的原始数据的对象,当你简历一个评估器模型,可以通过一个包含特征列的列表来描述每个你想在模型中使用的特征。tf.feature_column模块提供了很多选项。

对于鸢尾花,四组特征是数值型的,所以我们可以建立一个特征列列表高速评估器模型,去以32位浮点数分别代表这四组特征。因而,创建特征列的代码如下:

# Feature columns describe how to use the input.

my_feature_columns = []

for key in train_x.keys(): 

                         my_feature_columns.append(tf.feature_column.numeric_column(key=key))

特征列远不止如此,更多的细节将在getting started guide 的后续中继续介绍。

四、实例化一个评估器

鸢尾花问题是一个经典的分类问题,幸运的是tensorflow提供了多种预制分类评估器,包括:

在这个例子中,tf.estimator.DNNClassifier看起来是最好的选择,以下是实例化代码:

# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.

classifier = tf.estimator.DNNClassifier(

    feature_columns=my_feature_columns,

    # Two hidden layers of 10 nodes each.

    hidden_units=[10, 10],

    # The model must choose between 3 classes.

    n_classes=3)

五、训练、评估、预测

现在我们已经有了一个评估器对象,我们可以调用以下方法:

·训练模型

·评估训练后的模型

·使用训练后的模型进行预测

5.1 训练模型

调用评估器的train方法来训练模型方法:

# Train the Model.

classifier.train(

    input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),

    steps=args.train_steps)

注释:lambda函数调用了input_fn来捕获参数,同时提供一个不带参数的输入函数,作为评估器的期望。step参数代表模型的训练次数。

5.2 评估训练后的模型

有了训练后的模型, 我们可以得到一些模型性能的统计数据,评估代码示例:

# Evaluate the model.

eval_result = classifier.evaluate(

    input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

与train方法不同的是,在评估过程中没有steps参数,evel_input_fn仅生成一次迭代后的数据。

运行这段代码生成以下输出:

Test set accuracy: 0.967

5.3 由训练后的模型做出预测

现在我们有了能够产生很好评估结果的模型,现在用这个训练后的模型预测基于未标注的鸢尾花的物种。根据训练和评估,我们做了一个预测函数:

# Generate predictions from the model

expected = ['Setosa', 'Versicolor', 'Virginica']

predict_x = {

    'SepalLength': [5.1, 5.9, 6.9],

    'SepalWidth': [3.3, 3.0, 3.1],

    'PetalLength': [1.7, 4.2, 5.4],

    'PetalWidth': [0.5, 1.5, 2.1],

}

predictions = classifier.predict(

    input_fn=lambda:iris_data.eval_input_fn(predict_x,                                            batch_size=args.batch_size))

predict 方法返回一个Python迭代器。生成一个包含预测结果的字典,代码如下。代码将展示一些预测和他们的准确性。

template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')

for pred_dict, expec in zip(predictions, expected):

    class_id = pred_dict['class_ids'][0]

    probability = pred_dict['probabilities'][class_id]

    print(template.format(iris_data.SPECIES[class_id],

                          100 * probability, expec))

运行结果:

...

Prediction is "Setosa" (99.6%), expected "Setosa"

Prediction is "Versicolor" (99.8%), expected "Versicolor"

Prediction is "Virginica" (97.9%), expected "Virginica"

六、总结

1、预制评估器能够快速有效地建立标准模型

2、现在可以开始写Tensorflow程序啦,一些材料可以考虑。


【版权声明】本文为华为云社区用户翻译文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容, 举报邮箱:cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。