Tensorflow 有监督学习

Submitted by Lizhe on Fri, 08/18/2017 - 09:20

有监督的学习在于

使用某个带标注信息的输入数据集合(其中的每个样本都标注了真是的或者期望的输出), 去训练一个推断模型,

该模型能够覆盖一个数据集, 并可对不存在于初始训练集合中的新样本的输出进行干预

这个推断模型实际上是一系列的数学运算, 运算使用的公式是固定的,但是参与运算的值 都是模型的参数, 在训练过程中会被不断更新, 以使模型能够学习,并对其输出进行调整

462

首先需要对模型参数进行初始化. 通常采用对参数随机赋值的方法, 但对于比较简单的模型,也可以将各个参数的初始值全部设置为0

读取训练数据(包括每个数据样本及其期望输出). 通常人们会在这些数据送入模型之前随机打乱样本的次序

在训练数据上执行推断模型. 这样, 在当前模型的参数配置下, 每个训练样本都会得到一个输出值

计算损失. 损失是一个能够刻画模型在最后一步得到的输出与来自训练集的期望输出之间差距的概括性指标. 损失函数有多重类型.

调整模型参数. 这一步对应于实际的学习过程. 给定损失函数, 学习的目的在于通过大量训练步骤改善各参数的值, 从而将损失最小化.常见的策略是使用梯度下降算法

 

训练结束之后便进入评估阶段. 评估使用的数据集是模型预先无法获知的.

通过评估, 可以了解到所训练的模型在训练集之外的推广能力. 一种常见的方法是将原始数据集分成2部分, 将70%的样本用于训练, 剩下的30%的样本用于评估

 

下面是训练和评估的基本代码框架

import tensorflow as tf

def inference(x):
    #计算推断模型在数据x上的输出,返回结果
def loss(x,y):
    #根据训练数据x及期望获得的结果y计算损失

def inputs():
    #读取或生成训练数据x极其期望输出y

def train(total_loss):
    #根据计算的总损失训练或调整模型参数

def evaluate(sess,x,y):
    #对训练得到的模型进行评估
    
saver = tf.train.Saver()

#在一个会话对象中启动数据流图, 搭建流程

with tf.Session() as sess:
    
    tf.initialize_all_variables().run()
    
    x,y = inputs()
    
    total_loss = loss(x,y)
    train_op = train(total_loss)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    
    #实际的训练迭代次数
    training_steps = 1000
    
    for step in range(training_steps):
        sess.run([train_op])
        #输出损失递件过程
        if step%10==0:
            print "loss:", sess.run([total_loss])
            
        if step%1000==0:
            saver.save(sess,'my-checkpoint', global_step=step)
    
    evaluate(sess,x,y)
    
    coord.request_stop()
    coord.join(threads)
    
    saver.save(sess,'my-checkpoint', global_step=training_steps)
    sess.close()

 

上面的代码在每1000次迭代时都会创建一个检查点,用于持久化当前的运行状态, 默认情况下Saver对象只会保存最近的5个文件

如果希望从某个检查点恢复训练, 则应该先使用tr.train.get_checkpoint_state方法验证是否有检查点文件被保存下来了

然后使用tf.train.Saver.restore方法恢复变量的值

with tf.Session() as sess:

    initial_step = 0 
    ckpt = tf.train.get_checkpoint_state(os.path.dirname(__file__))
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        initial_step = int(ckpt.model_checkpoint_path.rsplit('-',1)[1])
    
    for step in range(training_steps):
        ...