Tensorflow --BeamSearch

github:https://github.com/zle1992/Seq2Seq-Chatbot

1、 注意在infer阶段,需要需要reuse,

2、If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:

  • The encoder output has been tiled to beam_width via tf.contrib.seq2seq.tile_batch (NOT tf.tile).
  • The batch_size argument passed to the zero_state method of this wrapper is equal to true_batch_size * beam_width.
  • The initial state created with zero_state above contains a cell_state value containing properly tiled final state from the encoder.
 1 import tensorflow as tf
 2 from tensorflow.python.layers.core import Dense
 3 
 4 
 5 BEAM_WIDTH = 5
 6 BATCH_SIZE = 128
 7 
 8 
 9 # INPUTS
10 X = tf.placeholder(tf.int32, [BATCH_SIZE, None])
11 Y = tf.placeholder(tf.int32, [BATCH_SIZE, None])
12 X_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
13 Y_seq_len = tf.placeholder(tf.int32, [BATCH_SIZE])
14 
15 
16 # ENCODER         
17 encoder_out, encoder_state = tf.nn.dynamic_rnn(
18     cell = tf.nn.rnn_cell.BasicLSTMCell(128), 
19     inputs = tf.contrib.layers.embed_sequence(X, 10000, 128),
20     sequence_length = X_seq_len,
21     dtype = tf.float32)
22 
23 
24 # DECODER COMPONENTS
25 Y_vocab_size = 10000
26 decoder_embedding = tf.Variable(tf.random_uniform([Y_vocab_size, 128], -1.0, 1.0))
27 projection_layer = Dense(Y_vocab_size)
28 
29 
30 # ATTENTION (TRAINING)
31 with tf.variable_scope('shared_attention_mechanism'):
32     attention_mechanism = tf.contrib.seq2seq.LuongAttention(
33         num_units = 128, 
34         memory = encoder_out,
35         memory_sequence_length = X_seq_len)
36 
37 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
38     cell = tf.nn.rnn_cell.BasicLSTMCell(128),
39     attention_mechanism = attention_mechanism,
40     attention_layer_size = 128)
41 
42 
43 # DECODER (TRAINING)
44 training_helper = tf.contrib.seq2seq.TrainingHelper(
45     inputs = tf.nn.embedding_lookup(decoder_embedding, Y),
46     sequence_length = Y_seq_len,
47     time_major = False)
48 training_decoder = tf.contrib.seq2seq.BasicDecoder(
49     cell = decoder_cell,
50     helper = training_helper,
51     initial_state = decoder_cell.zero_state(BATCH_SIZE,tf.float32).clone(cell_state=encoder_state),
52     output_layer = projection_layer)
53 with tf.variable_scope('decode_with_shared_attention'):
54     training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
55         decoder = training_decoder,
56         impute_finished = True,
57         maximum_iterations = tf.reduce_max(Y_seq_len))
58 training_logits = training_decoder_output.rnn_output
59 
60 
61 # BEAM SEARCH TILE
62 encoder_out = tf.contrib.seq2seq.tile_batch(encoder_out, multiplier=BEAM_WIDTH)
63 X_seq_len = tf.contrib.seq2seq.tile_batch(X_seq_len, multiplier=BEAM_WIDTH)
64 encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=BEAM_WIDTH)
65 
66 
67 # ATTENTION (PREDICTING)
68 with tf.variable_scope('shared_attention_mechanism', reuse=True):
69     attention_mechanism = tf.contrib.seq2seq.LuongAttention(
70         num_units = 128, 
71         memory = encoder_out,
72         memory_sequence_length = X_seq_len)
73 
74 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
75     cell = tf.nn.rnn_cell.BasicLSTMCell(128),
76     attention_mechanism = attention_mechanism,
77     attention_layer_size = 128)
78 
79 
80 # DECODER (PREDICTING)
81 predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
82     cell = decoder_cell,
83     embedding = decoder_embedding,
84     start_tokens = tf.tile(tf.constant([1], dtype=tf.int32), [BATCH_SIZE]),
85     end_token = 2,
86     initial_state = decoder_cell.zero_state(BATCH_SIZE * BEAM_WIDTH,tf.float32).clone(cell_state=encoder_state),
87     beam_width = BEAM_WIDTH,
88     output_layer = projection_layer,
89     length_penalty_weight = 0.0)
90 with tf.variable_scope('decode_with_shared_attention', reuse=True):
91     predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
92         decoder = predicting_decoder,
93         impute_finished = False,
94         maximum_iterations = 2 * tf.reduce_max(Y_seq_len))
95 predicting_logits = predicting_decoder_output.predicted_ids[:, :, 0]
96 
97 print('successful')

参考:

https://gist.github.com/higepon/eb81ba0f6663a57ff1908442ce753084

https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/BeamSearchDecoder

https://github.com/tensorflow/nmt#beam-search