pytorch笔记:09)Attention机制

刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*)

1.RNN中的attention

pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

首先,RNN的输入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相对于传统的encoder-decoder模型,attention机制仅在decoder处有所不同。下面具体看看:

1>保存了rnn每个词向量对应隐藏层的输出状态(encoder_outputs),用于decoder的attention机制

#train代码部分

for ei in range(input_length):

encoder_output, encoder_hidden = encoder(

input_tensor[ei], encoder_hidden)

encoder_outputs[ei] = encoder_output[0, 0]

1

2

3

4

5

2>AttnDecoderRNN的forward

1.输入的input经过embed

embedded = self.embedding(input).view(1, 1, -1)

embedded = self.dropout(embedded)

1

2

2.获取关于输入的attention权重,这里的Q=decoder_rnn的input,K=decoder_rnn的隐藏元

2.1求Q和K相似度的方法有很多,这里让全连接层自己来学习,把embedded和hidden连接在一起经过fc层(部分修改了下)

similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))

1

2.2 经过softmax获得归一化的权重

attn_weights = F.softmax(similarity, dim=1)

1

3.权重应用于encoder输出的所有词对应的词向量上(对应相乘即可)->获得attention结果

attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))

1

4.把attention结果和decoder的输入cat在一起,使用1个全连接层来融合二者,最终生成带注意力机制的词向量

output = torch.cat((embedded[0], attn_applied[0]), 1)

output = self.attn_combine(output).unsqueeze(0)

1

2

5.根据decoder的上一个输出单词来预测下一个单词,这里多插一句,decoder的首个输入为起始标志符’sos’,其根据encode最后的隐藏元来预测第一个单词,后面依次类推。

output = F.relu(output)

output, hidden = self.gru(output, hidden)

output = F.log_softmax(self.out(output[0]), dim=1)

return output, hidden, attn_weights

1

2

3

4

2.transformer中的attention

“Attention is All You Need”(霸气标题),pytorch代码推荐2篇:

哈佛大学NLP研究组:http://nlp.seas.harvard.edu/2018/04/03/attention.html

台湾小哥的代码(较通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:

下面以soft_attention为例(*input和output的attention,仅和self_attention做下区分,第1篇代码标记src_attn,第2篇代码标记dec_enc_attn),soft_attention的目标:给定序列Q(query,长度记为lq,维度dk),键序列K(key,长度记为lk,维度dk),值序列V(value,长度记为lv,维度dv),计算Q和K的相似度权重,最后再乘上V。

下面直接贴上attention-is-all-you-need-pytorch中MultiHeadAttention代码

def forward(self, q, k, v, mask=None):

d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

sz_b, len_q, _ = q.size()

sz_b, len_k, _ = k.size()

sz_b, len_v, _ = v.size()

residual = q

q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)

k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)

v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

#这里把batch和分块数放在一起,便于使用bmm

q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk

k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk

v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..

output, attn = self.attention(q, k, v, mask=mask)

output = output.view(n_head, sz_b, len_q, d_v)

output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

output = self.dropout(self.fc(output))

output = self.layer_norm(output + residual)

return output, attn

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

和RNN中的attention的不同,这里的batch_size和seq_len均不为1,其把序列视为一个整体,求Q和V的相似度可使用点乘(V可以视为上面提及的encoder_outputs),获得的是一个相似度矩阵,比如Q是一个长度为10的序列,K是一个长度为16的序列,其相似度矩阵就是一个10*16的矩阵,再如矩阵第一行表示Q的第一个单词和K序列所有单词的相似度。

similarity:=(lq,dk)∗(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)

similarity:=(lq,dk)∗(dk,lk)=(lq,lk)

然后,生成带注意力机制的词向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而这里使用的是残差(类似于unet和resnet),最后使用PositionwiseFeedForward(2个fc层)来融合attn_applied和input,最终生成带注意力机制的词向量。

attention_applied=(lq,lk)∗(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)

attention_applied=(lq,lk)∗(lv,dv)=(lq,dv)

细节部分

在数据预处理部分,对序列s都进行了首尾标记,比如s=’’+ s + ‘’,刚看transform(之前跳过了seq2seq),对下面的代码甚是不解

decoder_input=target_seq[:, :-1] #这里不是去掉终止标记<eos>,去掉的可能是padding_0,只为兼容target_ground_y的序列长度?

encoder_input=input_seq[:, 1:] #encoder的输入序列去掉了起始标记<sos>

target_ground_y= target_seqtrg[:, 1:] #用于计算模型loss的target,去掉了起始标记<sos>

1

2

3

其实在pytorch官方教程中说的比较清楚,看下图

encoder的输入序列和ground_true只需要一个终止符即可,而decoder的输入序列开始必须指定一个起始符,让其根据context预测输出序列的第一个单词,后面根据前一个单词再预测下一个单词,依次类推直到当前预测的单词为终止标记’eos’,才计算loss.

---------------------

作者:PJ-Javis

来源:CSDN

原文:https://blog.csdn.net/jiangpeng59/article/details/84859640

版权声明:本文为博主原创文章,转载请附上博文链接!