type
Post
status
Published
date
Nov 21, 2024
slug
mini_rnn
summary
最近在看并行RNN相关的paper,发现很多都利用了Parallel Scanning算法。本文将从Parallel Scanning算法开始,介绍Bengio团队不久前发表的《Were RNNs All We Needed?》
tags
大模型
category
学习分享
icon
password
ㅤ | Info |
Paper | |
GitHub | |
个人博客地址 |
最近在看并行RNN相关的paper,发现很多都利用了Parallel Scanning算法。本文将从Parallel Scanning算法开始,介绍Bengio团队不久前发表的《Were RNNs All We Needed?》
1 Parallel Scanning算法介绍
首先来看定义。Parallel Scanning字面意思,就是对scan操作进行并行化,那么什么是
scan(扫描)操作呢?1.1 Scan的定义
1.1.1 inclusive scan
scan (inclusive scan)也称为all-prefix-sum,其定义如下:若给定:
- 有序集合(order set) ,
- 二元结合运算符(binary associative operation) ,并且 的单位元存在
输出一个order set,并满足
.
将满足上述规则的操作称为
scan。显然上式可以写成递归形式,时间复杂度为
注1:二元结合运算符作用于两个操作数返回一个结果,且运算满足结合率。常见的二元结合运算符包括加法()、乘法()、逻辑与()和逻辑或()等. 注2: 的单位元:若:,则称是运算 的单位元。例如,加法的单位元是0,乘法的单位元是1,向量点乘的单位元是单位向量。
1.1.2 exclusive scan
实践中,
scan另一种变体prescan(也叫exclusive scan)也经常用到,输入和scan一致,输出为:其递归形式为
inclusive scan与exclusive scan可以很方便的转化,
inclusive scan → exclusive scan,只需将输出序列向右移一个单位,并且在序列第一个元素填充单位元。
exclusive scan → inclusive scan,只需将输出序列向左移一个单位,并且用最后一个输入元素加上最后一个输出元素的结果填充最后一个元素。
1.1.3 例子: prefix sum
已知输入有序集合,二元结合运算符为加法,计算A在下的inclusive scan和exclusive scan
根据式1,易得inclusive scan的结果为:
根据式2,易得exclusive scan的结果为:
代码实现:
1.2 Parallel Scanning
前文所述基于递归式计算scan的算法称之为sequential algorithm,其计算复杂度为,并且无法并行化。那么如何并行化计算scan呢?
1.2.1 Kogge-Stone Parallel Scanning algorithm[2]
Kogge-Stone 并行扫描算法的基本计算流程如下图所示(从最底部往上看)总计分为个阶段,在每一个阶段并行计算(表示阶段, 从0开始取)。该方法的加法运算次数为多于顺序算法的,不考虑并行的情况下时间复杂度为。但在processor足够时,
Kogge-Stone 的时间复杂度为。python代码实现如下:
注意由于python原生的多线程存在GIL,无法利用多核优势,故使用numpy实现

1.2.2 Brent-Kung Parallel Scanning algorithm[3]
从上文中,
Kogge-Stone 算法虽然在并行的情况下将scan的时间复杂度从降到了,但Kogge-Stone 算法实际的计算量是比顺序执行多不少的。下面来看计算效率更高的Brent-Kung 算法。Kogge-Stone 算法分为两个阶段stage1: 上行阶段,计算reduce (up sweep)
上行阶段有 个阶段,每个阶段执行
算法流程:

下面来分析一下up sweep的时间复杂度
up sweep的计算量为
不做并行的时间复杂度为,并行时的时间复杂度为
python代码如下:
此处为了便于理解,第二个循环没有用并行
通过up sweep 我们可以得到reduce的结果,但无法得到完整的scan结果,需要继续进行down sweep。
stage2: 下行阶段(down sweep)
算法流程:
计算复杂度与up-sweep一致

python代码如下:
综上所述,我们详细介绍了
Kogge-Stone 算法,它分为up sweep和down sweep两个阶段,每个阶段的计算量为,不做并行的计算时间复杂度为:,并行时的计算复杂度为❓小练习
不妨尝试回答一下几个问题:
- 当输入序列的长度并不是2的N次幂,如何用
Brent-Kung算法进行并行?
- 如果系统的processor有限,此时的时间复杂度是多少?
2 并行RNN
通过上文的介绍我们可以用并行的方法计算递归式。那如何将其与RNN建立起联系呢?
先来回顾一下两个经典的RNN算法,1)LSTM, 2)GRU
2.1 经典RNN回顾
2.1.1 LSTM
LSTM引入记忆细胞C(t)来存储长期信息,解决传统RNN无法处理长程依赖的问题。并引入3个门(遗忘门、输入门、输出门)来控制新老信息的交互。

下面来详细看其计算流程:
给定
- 输入序列:
- 初始化隐藏状态
- 初始化记忆细胞
三个门的输出在0~1之间,通过点乘来控制信息的流入量。
2.1.2 GRU
GRU简化了LSTM的门控机制达到和LSTM类似的效果。GRU主要通过两个门(重置门、更新门)来控制信息的交互。

下面来详细看其计算流程:
给定
- 输入序列:
- 初始化隐藏状态
2.2 经典RNN并行化
2.2.1 理论基础
通过前文介绍,我们回顾了经典RNN的递归更新公式,但显然,无法直接沿用parallel scan算法进行并行
- 对于LSTM而言依赖上一个时间步的的计算,且其递归式的形式并非已知。
- 对于GRU而言,同样依赖上一个时间步的的计算,且其递归式的形式并非已知。
故他们都无法利用parallel scan算法进行并行化。
如何让LSTM,GRU能够使用parallel scan算法进行并行呢?
不考虑对以往时间步的依赖,LSTM,GRU的递归更新公式形如:
- 作者:莫叶何竹🍀
- 链接:http://www.myhz0606.com/article/mini_rnn
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章






