tensorflow BasicRNNCell调试

运行以下代码,进入~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py和~/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell.py单步追踪调试

调试中import tensorflow as tf,利用tf.Session().run(variable)打印变量

查看BasicRNNCell和dynamic_rnn的实现方式

 1 #-*-coding:utf8-*-
 2 
 3 __author = "buyizhiyou"
 4 __date = "2017-11-20"
 5 
 6 '''
 7 单步调试,学习rnn的tf实现
 8 '''
 9 import tensorflow as tf 
10 import numpy as np
11 import pdb  
12   
13 X = tf.random_normal(shape=[2,3,4], dtype=tf.float32)#(2,3,4)==>(Batch_size,Time_steps(序列长度),Data_Vector)
14 pdb.set_trace()  
15 cell = tf.nn.rnn_cell.BasicRNNCell(10)#output_size:10,也可以换成GRUCell,LSTMAACell,BasicRNNCell  
16 state = cell.zero_state(2, tf.float32)#batch_size:2  
17 output, state = tf.nn.dynamic_rnn(cell, X, initial_state=state, time_major=False)  
18 with tf.Session() as sess:  
19     sess.run(tf.global_variables_initializer())  
20     print (output.get_shape())
21     print (sess.run(state))