位置编码是目前大模型所涉及到的基础概念之一,本文将从位置编码满足的条件,分析常见的两种编码方式:绝对位置编码和旋转位置编码,及其在多模态大模型中的直接拓展。

位置编码需要满足的条件

位置编码的根本目的是为序列中的每个token引入位置信息,否则自注意力机制对序列是「无序」的。一个好的位置编码方法需要满足以下条件:

  • 唯一性:每个位置有唯一表示,避免不同位置混淆。
  • 可区分性:相邻位置编码差异明显,模型能感知局部顺序。
  • 可组合性:能有效表达位置间的相对关系。
  • 泛化性:能处理比训练时更长的序列。

二进制编码

其实,根据上面的条件,很容易想到的一个方案就是二进制编码。显而易见,二进制编码一定是唯一的。但是不具有可区分性,泛化性和可组合性。

二进制编码的可区分性

上面已经论述过,对于二进制编码,其是one-to-one的映射,每个位置都会有一个唯一的二进制串。因此在理论上,它不会出现两个位置相同的编码,完全满足可区分性。但在向量空间的相似性上,它并不理想:比如 7 = 0111 和 8 = 1000 的二进制差异非常大,汉明距离为4,实际上它们相邻,但编码却完全不相似。这意味着“局部邻近性”缺失,不利于捕捉局部顺序信息。

二进制编码的可组合性

在二进制编码中,位置差必须依赖模型学会“二进制到整数”的解码,再做减法。例如 0101 (5) 和 1000 (8),模型需要先恢复数值才能得出差=3。

二进制编码的泛化性

训练时序列的最长长度为$L$,那把位置 $pos$ 转成二进制串所需的比特数为: $$ \lceil \log_2(L) \rceil $$

比如,训练时最大序列长度为 256,那么会使用8位二进制。在测试时,假设位置 $=300$:

  • 300 的二进制是 $100101100$,需要 9 位。

但训练时 embedding 只有 8 维(每一位有一个 embedding),此时二进制编码体系不够用了。在训练长度之外会直接“溢出”,不能自然推广。

绝对位置编码

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

$PE(pos) = [\sin(\omega_0 \cdot pos), \cos(\omega_0 \cdot pos), \sin(\omega_1 \cdot pos), \cos(\omega_1 \cdot pos), \dots, \sin(\omega_{d_{model}/2-1} \cdot pos), \cos(\omega_{d_{model}/2-1} \cdot pos)]$

其中,每个维度对的角频率 $\omega_i$ 的定义为:

$\omega_i = \frac{1}{10000^{2i/d_{model}}}$

这个行向量清晰地展示了位置编码的结构:

  • 偶数索引的维度 使用 $\sin(\cdot)$ 函数。
  • 奇数索引的维度 使用 $\cos(\cdot)$ 函数。

每一对 $\sin(\cdot)$ 和 $\cos(\cdot)$ 函数都共享一个独特的、随着维度 $i$ 的增加而逐渐减小的频率,从而为序列中的每个位置提供了独一无二的编码。

首先,需要考虑频率问题,高频在前还是低频在前。高频代表的是主要负责区分相邻词元的位置(高频变化快)。低频 主要负责区分全局位置(低频变化慢)。 关于频率问题,有两点需要注意。

绝对位置编码的可区分性

import numpy as np
import matplotlib.pyplot as plt

def positional_encoding_correct(pos, d_model=2, base=10000):
    encoding = np.zeros(d_model)
    
    for i in range(d_model // 2):
        # 计算频率项,与i成正比
        frequency_term = pos / (base ** ((2 * i) / d_model))
        
        # 填充偶数维度 (2*i)
        encoding[2 * i] = np.sin(frequency_term)
        
        # 填充奇数维度 (2*i + 1)
        encoding[2 * i + 1] = np.cos(frequency_term)
        
    return encoding

# 重新设定基数
bases = [1, 5, 10, 100]
positions = np.arange(5) # 前5个位置

plt.figure(figsize=(12, 12))
for idx, base in enumerate(bases):
    encodings = np.array([positional_encoding(pos, d_model=2, base=base) for pos in positions])
    
    plt.subplot(2, 2, idx+1)
    plt.scatter(encodings[:,0], encodings[:,1], c=positions, cmap="viridis", s=100, marker="o")
    for i, (x, y) in enumerate(encodings):
        plt.text(x+0.02, y+0.02, f"pos {i}", fontsize=10, weight="bold")
    plt.title(f"Positional Encoding (base={base})")
    plt.xlabel("sin component")
    plt.ylabel("cos component")
    plt.axis("equal")

plt.tight_layout()
plt.show()

位置编码示意图

相对距离

$$ PE_p \cdot PE_q = \sum_i \left[ \sin(p\omega_i)\sin(q\omega_i) + \cos(p\omega_i)\cos(q\omega_i) \right] = \sum_i \cos\left(\omega_i (p - q)\right) $$