77范文网 - 专业文章范例文档资料分享平台

深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(3)

来源:网络收集 时间:2018-11-17 下载这篇文档 手机版
说明:文章内容仅供预览,部分内容可能不全,需要完整文档或者需要复制内容,请下载word后使用。下载word有问题请添加微信号:或QQ: 处理(尽可能给您提供完整文档),感谢您的支持与谅解。点击这里给我发消息

UCLoud中国云三强: www.ucloud.cn

Train

train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。

def train():

sess = tf.Session()

file_labels = get_imagesfile(FLAGS.train_data_dir)

images, labels, coord, threads = batch_data(file_labels, sess) endpoints = network(images, labels) saver = tf.train.Saver()

sess.run(tf.global_variables_initializer())

train_writer = tf.train.SummaryWriter('./log' + '/train',sess.graph) test_writer = tf.train.SummaryWriter('./log' + '/val') start_step = 0 if FLAGS.restore:

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if ckpt:

saver.restore(sess, ckpt)

print \ start_step += int(ckpt.split('-')[-1]) logger.info(':::Training Start:::')

UCLoud中国云三强: www.ucloud.cn

try:

while not coord.should_stop(): # logger.info('step {0} start'.format(i)) start_time = time.time()

_, loss_val, train_summary, step = sess.run([endpoints['train_op'], endpoints['loss'], endpoints['merged_summary_op'], endpoints['global_step']]) train_writer.add_summary(train_summary, step) end_time = time.time()

logger.info(\end_time-start_time, loss_val))

if step > FLAGS.max_steps: break

# logger.info(\end_time-start_time, loss_val))

if step % FLAGS.eval_steps == 1:

accuracy_val,test_summary, step = sess.run([endpoints['accuracy'], endpoints['merged_summary_op'], endpoints['global_step']]) test_writer.add_summary(test_summary, step)

logger.info('===============Eval a batch in Train data=======================')

logger.info( 'the step {0} accuracy {1}'.format(step, accuracy_val))

UCLoud中国云三强: www.ucloud.cn

logger.info('===============Eval a batch in Train data=======================') if step % FLAGS.save_steps == 1:

logger.info('Save the ckpt of {0}'.format(step))

saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step']) except tf.errors.OutOfRangeError:

# print \ logger.info('==================Train Finished================')

saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=endpoints['global_step']) finally:

coord.request_stop() coord.join(threads) sess.close()

UCLoud中国云三强: www.ucloud.cn

Graph

UCLoud中国云三强: www.ucloud.cn

Loss and Accuracy

Validation

训练完成之后,想对完成的模型在测试数据集上做一个评估,这里我也曾经尝试利用batch_data,将slice_input_producer中epoch设置为1,来做相关的工作,但是发现这里无法和train 共用,会出现epoch无初始化值的问题(train中传epoch为None),所以这里自己写了shuffle batch的逻辑,将测试集的images和labels通过feed_dict传进到网络,得到模型的输出, 然后做相关指标的计算:

百度搜索“77cn”或“免费范文网”即可找到本站免费阅读全部范文。收藏本站方便下次阅读,免费范文网,提供经典小说综合文库深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(3)在线全文阅读。

深度学习进阶笔记之八 TensorFlow与中文手写汉字识别(3).doc 将本文的Word文档下载到电脑,方便复制、编辑、收藏和打印 下载失败或者文档不完整,请联系客服人员解决!
本文链接:https://www.77cn.com.cn/wenku/zonghe/281937.html(转载请注明文章来源)
Copyright © 2008-2022 免费范文网 版权所有
声明 :本网站尊重并保护知识产权,根据《信息网络传播权保护条例》,如果我们转载的作品侵犯了您的权利,请在一个月内通知我们,我们会及时删除。
客服QQ: 邮箱:tiandhx2@hotmail.com
苏ICP备16052595号-18
× 注册会员免费下载(下载后可以自由复制和排版)
注册会员下载
全站内容免费自由复制
注册会员下载
全站内容免费自由复制
注:下载文档有可能“只有目录或者内容不全”等情况,请下载之前注意辨别,如果您已付费且无法下载或内容有问题,请联系我们协助你处理。
微信: QQ: