1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| ''' 位置编码器 每一个词在不同的位置上面有不同的意思 ''' import torch.nn as nn import torch import math from torch.autograd import Variable
class PositionEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000) -> None: super(PositionEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0), d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe.unsqueeze(0) self.register_buffer('pe', pe)
def forward(self, x): x = x + Variable(self.pe[:, x.size(1)], requires_grad=False) return self.dropout(x)
|