Tensorflow入门实战-mnist手写体识别

 1 \'\'\'
 2 tensorflow 教程
 3 mnist样例
 4 \'\'\'
 5 import tensorflow as tf 
 6 from tensorflow.examples.tutorials.mnist import input_data
 7 
 8 #参数设置
 9 INPUT_NODE=784
10 OUTPUT_NODE=10
11 LAYER1_NODE=500
12 BATCH_SIZE=100
13 LEARNING_RATE_BASE=0.8
14 LEARNING_RATE_DECAY=0.99
15 REGULARIZATION_RATE=0.0001
16 TRAINING_STEPS=10000
17 MOVEING_AVEARGE_DECAY=0.99
18 
19 
20 def inference(input_tensor,avg_class,weights1,biases1,weights2,biases2):
21     \'\'\'
22     定义前向计算的过程:
23     avg_class是滑动平均函数,使权重平滑过渡,保留历史数据,
24     为None时,表示普通的参数更新过程
25     \'\'\'
26     if avg_class==None:
27         layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1)
28         return tf.matmul(layer1,weights2)+biases2
29     else:
30         layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)+avg_class.average(biases1)))
31         return tf.matmul(layer1,avg_class.average(weights2)+avg_class.average(biases2))
32 
33 def train(mnist):
34     #设置输入变量 placerholder表示占位,开启会话训练的时候需要传入数据
35     x=tf.placeholder(tf.float32,[None,INPUT_NODE],name=\'x-input\')
36     y_=tf.placeholder(tf.float32,[None,OUTPUT_NODE],name=\'y-input\')
37 
38     #设置权重变量,variable表示训练时需要自动更新
39     weights1=tf.Variable(tf.random_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1))
40     biases1=tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
41     weights2=tf.Variable(tf.random_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1))
42     biases2=tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))
43     
44     #y=inference(x,None,weights1,biases1,weights2,biases2)
45     
46     global_step=tf.Variable(0,trainable=False)#不可更新参数
47     variable_averages=tf.train.ExponentialMovingAverage(MOVEING_AVEARGE_DECAY,global_step)#min(decay,(1+step)/(10+step)) 后面的变量会越来越大,表示参数的更新越来越稳定,大都依赖于历史数据
48     variable_averages_op=variable_averages.apply(tf.trainable_variables())
49     average_y=inference(x,variable_averages,weights1,biases1,weights2,biases2)
50 
51     cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=average_y,labels=tf.argmax(y_,1))#计算图的输出是每个分类的得分,但是要求输入的标签是正确答案的下标
52     cross_entropy_mean=tf.reduce_mean(cross_entropy)
53 
54     regularizer=tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
55     regularization=regularizer(weights1)+regularizer(weights2)
56     loss=cross_entropy+regularization
57 
58     learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY)#学习率成阶梯状衰减 每个epoch衰减一次,也就是一整轮数据训练完衰减一次
59     train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
60 
61     train_op=tf.group(train_step,variable_averages_op)#把反向传播是需要更新的参数打包,不使用滑动平均不需要这句话,因为只更新权重。滑动平均还要利用历史数据更新并更新历史数据
62 
63     correct_prediction=tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1))
64     accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
65     
66 
67     with tf.Session() as sess:
68         tf.global_variables_initializer().run()
69         validate_feed={x:mnist.validation.images,y_:mnist.validation.labels}
70         test_feed={x:mnist.test.images,y_:mnist.test.labels}
71 
72         for i in range(TRAINING_STEPS):
73             if i%1000==0:
74                 validate_acc=sess.run(accuracy,feed_dict=validate_feed)
75                 print(\'After %d training steps,validation accuracy using average model is %g\' %(i,validate_acc))
76                 
77             xs,ys=mnist.train.next_batch(BATCH_SIZE)
78             sess.run(train_op,feed_dict={x:xs,y_:ys})
79 
80         test_acc=sess.run(accuracy,feed_dict=test_feed)
81         print(\'After %d training steps,test accuracy using average model is %g\' %(TRAINING_STEPS,test_acc))    
82 
83 def main(argv=None):
84     mnist=input_data.read_data_sets("/tmp/data",one_hot=True)
85     train(mnist)
86 
87 if __name__ == \'__main__\':
88     tf.app.run()