作者小头像 Lv.3
202 成长值

个人介绍

云计算网络运维专业人员

感兴趣或擅长的领域

云计算、自动化运维
个人勋章
  • 活跃之星
成长雷达
130
12
25
15
20

个人资料

个人介绍

云计算网络运维专业人员

感兴趣或擅长的领域

云计算、自动化运维

达成规则

发布时间 2022/07/01 10:09:46 最后回复 irrational 2024/03/04 09:05:51 版块 人工智能
83668 2839 405
他的回复:
import tensorflow as tffrom poems.model import rnn_modelfrom poems.poems import process_poemsimport numpy as npstart_token = 'B'end_token = 'E'model_dir = './model/'corpus_file = './data/poems.txt'lr = 0.0002def to_word(predict, vocabs): predict = predict[0] predict /= np.sum(predict) sample = np.random.choice(np.arange(len(predict)), p=predict) if sample > len(vocabs): return vocabs[-1] else: return vocabs[sample]def gen_poem(begin_word): batch_size = 1 print('## loading corpus from %s' % model_dir) tf.reset_default_graph() poems_vector, word_int_map, vocabularies = process_poems(corpus_file) input_data = tf.placeholder(tf.int32, [batch_size, None]) end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len( vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr) saver = tf.train.Saver(tf.global_variables()) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init_op) checkpoint = tf.train.latest_checkpoint(model_dir) saver.restore(sess, checkpoint) x = np.array([list(map(word_int_map.get, start_token))]) [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], feed_dict={input_data: x}) word = begin_word or to_word(predict, vocabularies) poem_ = '' i = 0 while word != end_token: poem_ += word i += 1 if i > 24: break x = np.array([[word_int_map[word]]]) [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], feed_dict={input_data: x, end_points['initial_state']: last_state}) word = to_word(predict, vocabularies) return poem_def pretty_print_poem(poem_): poem_sentences = poem_.split('。') for s in poem_sentences: if s != '' and len(s) > 10: print(s + '。')if __name__ == '__main__': begin_char = input('## (输入 quit 退出)请输入第一个字 please input the first character: ') if begin_char == 'quit': exit() poem = gen_poem(begin_char) pretty_print_poem(poem_=poem)## (输入 quit 退出)请输入第一个字 please input the first character: 于