pytorch seq2seq闲聊机器人beam search返回结果

decoder.py

"""
实现解码器
"""
import heapq

import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random
from chatbot.attention import Attention


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()

        self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
                                      embedding_dim=config.chatbot_decoder_embedding_dim,
                                      padding_idx=config.target_ws.PAD)

        #需要的hidden_state形状:[1,batch_size,64]
        self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
                          hidden_size=config.chatbot_decoder_hidden_size,
                          num_layers=config.chatbot_decoder_number_layer,
                          bidirectional=False,
                          batch_first=True,
                          dropout=config.chatbot_decoder_dropout)

        #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64]

        self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
        self.attn = Attention(method="general")
        self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False)

    def forward(self, encoder_hidden,target,encoder_outputs):
        # print("target size:",target.size())
        #第一个时间步的输入的hidden_state
        decoder_hidden = encoder_hidden  #[1,batch_size,128*2]
        #第一个时间步的输入的input
        batch_size = encoder_hidden.size(1)
        decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device)         #[batch_size,1]
        # print("decoder_input:",decoder_input.size())


        #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
        decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device)

        if random.random()>0.5:    #teacher_forcing机制

            for t in range(config.chatbot_target_max_len):
                decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
                decoder_outputs[:,t,:] = decoder_output_t


                #获取当前时间步的预测值
                value,index = decoder_output_t.max(dim=-1)
                decoder_input = index.unsqueeze(-1)  #[batch_size,1]
                # print("decoder_input:",decoder_input.size())
        else:
            for t in range(config.chatbot_target_max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                decoder_outputs[:, t, :] = decoder_output_t
                #把真实值作为下一步的输入
                decoder_input = target[:,t].unsqueeze(-1)
                # print("decoder_input size:",decoder_input.size())
        return decoder_outputs,decoder_hidden


    def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
        '''
        计算一个时间步的结果
        :param decoder_input: [batch_size,1]
        :param decoder_hidden: [1,batch_size,128*2]
        :return:
        '''

        decoder_input_embeded = self.embedding(decoder_input)
        # print("decoder_input_embeded:",decoder_input_embeded.size())

        #out:[batch_size,1,128*2]
        #decoder_hidden :[1,bathc_size,128*2]
        # print(decoder_hidden.size())
        out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden)

        ##### 开始attention ############
        ### 1. 计算attention weight
        attn_weight = self.attn(decoder_hidden,encoder_outputs)  #[batch_size,1,encoder_max_len]
        ### 2. 计算context vector
        #encoder_ouputs :[batch_size,encoder_max_len,128*2]
        context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
        ### 3. 计算 attention的结果
        #[batch_size,128*2]  #context_vector:[batch_size,128*2] --> 128*4
        #attention_result = [batch_size,128*4] --->[batch_size,128*2]
        attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
        # attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
        #### attenion 结束

        # print("decoder_hidden size:",decoder_hidden.size())
        #out :【batch_size,1,hidden_size】

        # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
        out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
        # print("out_fc:",out_fc.size())
        return out_fc,decoder_hidden

    def evaluate(self,encoder_hidden,encoder_outputs):

        # 第一个时间步的输入的hidden_state
        decoder_hidden = encoder_hidden  # [1,batch_size,128*2]
        # 第一个时间步的输入的input
        batch_size = encoder_hidden.size(1)
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)  # [batch_size,1]
        # print("decoder_input:",decoder_input.size())

        # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
        decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
            config.device)

        predict_result = []
        for t in range(config.chatbot_target_max_len):
            decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
            decoder_outputs[:, t, :] = decoder_output_t

            # 获取当前时间步的预测值
            value, index = decoder_output_t.max(dim=-1)
            predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
            decoder_input = index.unsqueeze(-1)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
            # predict_result.append(decoder_input)
        #把结果转化为ndarray,每一行是一条预测结果
        predict_result = np.array(predict_result).transpose()
        return decoder_outputs, predict_result

    def evaluate_with_beam_search(self, encoder_hidden, encoder_outputs):
        """
        使用beam search完成评估,只能输入一个句子,得到一个输出
        :param encoder_hidden:
        :param encoder_outputs:
        :return:
        """
        # 第一个时间步的输入的hidden_state
        decoder_hidden = encoder_hidden  # [1,batch_size,128*2]
        # 第一个时间步的输入的input
        batch_size = encoder_hidden.size(1)
        assert batch_size == 1, "beam search的过程中,batch_size只能为1"
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)  # [batch_size,1]

        prev_beam = Beam()
        prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)

        while True:
            cur_beam = Beam()
            for prob, complete, seq_list, decoder_input, decoder_hidden in prev_beam:
                if complete:  # 有可能前一次已经到达eos了,但是概率不是最大的
                    cur_beam.add(prob, complete, seq_list, decoder_input, decoder_hidden)
                else:
                    decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)

                    value, index = torch.topk(decoder_output_t, config.beam_width)
                    # print("value index size:",value[0].size(),index[0].size())
                    for m, n in zip(value[0], index[0]):
                        # print("m,n size:",m.size(),n.size(),m,n)
                        cur_prob = prob * m.item()
                        decoder_input = torch.LongTensor([[n.item()]]).to(config.device)
                        cur_seq_list = seq_list + [decoder_input]
                        if n == config.target_ws.EOS:
                            cur_complete = True
                        else:
                            cur_complete = False
                        cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden)

            best_prob, best_complete, best_seq, _, _ = max(cur_beam)
            if best_complete or len(best_seq) - 1 == config.chatbot_target_max_len:

                best_seq = [i.item() for i in best_seq]
                if best_seq[0] == config.target_ws.SOS:
                    best_seq = best_seq[1:]
                if best_seq[-1] == config.target_ws.EOS:
                    best_seq = best_seq[:-1]
                return best_seq


            else:
                prev_beam = cur_beam


class Beam:
    """保存每一个时间步的数据"""

    def __init__(self):
        self.heapq = list()
        self.beam_width = config.beam_width

    def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
        heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
        # 保证最终只有一个beam width个结果
        if len(self.heapq) > self.beam_width:
            heapq.heappop(self.heapq)

    def __iter__(self):
        for item in self.heapq:
            yield item

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from chatbot.encoder import Encoder
from chatbot.decoder import Decoder


class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq,self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, input,input_len,target):
        encoder_outputs,encoder_hidden = self.encoder(input,input_len)
        decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
        return decoder_outputs

    def evaluate(self,input,input_len):
        encoder_outputs, encoder_hidden = self.encoder(input, input_len)
        decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
        return decoder_outputs,predict_result


    def evaluate_with_beam_search(self,input,input_len):
        encoder_outputs, encoder_hidden = self.encoder(input, input_len)
        best_seq = self.decoder.evaluate_with_beam_search(encoder_hidden, encoder_outputs)
        return best_seq

  eval.py

"""
进行模型的评估
"""

import torch
import torch.nn.functional as F
from chatbot.dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from chatbot.seq2seq import Seq2Seq

def eval():
    model = Seq2Seq().to(config.device)
    model.eval()
    model.load_state_dict(torch.load("./models/model.pkl"))

    loss_list = []
    data_loader = get_dataloader(train=False)
    bar = tqdm(data_loader,total=len(data_loader),desc="当前进行评估")
    with torch.no_grad():
        for idx,(input,target,input_len,target_len) in enumerate(bar):
            input = input.to(config.device)
            target = target.to(config.device)
            input_len = input_len.to(config.device)

            decoder_outputs,predict_result = model.evaluate(input,input_len) #[batch_Size,max_len,vocab_size]
            loss = F.nll_loss(decoder_outputs.view(-1,len(config.target_ws)),target.view(-1),ignore_index=config.input_ws.PAD)
            loss_list.append(loss.item())
            bar.set_description("idx:{} loss:{:.6f}".format(idx,np.mean(loss_list)))
    print("当前的平均损失为:",np.mean(loss_list))


def interface():
    from chatbot.cut_sentence import cut
    import config
    #加载模型
    model = Seq2Seq().to(config.device)
    model.eval()
    model.load_state_dict(torch.load("./models/model.pkl"))

    #准备待预测的数据
    while True:
        origin_input =input("me>>:")
        # if "你是谁" in origin_input or "你叫什么" in origin_input:
        #     result = "我是小智。"
        # elif "你好" in origin_input or "hello" in origin_input:
        #     result = "Hello"
        # else:
        _input = cut(origin_input, by_word=True)
        input_len = torch.LongTensor([len(_input)]).to(config.device)
        _input = torch.LongTensor([config.input_ws.transform(_input,max_len=config.chatbot_input_max_len)]).to(config.device)

        outputs,predict = model.evaluate(_input,input_len)
        result = config.target_ws.inverse_transform(predict[0])
        print("chatbot>>:",result)


def interface_with_beamsearch():
    from chatbot.cut_sentence import cut
    import config
    # 加载模型
    model = Seq2Seq().to(config.device)
    model.eval()
    model.load_state_dict(torch.load("./models/model.pkl"))

    # 准备待预测的数据
    while True:
        origin_input = input("me>>:")
        _input = cut(origin_input, by_word=True)
        input_len = torch.LongTensor([len(_input)]).to(config.device)
        _input = torch.LongTensor([config.input_ws.transform(_input, max_len=config.chatbot_input_max_len)]).to(
            config.device)

        best_seq = model.evaluate_with_beam_search(_input, input_len)
        result = config.target_ws.inverse_transform(best_seq)
        print("chatbot>>:", result)




if __name__ == '__main__':
    # interface()
    interface_with_beamsearch()