用Tensorflow搭建神经网络的一般步骤

用Tensorflow搭建神经网络的一般步骤如下:

① 导入模块

② 创建模型变量和占位符

③ 建立模型

④ 定义loss函数

⑤ 定义优化器(optimizer), 使 loss 达到最小

⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

⑦ 训练模型

⑧ 检验模型

⑨ 使用模型预测数据

⑩ 保存模型

⑪ 使用Tensorboard的可视化功能

下面以一个简单的线性回归问题为例:

首先是训练模型的代码: train_model.py

 1 # ① 导入模块
 2 import tensorflow as tf
 3 
 4 # ② 创建模型的变量和占位符
 5 W = tf.Variable([.3], dtype=tf.float32)
 6 b = tf.Variable([-.3], dtype=tf.float32)
 7 x = tf.placeholder(tf.float32, name="input_x")
 8 y = tf.placeholder(tf.float32, name="input_y")
 9 
10 # ③建立模型
11 linear_model = W*x + b
12 # 如果是矩阵相乘,可以写成:
13 # linear_model = tf.matmul(x, W)+b  # matmul表示矩阵相乘
14 
15 # ④ 定义loss函数
16 loss = tf.reduce_sum(tf.square(linear_model - y))
17 
18 # ⑤ 定义优化器(optimizer), 使 loss 达到最小
19 learning_rate=0.01
20 optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
21 train = optimizer.minimize(loss)
22 
23 # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤)
24 
25 # ⑦ 训练模型
26 # 假设模型是y=2x+1
27 x_train = [1, 2, 3, 4]
28 y_train = [3, 5, 7, 9]
29 
30 init = tf.global_variables_initializer() # 添加用于初始化变量的节点
31 sess = tf.Session()
32 sess.run(init) # 运行初始化操作
33 for step in range(1000):
34    sess.run(train, {x: x_train, y: y_train})
35 
36 \'\'\'
37 第⑦步和第⑩步可以合并为:
38 for step in xrange(1000000):
39     sess.run(train, {x: x_train, y: y_train})
40     if step % 1000 == 0:
41         saver.save(sess, \'my-model\', global_step=step)
42 \'\'\'
43 
44 # ⑧ 检验模型
45 curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
46 print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
47 \'\'\'
48 W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
49 \'\'\'
50 
51 # ⑨ 使用模型预测数据
52 x_predict = [-1, 0, 1, 2]
53 predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
54 # 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
55 print("result:", predicted_values)
56 \'\'\'
57 result: [-1.0000062   0.99999553  2.99999714  4.99999905]
58 \'\'\'
59 
60 # ⑩ 保存模型
61 tf.add_to_collection("predict_network", linear_model)
62 saver = tf.train.Saver()
63 saver_path=saver.save(sess, "save/model.ckpt")
64 
65 # ⑪ 使用Tensorboard的可视化功能
66 # 定义保存日志的路径
67 path = "log"  # 也可写成: path = "./log"
68 writer=tf.summary.FileWriter(path, sess.graph)
69 
70 sess.close()

然后是载入模型的代码: restore_model.py

 1 import tensorflow as tf
 2 
 3 with tf.Session() as sess:
 4     new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
 5     new_saver.restore(sess,"save/model.ckpt")
 6     # print(tf.get_collection("predict_network"))
 7     restored_y=tf.get_collection("predict_network")[0]  # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 8 
 9     graph=tf.get_default_graph()
10     restored_x=graph.get_operation_by_name("input_x").outputs[0]
11 
12     predict_data = [-2, 3, 4]
13     predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data})
14 
15     print("result:", predicted_result)  # result: [-3.00000787  7.00000048  9.00000191]