type
status
date
slug
summary
tags
category
icon
password
Info
Paper
GitHub
个人博客地址
最近在看并行RNN相关的paper,发现很多都利用了Parallel Scanning算法。本文将从Parallel Scanning算法开始,介绍Bengio团队不久前发表的《Were RNNs All We Needed?》

1 Parallel Scanning算法介绍

首先来看定义。Parallel Scanning字面意思,就是对scan操作进行并行化,那么什么是scan(扫描)操作呢?

1.1 Scan的定义

1.1.1 inclusive scan

scan (inclusive scan)也称为all-prefix-sum,其定义如下:
若给定:
  • 有序集合(order set) ,
  • 二元结合运算符(binary associative operation) ,并且 的单位元存在
输出一个order set,并满足
.
将满足上述规则的操作称为scan
显然上式可以写成递归形式,时间复杂度为
注1:二元结合运算符作用于两个操作数返回一个结果,且运算满足结合率。常见的二元结合运算符包括加法()、乘法()、逻辑与()和逻辑或()等. 注2: 的单位元:若:,则称是运算 的单位元。例如,加法的单位元是0,乘法的单位元是1,向量点乘的单位元是单位向量。

1.1.2 exclusive scan

实践中,scan另一种变体prescan(也叫exclusive scan)也经常用到,输入和scan一致,输出为:
其递归形式为
inclusive scan与exclusive scan可以很方便的转化,
inclusive scan → exclusive scan,只需将输出序列向右移一个单位,并且在序列第一个元素填充单位元。
exclusive scan → inclusive scan,只需将输出序列向左移一个单位,并且用最后一个输入元素加上最后一个输出元素的结果填充最后一个元素。

1.1.3 例子: prefix sum

已知输入有序集合,二元结合运算符为加法,计算A在下的inclusive scan和exclusive scan
根据式1,易得inclusive scan的结果为:
根据式2,易得exclusive scan的结果为:
代码实现:

1.2 Parallel Scanning

前文所述基于递归式计算scan的算法称之为sequential algorithm,其计算复杂度为,并且无法并行化。那么如何并行化计算scan呢?

1.2.1 Kogge-Stone Parallel Scanning algorithm[2]

Kogge-Stone 并行扫描算法的基本计算流程如下图所示(从最底部往上看)
总计分为个阶段,在每一个阶段并行计算(表示阶段, 从0开始取)。该方法的加法运算次数为多于顺序算法的,不考虑并行的情况下时间复杂度为。但在processor足够时,Kogge-Stone 的时间复杂度为
python代码实现如下:
注意由于python原生的多线程存在GIL,无法利用多核优势,故使用numpy实现
notion image

1.2.2 Brent-Kung Parallel Scanning algorithm[3]

从上文中,Kogge-Stone 算法虽然在并行的情况下将scan的时间复杂度从降到了,但Kogge-Stone 算法实际的计算量是比顺序执行多不少的。下面来看计算效率更高的Brent-Kung 算法。
Kogge-Stone 算法分为两个阶段
stage1: 上行阶段,计算reduce (up sweep)
上行阶段有 个阶段,每个阶段执行
算法流程:
notion image
下面来分析一下up sweep的时间复杂度
up sweep的计算量为
不做并行的时间复杂度为,并行时的时间复杂度为
python代码如下:
此处为了便于理解,第二个循环没有用并行
通过up sweep 我们可以得到reduce的结果,但无法得到完整的scan结果,需要继续进行down sweep。
stage2: 下行阶段(down sweep)
算法流程:
计算复杂度与up-sweep一致
notion image
python代码如下:
综上所述,我们详细介绍了Kogge-Stone 算法,它分为up sweep和down sweep两个阶段,每个阶段的计算量为,不做并行的计算时间复杂度为:,并行时的计算复杂度为

❓小练习

不妨尝试回答一下几个问题:
  1. 当输入序列的长度并不是2的N次幂,如何用 Brent-Kung 算法进行并行?
  1. 如果系统的processor有限,此时的时间复杂度是多少?

2 并行RNN

通过上文的介绍我们可以用并行的方法计算递归式。那如何将其与RNN建立起联系呢?
先来回顾一下两个经典的RNN算法,1)LSTM, 2)GRU

2.1 经典RNN回顾

2.1.1 LSTM

LSTM引入记忆细胞C(t)来存储长期信息,解决传统RNN无法处理长程依赖的问题。并引入3个门(遗忘门、输入门、输出门)来控制新老信息的交互。
notion image
下面来详细看其计算流程:
给定
  • 输入序列:
  • 初始化隐藏状态
  • 初始化记忆细胞
三个门的输出在0~1之间,通过点乘来控制信息的流入量。

2.1.2 GRU

GRU简化了LSTM的门控机制达到和LSTM类似的效果。GRU主要通过两个门(重置门、更新门)来控制信息的交互。
notion image
下面来详细看其计算流程:
给定
  • 输入序列:
  • 初始化隐藏状态

2.2 经典RNN并行化

2.2.1 理论基础

通过前文介绍,我们回顾了经典RNN的递归更新公式,但显然,无法直接沿用parallel scan算法进行并行
递归更新公式
LSTM
GRU
  • 对于LSTM而言依赖上一个时间步的的计算,且其递归式的形式并非已知。
  • 对于GRU而言,同样依赖上一个时间步的的计算,且其递归式的形式并非已知。
故他们都无法利用parallel scan算法进行并行化。
如何让LSTM,GRU能够使用parallel scan算法进行并行呢?
不考虑对以往时间步的依赖,LSTM,GRU的递归更新公式形如:
已知。这个式子和标准的scan多了一个偏置项。文献[6]指出,只需对式6进行适当变形,即可用两次parallel scan算法对式6进行并行计算。
推导前,不妨将式(6)简写为:
通过归纳,不难得出
对上式子两边取对数,有
从上述递归式可以看出,有两处可以用两次parallel scan算法
第一次parallel scan计算有序集合
第二次parallel scan计算有序集合
有了他们,我们可以并行计算有序集合
下面来看,如何将LSTM,GRU转变为式(6)的形式

2.2.2 LSTM的并行化

Step 1: Drop previous hidden state dependencies from gates
Step 2: Drop range restriction of candidate states
Step 3: Ensure output is time-independent in scale
通过上述的操作,结合文献[6]的技巧(式9)完成LSTM的并行化。

2.2.3 GRU的并行化

GRU的并行化的操作和LSTM类似
Step 1: Drop previous hidden state dependencies from gates
Step 2: Drop range restriction of candidate states

3 小结

本文从parallel scan算法出发,介绍了如何将经典RNN算法——LSTM,GRU进行变换,使其能够并行化。实验结果本文不做介绍,请参考原论文。

Reference:

[1] Prefix Sums and Their Applications
[3] A Regular Layout for Parallel Adders
[4] LONG SHORT-TERM MEMORY
[5] Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling
[6] Efficient Parallelization of a Ubiquitous Sequential Computation
diffusion model(十九) :SDE视角下的扩散模型SigLIP技术小结
  • Twikoo