注意力机制

注意力机制概述

自主性的与非自主性的注意力提示解释了人类的注意力的方式, 下面来看看如何通过这两种注意力提示, 用神经网络来设计注意力机制的框架,

首先,考虑一个相对简单的状况, 即只使用非自主性提示。 要想将选择偏向于感官输入, 则可以简单地使用参数化的全连接层, 甚至是非参数化的最大汇聚层或平均汇聚层。

因此,“是否包含自主性提示”将注意力机制与全连接层或汇聚层区别开来。 在注意力机制的背景下,自主性提示被称为查询(query)。 给定任何查询,注意力机制通过注意力汇聚(attention pooling) 将选择引导至感官输入(sensory inputs,例如中间特征表示)。 在注意力机制中,这些感官输入被称为(value)。 更通俗的解释,每个值都与一个(key)配对, 这可以想象为感官输入的非自主提示。 如下图所示,可以通过设计注意力汇聚的方式, 便于给定的查询(自主性提示)与键(非自主性提示)进行匹配, 这将引导得出最匹配的值(感官输入)。

此外,注意力机制的设计有许多替代方案。 例如可以设计一个不可微的注意力模型, 该模型可以使用强化学习方法 (Mnih et al., 2014)进行训练

注意力汇聚

考虑下面这个回归问题:

给定的成对的“输入-输出”数据集 {(x1,y1),,(xn,yn)}\{(x_1, y_1), \ldots, (x_n, y_n)\} ,如何学习 ff 来预测任意新输入 xx 的输出 y^=f(x)\hat{y} = f(x)

根据下面的非线性函数生成一个人工数据集,其中加入的噪声项为 ϵ\epsilon

yi=2sin(xi)+xi0.8+ϵ,y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon,

其中 ϵ\epsilon 服从均值为 00 和标准差为 0.50.5 的正态分布。在这里生成了 5050 个训练样本和 5050 个测试样本。

平均汇聚

先使用最简单的估计器来解决回归问题。基于平均汇聚来计算所有训练样本输出值的平均值:

f(x)=1ni=1nyi,f(x) = \frac{1}{n}\sum_{i=1}^n y_i,

如下图所示,这个估计器确实不够聪明。(粉色为预测值,蓝色为真实值)

非参数注意力汇聚

显然,平均汇聚忽略了输入 xix_i 。我们可以根据输入的位置对输出 yiy_i 进行加权:

f(x)=i=1nK(xxi)j=1nK(xxj)yi,f(x) = \sum_{i=1}^n \frac{K(x - x_i)}{\sum_{j=1}^n K(x - x_j)} y_i,

其中 KK(kernel)。公式所描述的估计器被称为Nadaraya-Watson核回归(Nadaraya-Watson kernel regression)。

这里不会深入讨论核函数的细节,但受此启发,我们可以从注意力机制框架的角度,重写该式成为一个更加通用的注意力汇聚(attention pooling)公式:

f(x)=i=1nα(x,xi)yi,f(x) = \sum_{i=1}^n \alpha(x, x_i) y_i,

其中 xx 是查询, (xi,yi)(x_i, y_i) 是键值对。注意力汇聚是 yiy_i 的加权平均。将查询 xx 和键 xix_i 之间的关系建模为注意力权重(attention weight) α(x,xi)\alpha(x, x_i) ,如上式所示,这个权重将被分配给每一个对应值 yiy_i 。对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布:它们是非负的,并且总和为1。

为了更好地理解注意力汇聚,下面考虑一个高斯核(Gaussian kernel),其定义为:

K(u)=12πexp(u22).K(u) = \frac{1}{\sqrt{2\pi}} \exp(-\frac{u^2}{2}).

将高斯核代入上面两个公式,可以得到:

f(x)=i=1nα(x,xi)yi=i=1nexp(12(xxi)2)j=1nexp(12(xxj)2)yi=i=1nsoftmax(12(xxi)2)yi.\begin{aligned} f(x) &=\sum_{i=1}^n \alpha(x, x_i) y_i\\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}(x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i. \end{aligned}

在上式中,如果一个键 xix_i 越是接近给定的查询 xx ,那么分配给这个键对应值 yiy_i 的注意力权重就会越大,也就“获得了更多的注意力”。

值得注意的是,Nadaraya-Watson核回归是一个非参数模型。

因此,上式是非参数的注意力汇聚(nonparametric attention pooling)模型。

带参数注意力汇聚

在下面的查询 xx 和键 xix_i 之间的距离乘以可学习参数 ww

f(x)=i=1nα(x,xi)yi=i=1nexp(12((xxi)w)2)j=1nexp(12((xxj)w)2)yi=i=1nsoftmax(12((xxi)w)2)yi.\begin{aligned}f(x) &= \sum_{i=1}^n \alpha(x, x_i) y_i \\&= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}((x - x_i)w)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}((x - x_j)w)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}((x - x_i)w)^2\right) y_i.\end{aligned}

即:

1
2
3
4
5
6
7
8
9
10
11
12
13
class NWKernelRegression(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

def forward(self, queries, keys, values):
# queries和attention_weights的形状为(查询个数,“键-值”对个数)
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = nn.functional.softmax(
-((queries - keys) * self.w)**2 / 2, dim=1)
# values的形状为(查询个数,“键-值”对个数)
return torch.bmm(self.attention_weights.unsqueeze(1),
values.unsqueeze(-1)).reshape(-1)

训练

接下来,将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
import matplotlib.pyplot as plt


def plot_kernel_reg(x_test, y_truth, y_hat, x_train, y_train):
plt.plot(x_test, y_truth, label='Truth')
plt.plot(x_test, y_hat, label='Pred')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.xlim(0, 5)
plt.ylim(-1, 5)
plt.plot(x_train, y_train, 'o', alpha=0.5)
plt.show()


def f(x):
return 2 * torch.sin(x) + x ** 0.8


class NWKernelRegression(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

def forward(self, queries, keys, values):
# queries和attention_weights的形状为(查询个数,“键-值”对个数)
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = nn.functional.softmax(
-((queries - keys) * self.w) ** 2 / 2, dim=1)
# values的形状为(查询个数,“键-值”对个数)
return torch.bmm(self.attention_weights.unsqueeze(1),
values.unsqueeze(-1)).reshape(-1)


n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出
x_test = torch.arange(0, 5, 0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))

weights = torch.ones((2, 10)) * 0.1

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)

for epoch in range(5):
trainer.zero_grad()
l = loss(net(x_train, keys, values), y_train)
l.sum().backward()
trainer.step()
print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(x_test, y_truth, y_hat, x_train, y_train)

注意力评分函数

上一节中高斯核指数部分可以视为注意力评分函数(attention scoring function), 然后把这个函数的输出结果输入到softmax函数中进行运算。 通过上述步骤,将得到与键对应的值的概率分布(即注意力权重)。 最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。

下图说明了如何将注意力汇聚的输出计算成为值的加权和,其中 aa 表示注意力评分函数。 由于注意力权重是概率分布,因此加权和其本质上是加权平均值。

用数学语言描述,假设有一个查询 qRq\mathbf{q} \in \mathbb{R}^qmm 个“键-值”对 (k1,v1),,(km,vm)(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m) ,其中 kiRk\mathbf{k}_i \in \mathbb{R}^kviRv\mathbf{v}_i \in \mathbb{R}^v 。注意力汇聚函数 ff 就被表示成值的加权和:

f(q,(k1,v1),,(km,vm))=i=1mα(q,ki)viRv,f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,

其中查询 q\mathbf{q} 和键 ki\mathbf{k}_i 的注意力权重(标量)是通过注意力评分函数 aa 将两个向量映射成标量,再经过softmax运算得到的:

α(q,ki)=softmax(a(q,ki))=exp(a(q,ki))j=1mexp(a(q,kj))R.\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.

正如上图所示,选择不同的注意力评分函数 aa 会导致不同的注意力汇聚操作。本节将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。

掩蔽softmax操作

正如上面提到的,softmax操作用于输出一个概率分布作为注意力权重。 在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。 例如,为了高效处理小批量数据集, 某些文本序列被填充了没有意义的特殊词元。 为了仅将有意义的词元作为值来获取注意力汇聚, 可以指定一个有效序列长度(即词元的个数), 以便在计算softmax时过滤掉超出指定范围的位置。 下面的masked_softmax函数 实现了这样的掩蔽softmax操作(masked softmax operation), 其中任何超出有效长度的位置都被掩蔽并置为0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X:3D张量,valid_lens:1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)

加性注意力

一般来说,当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。

给定查询 qRq\mathbf{q} \in \mathbb{R}^q 和键 kRk\mathbf{k} \in \mathbb{R}^k加性注意力(additive attention)的评分函数为

a(q,k)=wvtanh(Wqq+Wkk)R,a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},

其中可学习的参数是 WqRh×q\mathbf W_q\in\mathbb R^{h\times q}WkRh×k\mathbf W_k\in\mathbb R^{h\times k}wvRh\mathbf w_v\in\mathbb R^{h}

如上式所示,将查询和键连结起来后输入到一个多层感知机(MLP)中,感知机包含一个隐藏层,其隐藏单元数是一个超参数 hh 。通过使用 tanh\tanh 作为激活函数,并且禁用偏置项。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class AdditiveAttention(nn.Module):
"""加性注意力"""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)

def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# 在维度扩展后,
# queries的形状:(batch_size,查询的个数,1,num_hidden)
# key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
# 使用广播方式进行求和
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# self.w_v仅有一个输出,因此从形状中移除最后那个维度。
# scores的形状:(batch_size,查询的个数,“键-值”对的个数)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# values的形状:(batch_size,“键-值”对的个数,值的维度)
return torch.bmm(self.dropout(self.attention_weights), values)

缩放点积注意力

使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度 dd 。假设查询和键的所有元素都是独立的随机变量,并且都满足零均值和单位方差,那么两个向量的点积的均值为 00 ,方差为 dd 。为确保无论向量长度如何,点积的方差在不考虑向量长度的情况下仍然是 11 ,我们再将点积除以 d\sqrt{d} ,则缩放点积注意力(scaled dot-product attention)评分函数为:

a(q,k)=qk/d.a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}.

在实践中,我们通常从小批量的角度来考虑提高效率,例如基于 nn 个查询和 mm 个键-值对计算注意力,其中查询和键的长度为 dd ,值的长度为 vv 。查询 QRn×d\mathbf Q\in\mathbb R^{n\times d} 、键 KRm×d\mathbf K\in\mathbb R^{m\times d} 和值 VRm×v\mathbf V\in\mathbb R^{m\times v} 的缩放点积注意力是:

softmax(QKd)VRn×v.\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.

下面的缩放点积注意力的实现使用了暂退法进行模型正则化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class DotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)

# queries的形状:(batch_size,查询的个数,d)
# keys的形状:(batch_size,“键-值”对的个数,d)
# values的形状:(batch_size,“键-值”对的个数,值的维度)
# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# 设置transpose_b=True为了交换keys的最后两个维度
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)

Bahdanau 注意力

对于机器翻译问题: 通过设计一个基于两个循环神经网络的编码器-解码器架构, 用于序列到序列学习。 具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量, 然后循环神经网络解码器根据生成的词元和上下文变量 按词元生成输出(目标)序列词元。 然而,即使并非所有输入(源)词元都对解码某个词元都有用, 在每个解码步骤中仍使用编码相同的上下文变量。 有什么方法能改变上下文变量呢?

受学习对齐想法的启发,Bahdanau等人提出了一个没有严格单向对齐限制的可微注意力模型 (Bahdanau et al., 2014)。 在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。

假设输入序列中有 TT 个词元,解码时间步 tt' 的上下文变量是注意力集中的输出:

ct=t=1Tα(st1,ht)ht,\mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t,

其中,时间步 t1t' - 1 时的解码器隐状态 st1\mathbf{s}_{t' - 1} 是查询,编码器隐状态 ht\mathbf{h}_t 既是键,也是值,注意力权重 α\alpha 是使用加性注意力打分函数计算的。

如下图,一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

多头注意力

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 hh 组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这 hh 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 hh 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) (Vaswani et al., 2017)。 对于 hh 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。 下图展示了使用全连接层来实现可学习的线性变换的多头注意力。

给定查询 qRdq\mathbf{q} \in \mathbb{R}^{d_q} 、键 kRdk\mathbf{k} \in \mathbb{R}^{d_k} 和值 vRdv\mathbf{v} \in \mathbb{R}^{d_v} ,每个注意力头 hi\mathbf{h}_ii=1,,hi = 1, \ldots, h )的计算方法为:

hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv,\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},

其中,可学习的参数包括 Wi(q)Rpq×dq\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}Wi(k)Rpk×dk\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}Wi(v)Rpv×dv\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} ,以及代表注意力汇聚的函数 ffff 可以是加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 hh 个头连结后的结果,因此其可学习参数是

WoRpo×hpv\mathbf W_o\in\mathbb R^{p_o\times h p_v}

Wo[h1hh]Rpo.\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.

基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import math
import torch
from torch import nn


def sequence_mask(X, valid_lens, value=-1e6):
max_len = X.size(1)
mask = torch.arange(max_len)[None, :].to(X.device) < valid_lens[:, None]
X[~mask] = value
return X


def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X:3D张量,valid_lens:1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
"""缩放点积注意力"""

def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)

# queries的形状:(batch_size,查询的个数,d)
# keys的形状:(batch_size,“键-值”对的个数,d)
# values的形状:(batch_size,“键-值”对的个数,值的维度)
# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# 设置transpose_b=True为了交换keys的最后两个维度
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)


def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
# num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)

# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)


class MultiHeadAttention(nn.Module):
"""多头注意力"""

def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size,查询的个数)
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)

if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)

# output的形状:(batch_size*num_heads,查询的个数,
# num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)

# output_concat的形状:(batch_size,查询的个数,num_hiddens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)


num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
print(attention.eval())

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
print(attention(X, Y, Y, valid_lens).shape)

自注意力

有了注意力机制之后,我们将词元序列输入注意力池化中, 以便同一组词元同时充当查询、键和值。 具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。 由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention) (Lin et al., 2017, Vaswani et al., 2017), 也被称为内部注意力(intra-attention) (Cheng et al., 2016, Parikh et al., 2016, Paulus et al., 2017)。 本节将使用自注意力进行序列编码,以及使用序列的顺序作为补充信息。

给定一个由词元组成的输入序列 x1,,xn\mathbf{x}_1, \ldots, \mathbf{x}_n ,其中任意 xiRd\mathbf{x}_i \in \mathbb{R}^d1in1 \leq i \leq n )。该序列的自注意力输出为一个长度相同的序列 y1,,yn\mathbf{y}_1, \ldots, \mathbf{y}_n ,其中:

yi=f(xi,(x1,x1),,(xn,xn))Rd\mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d

自注意力模型采用查询-键-值(Query-Key-Value,QKV)模式。

(1) 计算查询矩阵Q,键矩阵K,值矩阵V

假设输入序列为 X=[x1,...,xN]RDx×NX=[x_1,...,x_N] \in \Bbb{R}^{D_x×N} ,经过词嵌入得到 A=[a1,...,aN]RDa×NA=[a_1,...,a_N] \in \Bbb{R}^{D_a×N} ;将词嵌入矩阵线性映射到三个不同的空间,得到

  1. 查询矩阵 Q=[q1,...,qN]RDk×NQ=[q_1,...,q_N] \in \Bbb{R}^{D_k×N}
  2. 键矩阵 K=[k1,...,kN]RDk×NK=[k_1,...,k_N] \in \Bbb{R}^{D_k×N}
  3. 值矩阵 V=[v1,...,vN]RDv×NV=[v_1,...,v_N] \in \Bbb{R}^{D_v×N} ;

矩阵运算如下:

Q=WqA,WqRDk×DaQ = W^qA, \quad W^q \in \Bbb{R}^{D_k×D_a}

K=WkA,WkRDk×DaK = W^kA, \quad W^k \in \Bbb{R}^{D_k×D_a}

V=WvA,WvRDv×DaV = W^vA, \quad W^v \in \Bbb{R}^{D_v×D_a}

(2) 计算注意力分布

对于每个查询向量 qiq_i 使用键值对注意力机制,得到注意力分布 a^1,1,...,a^1,N\hat{a}_{1,1},...,\hat{a}_{1,N}

矩阵运算如下:

A=KTQDkA = \frac{K^TQ}{\sqrt{D_k}}

A^=softmax(A)\hat{A} = softmax(A)

其中注意力得分选用缩放点积Scaled Dot-Product,其原因是后续的Softmax函数对较大或较小的输入非常敏感(容易映射到 1100 ),因此通过因子 Dk\sqrt{D_k} 进行缩放;Softmax函数按运算。

(3) 加权求和

根据注意力分布 A^\hat{A}加权求和得到输出:

矩阵运算如下:

B=VA^B = V\hat{A}

自注意力模型的优点

  1. 提高并行计算效率;
  2. 捕捉长距离的依赖关系。

自注意力模型可以看作在一个线性投影空间中建立 XX 中不同向量之间的交互关系。上述自注意力运算的计算复杂度为 O(N2)O(N^2) 。实践中有些问题并不需要捕捉全局结构,只依赖于局部信息,此时可以使用restricted自注意力机制,即假设当前词只与前后 rr 个词发生联系(类似于卷积中的滑动窗口),此时计算复杂度为 O(rN)O(rN)


下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,张量的形状为(批量大小,时间步的数目或词元序列的长度,d)。输出与输入的张量形状相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import math
import torch
from torch import nn


def sequence_mask(X, valid_lens, value=-1e6):
max_len = X.size(1)
mask = torch.arange(max_len)[None, :].to(X.device) < valid_lens[:, None]
X[~mask] = value
return X


def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X:3D张量,valid_lens:1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
"""缩放点积注意力"""

def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)

# queries的形状:(batch_size,查询的个数,d)
# keys的形状:(batch_size,“键-值”对的个数,d)
# values的形状:(batch_size,“键-值”对的个数,值的维度)
# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# 设置transpose_b=True为了交换keys的最后两个维度
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)


def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
# num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)

# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)


class MultiHeadAttention(nn.Module):
"""多头注意力"""

def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size,查询的个数)
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)

if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)

# output的形状:(batch_size*num_heads,查询的个数,
# num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)

# output_concat的形状:(batch_size,查询的个数,num_hiddens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)


num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
print(attention.eval())

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)

比较卷积神经网络、循环神经网络和自注意力

考虑一个卷积核大小为 kk 的卷积层,由于序列长度是 nn ,输入和输出的通道数量都是 dd ,所以卷积层的计算复杂度为 O(knd2)\mathcal{O}(knd^2)

如上图所示,卷积神经网络是分层的,因此为有 O(1)\mathcal{O}(1) 个顺序操作,最大路径长度为 O(n/k)\mathcal{O}(n/k) 。例如, x1\mathbf{x}_1x5\mathbf{x}_5 处于上图中卷积核大小为3的双层卷积神经网络的感受野内。

当更新循环神经网络的隐状态时, d×dd \times d 权重矩阵和 dd 维隐状态的乘法计算复杂度为 O(d2)\mathcal{O}(d^2) 。由于序列长度为 nn ,因此循环神经网络层的计算复杂度为 O(nd2)\mathcal{O}(nd^2) 。根据上图,有 O(n)\mathcal{O}(n) 个顺序操作无法并行化,最大路径长度也是 O(n)\mathcal{O}(n)

在自注意力中,查询、键和值都是 n×dn \times d 矩阵。考虑缩放的”点-积“注意力,其中 n×dn \times d 矩阵乘以 d×nd \times n 矩阵。之后输出的 n×nn \times n 矩阵乘以 n×dn \times d 矩阵。因此,自注意力具有 O(n2d)\mathcal{O}(n^2d) 计算复杂性。

正如上图,每个词元都通过自注意力直接连接到任何其他词元。因此,有 O(1)\mathcal{O}(1) 个顺序操作可以并行计算,最大路径长度也是 O(1)\mathcal{O}(1) 。总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

位置编码

Transformer中的自注意力机制无法捕捉位置信息,这是因为其计算过程具有置换不变性(permutation invariant),导致打乱输入序列的顺序对输出结果不会产生任何影响。

对于Transformer模型 f()f(\cdot) ,标记输入序列的两个向量 xm,xnx_m,x_n ,则Transformer具有全对称性

f(,xm,,xn,)=f(,xn,,xm,)f(\cdots, x_m, \cdots, x_n, \cdots) = f(\cdots, x_n, \cdots, x_m, \cdots)

**位置编码(Position Encoding)**通过把位置信息引入输入序列中,以打破模型的全对称性。为简化问题,考虑在 m,nm,n 位置处加上不同位置编码 pm,pnp_m,p_n

f~(,xm,,xn,)=f(,xm+pm,,xn+pn,)\tilde{f}(\cdots, x_m, \cdots, x_n, \cdots) = f(\cdots, x_m+p_m, \cdots, x_n+p_n, \cdots)

对上式进行二阶Taylor展开:

f~f+pmTfxm+pnTfxn+pmT2fxm2pm+pnT2fxnpn绝对位置信息+pmT2fxmxnpn相对位置信息\tilde{f} ≈ f + \underbrace{p_m^T \frac{\partial f}{\partial x_m} + p_n^T \frac{\partial f}{\partial x_n} + p_m^T \frac{\partial^2 f}{\partial x_m^2}p_m + p_n^T \frac{\partial^2 f}{\partial x_n}p_n}_{\text{绝对位置信息}} +\underbrace{p_m^T \frac{\partial^2 f}{\partial x_m\partial x_n}p_n}_{\text{相对位置信息}}

在上式中,第25项只依赖于单一位置,表示绝对位置信息。第6项包含 m,nm,n 位置的交互项,表示相对位置信息。因此位置编码主要有两种实现形式:

  • 绝对位置编码 (absolute PE):将位置信息加入到输入序列中,相当于引入索引的嵌入。比如Sinusoidal, Learnable, FLOATER, Complex-order, RoPE
  • 相对位置编码 (relative PE):通过微调自注意力运算过程使其能分辨不同token之间的相对位置。比如XLNet, T5, DeBERTa, URPE

绝对位置编码 Absolute Position Encoding

绝对位置编码是指在输入序列经过词嵌入后的第 kktoken向量 xkRdx_k \in \Bbb{R}^{d} 中加入(add)位置向量 pkRdp_k \in \Bbb{R}^{d} ;其过程等价于首先向输入引入(concatenate)位置索引 kkone hot向量 pk:xk+pkp_k: x_k+p_k ,再进行词嵌入;因此绝对位置编码也被称为位置嵌入(position embedding)

三角函数式(Sinusoidal)位置编码

三角函数式(Sinusoidal)位置编码是在原Transformer模型中使用的一种显式编码。以一维三角函数编码为例:

假设输入表示 XRn×d\mathbf{X} \in \mathbb{R}^{n \times d} 包含一个序列中 nn 个词元的 dd 维嵌入表示。位置编码使用相同形状的位置嵌入矩阵 PRn×d\mathbf{P} \in \mathbb{R}^{n \times d} 输出 X+P\mathbf{X} + \mathbf{P} ,矩阵第 ii 行、第 2j2j 列和 2j+12j+1 列上的元素为:

pk,2i=sin(k100002i/d)pk,2i+1=cos(k100002i/d)\begin{aligned} p_{k,2i} &= \sin(\frac{k}{10000^{2i/d}}) \\ p_{k,2i+1} &= \cos(\frac{k}{10000^{2i/d}}) \end{aligned}

其中 pk,2i,pk,2i+1p_{k,2i},p_{k,2i+1} 分别是位置索引 kk 处的编码向量的第 2i,2i+12i,2i+1 个分量。一个长度为 3232 的输入序列(每个输入向量的特征维度是 128128 )的Sinusoidal编码的可视化如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn
import matplotlib.pyplot as plt


class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)

def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)


encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
plt.figure(figsize=(6, 2.5))
colors = ['b', 'g', 'r', 'c']
line_styles = ['-', '--', '-.', ':']
for i, (color, line_style) in enumerate(zip(colors, line_styles)):
plt.plot(P[0, :, i+6].permute(1, 0), label="Col %d" % (i+6), color=color, linestyle=line_style)
plt.xlabel('Row (position)')
plt.legend()
plt.show()

根据三角函数的性质,位置 α+β\alpha+\beta 处的编码向量可以表示成位置 α\alpha 和位置 β\beta 的向量的组合,因此可以外推到任意位置:

sin(α+β)=sinαcosβ+cosαsinβcos(α+β)=cosαcosβsinαsinβ\begin{aligned} \sin(\alpha+\beta) &= \sin \alpha \cos \beta + \cos \alpha \sin \beta \\ \cos(\alpha+\beta) &= \cos \alpha \cos \beta - \sin \alpha \sin \beta \end{aligned}

在图像领域,常用到二维形式的位置编码。以二维三角函数编码为例,需要分别对高度方向和宽度方向进行编码 p=[ph,pw]p=[p_h,p_w]

ph,2i=sin(h100002i/d),ph,2i+1=cos(h100002i/d)pw,2i=sin(w100002i/d),pw,2i+1=cos(w100002i/d)\begin{aligned} p_{h,2i} &= \sin(\frac{h}{10000^{2i/d}}), \quad p_{h,2i+1} = \cos(\frac{h}{10000^{2i/d}}) \\ p_{w,2i} &= \sin(\frac{w}{10000^{2i/d}}), \quad p_{w,2i+1} = \cos(\frac{w}{10000^{2i/d}}) \end{aligned}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe

可学习(Learnable)位置编码

可学习(Learnable)位置编码是指将位置编码当作可训练参数,比如输入序列(经过嵌入层后)的大小为 n×dn \times d ,则随机初始化一个 pRn×dp \in \Bbb{R}^{n \times d} 的矩阵作为位置编码,随训练过程更新。

可学习位置编码的缺点是没有外推性,即如果预训练序列的最大长度为 nn ,则无法处理长度超过 nn 的序列。此时可以将超过 nn 部分的位置编码随机初始化并微调。

FLOATER:递归式位置编码

Learning to Encode Position for Transformer with Continuous Dynamical Model

如果位置编码能够递归地生成 pk+1=f(pk)p_{k+1}=f(p_k) ,则其生成结构自带学习到位置信息的可能性。

Position Encoding with Dynamical Systems

位置编码可以表示为一个离散序列 {piRd:i=1,...,T}\{p_i \in \Bbb{R}^d : i=1,...,T\} ,若将该序列连续化为 p(t)p(t) ,为了建立序列的自相关性,使用常微分方程(Neural ODE)构建一个连续动力系统:

dp(t)dt=h(t,p(t);θh)\frac{d p(t)}{dt} = h(t,p(t);\theta_h)

或表示为积分形式:

p(t)=p(s)+sth(τ,p(τ);θh)dτp(t) = p(s) + \int_{s}^{t} h(\tau,p(\tau);\theta_h) d\tau

其中 h(τ,p(τ);θh)h(\tau,p(\tau);\theta_h) 是由 θh\theta_h 定义的神经网络。离散的位置编码可以通过对时间离散化 ti=iΔtt_i=i\cdot \Delta t 后通过递归计算得到:

pN=pN1+(N1)ΔtNΔth(τ,p(τ);θh)dτp_N = p_{N-1} + \int_{(N-1)\Delta t}^{N \Delta t} h(\tau,p(\tau);\theta_h) d\tau

在训练时采用adjoint方法计算 θh\theta_h 的梯度:

dLdθh=tsa(τ)Th(τ,p(τ);θh)θhdτ\frac{dL}{d \theta_h} = - \int_{t}^{s} a(\tau)^T \frac{\partial h(\tau,p(\tau);\theta_h)}{\partial \theta_h} d\tau

adjoint状态 a(τ)a(\tau) 可以通过求解adjoint方程得到:

da(τ)dτ=a(τ)Th(τ,p(τ);θh)p(τ)\frac{da(\tau)}{d \tau} = - a(\tau)^T \frac{\partial h(\tau,p(\tau);\theta_h)}{\partial p(\tau)}

上述方程可以使用Runge-Kutta法或中点法来求解。

Parameter Sharing and Warm-start Training

普通的可学习编码可以在Transformer的每一层都注入位置信息,但是每增加一层的位置信息,用于位置编码的参数量就会翻倍。作者设计的连续动力系统能够在每一层共享参数( θh(1)=θh(2)=θh(N)\theta_h^{(1)}=\theta_h^{(2)}=\cdots \theta_h^{(N)} )的同时,使得每一层的位置编码是不一样的。这是因为每一层的初始状态不同,所以求解得到每一层的位置编码是不一样的:

p(n)(t)=p(n)(s)+sth(n)(τ,p(n)(τ);θh(n))dτp^{(n)}(t) = p^{(n)}(s) + \int_{s}^{t} h^{(n)}(\tau,p^{(n)}(\tau);\theta_h^{(n)}) d\tau

注意到三角函数编码是FLOATER的一个特例:

pi+1[j]pi[j]={sin((i+1)cjd)sin(icjd),if j is evencos((i+1)cj1d)cos(icj1d),if j is odd={ii+1cjdcos(τcjd)dτ,if j is evenii+1cj1dsin(τcj1d)dτ,if j is odd\begin{aligned} p_{i+1}[j]-p_i[j] &= \begin{cases} \sin((i+1)\cdot c^{\frac{j}{d}})-\sin(i\cdot c^{\frac{j}{d}}), & \text{if } j \text{ is even} \\ \cos((i+1)\cdot c^{\frac{j-1}{d}})-\cos(i\cdot c^{\frac{j-1}{d}}), & \text{if } j \text{ is odd} \end{cases}\\& = \begin{cases} \int_{i}^{i+1}c^{-\frac{j}{d}} \cos(\tau \cdot c^{\frac{j}{d}}) d\tau , & \text{if } j \text{ is even} \\ \int_{i}^{i+1}-c^{-\frac{j-1}{d}} \sin(\tau \cdot c^{\frac{j-1}{d}}) d\tau , & \text{if } j \text{ is odd} \end{cases} \end{aligned}

因此可以使用三角函数编码作为FLOATER的参数初始化,然后在下游任务上微调模型。由于微分方程求解器无法利用GPU并行计算能力,常微分方程带来的额外时间开销是不容忽视的。使用三角函数编码来初始化FLOATER能够避免从头训练模型,减小时间开销。

RoPE:旋转式位置编码

RoFormer: Enhanced Transformer with Rotary Position Embedding

旋转式位置编码是指在构造查询矩阵 qq 和键矩阵 kk 时,根据其绝对位置引入旋转矩阵 R\mathcal{R}

qi=RixiWQ,kj=RjxjWKq_i = \mathcal{R}_ix_i W^Q , k_j = \mathcal{R}_jx_j W^K

旋转矩阵 R\mathcal{R} 设计为正交矩阵,且应满足 RiTRj=Rji\mathcal{R}_i^T\mathcal{R}_j=\mathcal{R}_{j-i} ,使得后续注意力矩阵的计算中隐式地包含相对位置信息:

(RixiWQ)T(RjxjWK)=(xiWQ)TRiTRjxjWK=(xiWQ)TRjixjWK(\mathcal{R}_ix_i W^Q)^T(\mathcal{R}_jx_j W^K) = (x_i W^Q)^T\mathcal{R}_i^T\mathcal{R}_jx_j W^K = (x_i W^Q)^T\mathcal{R}_{j-i}x_j W^K

相对位置编码 Relative Position Encoding

相对位置编码并不是直接建模每个输入token的位置信息,而是在计算注意力矩阵时考虑当前向量与待交互向量的位置的相对距离。

从绝对位置编码出发,其形式相当于在输入中添加入绝对位置的表示。对应的完整自注意力机制运算如下

qi=(xi+pi)WQ,kj=(xj+pj)WK,vj=(xj+pj)WVαij=softmax{(xi+pi)WQ((xj+pj)WK)T}=softmax{xiWQ(WK)TxjT+xiWQ(WK)TpjT+piWQ(WK)TxjT+piWQ(WK)TpjT}zi=j=1nαij(xjWV+pjWV)\begin{aligned} q_i &= (x_i+p_i) W^Q , k_j = (x_j+p_j) W^K ,v_j = (x_j+p_j) W^V \\ \alpha_{ij} &= \text{softmax}\{(x_i+p_i)W^Q ( (x_j+p_j)W^K)^T \} \\ &= \text{softmax}\{ x_iW^Q (W^K)^T x_j^T+x_iW^Q (W^K)^T p_j^T+p_iW^Q (W^K)^T x_j^T+p_iW^Q (W^K)^T p_j^T \} \\ z_i &= \sum_{j=1}^{n} \alpha_{ij}(x_jW^V+p_jW^V) \end{aligned}

注意到绝对位置编码相当于在自注意力运算中引入了一系列 piWQ,(pjWK)T,pjWVp_iW^Q,(p_jW^K)^T,p_jW^V 项。而相对位置编码的出发点便是将这些项调整为与相对位置 (i,j)(i,j) 有关的向量 Ri,jR_{i,j}

经典相对位置编码

在经典的相对位置编码设置中,移除了与 xix_i 的位置编码项 piWQp_iW^Q 相关的项,并将 xjx_j 的位置编码项 pjWV,pjWKp_jW^V,p_jW^K 替换为相对位置向量 Ri,jV,Ri,jKR_{i,j}^V,R_{i,j}^K

αij=softmax{xiWQ(WK)TxjT+xiWQ(Ri,jK)T}zi=j=1nαij(xjWV+Ri,jV)\begin{aligned} \alpha_{ij} &= \text{softmax}\{x_iW^Q (W^K)^T x_j^T+x_iW^Q (R_{i,j}^K)^T \} \\ z_i &= \sum_{j=1}^{n} \alpha_{ij}(x_jW^V+R_{i,j}^V) \end{aligned}

相对位置向量 Ri,jV,Ri,jKR_{i,j}^V,R_{i,j}^K 可以设置为三角函数式或可学习参数,并且通常只考虑相对位置 pminijpmaxp_{\min} \leq i-j \leq p_{\max} 的情况:

Ri,jK=wclip(ji,pmin,pmax)K(wpminK,wpmaxK)Ri,jV=wclip(ji,pmin,pmax)V(wpminV,wpmaxV)\begin{aligned} R_{i,j}^K &= w^K_{\text{clip}(j-i,p_{\min},p_{\max})} \in (w_{p_{\min}}^K,\cdots w_{p_{\max}}^K) \\ R_{i,j}^V &= w^V_{\text{clip}(j-i,p_{\min},p_{\max})} \in (w_{p_{\min}}^V,\cdots w_{p_{\max}}^V) \end{aligned}

XLNet式

XLNet模型中,移除了值向量的位置编码 pjp_j ,并将注意力计算中 xjx_j 的位置编码 pjp_j 替换为相对位置向量 RijR_{i-j} (设置为三角函数式编码), xix_i 的位置编码 pip_i 设置为可学习向量 u,vu,v

αij=softmax{xiWQ(WK)TxjT+xiWQ(WK)TRijT+uWQ(WK)TxjT+vWQ(WK)TRijT}zi=j=1nαijxjWV\begin{aligned} \alpha_{ij} &= \text{softmax}\{ x_iW^Q (W^K)^T x_j^T+x_iW^Q (W^K)^T R_{i-j}^T+uW^Q (W^K)^T x_j^T+vW^Q (W^K)^T R_{i-j}^T \} \\ z_i &= \sum_{j=1}^{n} \alpha_{ij}x_jW^V \end{aligned}

T5式

T5模型中,移除了值向量的位置编码 pjp_j 以及注意力计算中的输入-位置注意力项( xi,pjx_i,p_jpi,xjp_i,x_j ),并将位置-位置注意力项( pi,pjp_i,p_j )设置为可学习标量 ri,jr_{i,j}

αij=softmax{xiWQ(WK)TxjT+ri,j}zi=j=1nαijxjWV\begin{aligned} \alpha_{ij} &= \text{softmax}\{ x_iW^Q (W^K)^T x_j^T+r_{i,j} \} \\ z_i &= \sum_{j=1}^{n} \alpha_{ij}x_jW^V \end{aligned}

RoPE:旋转式位置编码

  • RoPE通过绝对位置编码的方式实现相对位置编码,综合了绝对位置编码和相对位置编码的优点。
  • 主要就是对attention中的q, k向量注入了绝对位置信息,然后用更新的q,k向量做attention中的内积就会引入相对位置信息了

通过线性 attention 演算,先在 qqkk 向量中引入绝对位置信息:

q~m=f(q,m),k~n=f(k,n)\tilde{\boldsymbol{q}}_{m}=\boldsymbol{f}(\boldsymbol{q}, m), \quad \tilde{\boldsymbol{k}}_{n}=\boldsymbol{f}(\boldsymbol{k}, n)

但是需要实现相对位置编码的话,需要显式融入相对。attention 运算中 qqkk 会进行内积,所以考虑在进行向量内积时考虑融入相对位置。所以假设成立恒等式:

f(q,m),f(k,n)=g(q,k,mn)\langle\boldsymbol{f}(\boldsymbol{q}, m), \boldsymbol{f}(\boldsymbol{k}, n)\rangle=g(\boldsymbol{q}, \boldsymbol{k}, m-n)

其中m-n包含着 token 之间的相对位置信息。

给上述恒等式计算设置初始条件,例如 f(q,0)=qf(q,0)=qf(k,0)=kf(k,0)=k

求解过程使用复数方式求解,将内积使用复数形式表示:

q,k=Re[qk]\langle\boldsymbol{q}, \boldsymbol{k}\rangle=\operatorname{Re}\left[\boldsymbol{q} \boldsymbol{k}^{*}\right]

转化上面内积公式可得:

Re[f(q,m)f(k,n)]=g(q,k,mn)\operatorname{Re}\left[\boldsymbol{f}(\boldsymbol{q}, m) \boldsymbol{f}^{*}(\boldsymbol{k}, n)\right]=g(\boldsymbol{q}, \boldsymbol{k}, m-n)

假设等式两边都存在复数形式,则有下式:

f(q,m)f(k,n)=g(q,k,mn)\boldsymbol{f}(\boldsymbol{q}, m) \boldsymbol{f}^{*}(\boldsymbol{k}, n)=\boldsymbol{g}(\boldsymbol{q}, \boldsymbol{k}, m-n)

将两边公式皆用复数指数形式表示:

存在 reθj=rcosθ+rsinθjr e^{\theta \mathrm{j}}=r \cos \theta+r \sin \theta \mathrm{j},即任意复数 zz 可以表示为 z=reθj\boldsymbol{z}=r e^{\theta \mathrm{j}},其中 rr 为复数的模, θ\theta 为幅角。

f(q,m)=Rf(q,m)eiΘf(q,m)f(k,n)=Rf(k,n)eiΘf(k,n)g(q,k,mn)=Rg(q,k,mn)eiΘg(q,k,mn)\begin{aligned} \boldsymbol{f}(\boldsymbol{q}, m) & =R_{f}(\boldsymbol{q}, m) e^{\mathrm{i} \Theta_{f}(\boldsymbol{q}, m)} \\ \boldsymbol{f}(\boldsymbol{k}, n) & =R_{f}(\boldsymbol{k}, n) e^{\mathrm{i} \Theta_{f}(\boldsymbol{k}, n)} \\ \boldsymbol{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) & =R_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) e^{\mathrm{i} \Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n)}\end{aligned}

由于带入上面方程中 f(k,n)f(k,n) 带 * 是共轭复数,所以指数形式应该是 exe^{-x} 形式,带入上式公式可得方程组:

Rf(q,m)Rf(k,n)=Rg(q,k,mn)Θf(q,m)Θf(k,n)=Θg(q,k,mn)\begin{aligned} R_{f}(\boldsymbol{q}, m) R_{f}(\boldsymbol{k}, n) & =R_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n) \\ \Theta_{f}(\boldsymbol{q}, m)-\Theta_{f}(\boldsymbol{k}, n) & =\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, m-n)\end{aligned}

第一个方程带入条件 m=nm=n 化简可得:

Rf(q,m)Rf(k,m)=Rg(q,k,0)=Rf(q,0)Rf(k,0)=qkR_{f}(\boldsymbol{q}, m) R_{f}(\boldsymbol{k}, m)=R_{g}(\boldsymbol{q}, \boldsymbol{k}, 0)=R_{f}(\boldsymbol{q}, 0) R_{f}(\boldsymbol{k}, 0)=\|\boldsymbol{q}\|\|\boldsymbol{k}\|

Rf(q,m)=q,Rf(k,m)=kR_{f}(\boldsymbol{q}, m)=\|\boldsymbol{q}\|, R_{f}(\boldsymbol{k}, m)=\|\boldsymbol{k}\|

从上式可以看出来复数f(q,m)f(q,m)f(k,m)f(k,m)mm 取值关系不大。

第二个方程带入 m=nm=n 化简可得:

Θf(q,m)Θf(k,m)=Θg(q,k,0)=Θf(q,0)Θf(k,0)=Θ(q)Θ(k)\Theta_{f}(\boldsymbol{q}, m)-\Theta_{f}(\boldsymbol{k}, m)=\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, 0)=\Theta_{f}(\boldsymbol{q}, 0)-\Theta_{f}(\boldsymbol{k}, 0)=\Theta(\boldsymbol{q})-\Theta(\boldsymbol{k})

上式公式变量两边挪动下得到:

Θf(q,m)Θf(k,m)=Θg(q,k,0)=Θf(q,0)Θf(k,0)=Θ(q)Θ(k)\Theta_{f}(\boldsymbol{q}, m)-\Theta_{f}(\boldsymbol{k}, m)=\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, 0)=\Theta_{f}(\boldsymbol{q}, 0)-\Theta_{f}(\boldsymbol{k}, 0)=\Theta(\boldsymbol{q})-\Theta(\boldsymbol{k})

其中上式结果相当于 mm 是自变量,结果是与 mm 相关的值,假设为 φ(m)\varphi(m) ,即 Θf(q,m)=Θ(q)+φ(m)\Theta_{f}(\boldsymbol{q}, m)=\Theta(\boldsymbol{q})+\varphi(m)

n假设为m的前一个 token,则可得n=m-1,带入上上个式子可得:

φ(m)φ(m1)=Θg(q,k,1)+Θ(k)Θ(q)\varphi(m)-\varphi(m-1)=\Theta_{g}(\boldsymbol{q}, \boldsymbol{k}, 1)+\Theta(\boldsymbol{k})-\Theta(\boldsymbol{q})

φ(m)\varphi(m) 是等差数列,假设等式右边为 θ\theta ,则mm-1位置的公差就是为 θ\theta,可推得 φ(m)=mθ\varphi(m)=m \theta

得到二维情况下用复数表示的 RoPE:

f(q,m)=Rf(q,m)eiΘf(q,m)=qei(Θ(q)+mθ)=qeimθ\boldsymbol{f}(\boldsymbol{q}, m)=R_{f}(\boldsymbol{q}, m) e^{\mathrm{i} \Theta_{f}(\boldsymbol{q}, m)}=\|q\| e^{\mathrm{i}(\Theta(\boldsymbol{q})+m \theta)}=\boldsymbol{q} e^{\mathrm{i} m \theta}

按照 eimθ=cos(mθ)+isin(mθ)e^{im\theta}=cos(m\theta)+isin(m\theta) 展开后的矩阵形式是:

f(q,m)=(cosmθsinmθsinmθcosmθ)(q0q1)=RΘd(q0q1)\left.\boldsymbol{f}(\boldsymbol{q},m)=\left(\begin{array}{cc}\cos m\theta&-\sin m\theta\\\sin m\theta&\cos m\theta\end{array}\right.\right)\left(\begin{array}{c}q_0\\q_1\end{array}\right)\\=\mathbf{R}_{\boldsymbol{\Theta}}^d\left(\begin{array}{c}q_0\\q_1\end{array}\right)

RΘd\mathbf{R}^d_{\mathbf{\Theta}}d=2d=2,拓展到多维:

其中:

Θ={θi=100002i/d,i[0,1,...,d/21]}\mathbb{\Theta}=\{\theta_i=10000^{-2i/d},i\in[0,1,...,d/2-1]\}

总结来说,RoPE 的 self-attention 操作的流程是:对于 token 序列中的每个词嵌入向量,首先计算其对应的 query 和 key 向量,然后对每个 token 位置都计算对应的旋转位置编码,接着对每个 token 位置的 query 和 key 向量的元素按照两两一组应用旋转变换,最后再计算 query 和 key 之间的内积得到 self-attention 的计算结果。

公式最后还会采用三角式一样的远程衰减,来增加周期性函数外推位置差异性。

(Wmq)(Wnk)=Re[i=0d/21q[2i:2i+1]k[2i:2i+1]ei(mn)θi]\left(\boldsymbol{W}_{m} \boldsymbol{q}\right)^{\top}\left(\boldsymbol{W}_{n} \boldsymbol{k}\right)=\operatorname{Re}\left[\sum_{i=0}^{d / 2-1} \boldsymbol{q}_{[2 i: 2 i+1]} \boldsymbol{k}_{[2 i: 2 i+1]}^{*} e^{\mathrm{i}(m-n) \theta_{i}}\right]

形式上是绝对位置的编码,内积是相对位置编码

Transformer

图中概述了Transformer的架构。从宏观角度来看,Transformer的编码器是由多个相同的层叠加而成的,每个层都有两个子层(子层表示为 sublayer\mathrm{sublayer} )。第一个子层是多头自注意力(multi-head self-attention)汇聚;第二个子层是基于位置的前馈网络(positionwise feed-forward network)。具体来说,在计算编码器的自注意力时,查询、键和值都来自前一个编码器层的输出。受残差网络的启发,每个子层都采用了残差连接(residual connection)。在Transformer中,对于序列中任何位置的任何输入 xRd\mathbf{x} \in \mathbb{R}^d ,都要求满足 sublayer(x)Rd\mathrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d ,以便残差连接满足 x+sublayer(x)Rd\mathbf{x} + \mathrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d 。在残差连接的加法计算之后,紧接着应用层规范化(layer normalization) 。因此,输入序列对应的每个位置,Transformer编码器都将输出一个 dd 维表示向量。

Transformer解码器也是由多个相同的层叠加而成的,并且层中使用了残差连接和层规范化。除了编码器中描述的两个子层之外,解码器还在这两个子层之间插入了第三个子层,称为编码器-解码器注意力(encoder-decoder attention)层。在编码器-解码器注意力中,查询来自前一个解码器层的输出,而键和值来自整个编码器的输出。在解码器自注意力中,查询、键和值都来自上一个解码器层的输出。但是,解码器中的每个位置只能考虑该位置之前的所有位置。这种掩蔽(masked)注意力保留了自回归(auto-regressive)属性,确保预测仅依赖于已生成的输出词元。

Reference

https://zh.d2l.ai/chapter_attention-mechanisms/index.html

https://0809zheng.github.io/2022/07/01/posencode.html