MASS: Masked Sequence to Sequence Pre-training for Language Generation
X
6
X
8
_ X
7
X
1
X
2
_ _ _ _ _ _ X
4
X
5
X
3
Encoder
Decoder
__
X
3
X
5
X
4
Attention
Figure 1. The encoder-decoder framework for our proposed MASS. The token “ ” represents the mask symbol [M].
accuracy on multiple language understanding tasks in the
GLUE benchmark (Wang et al., 2018) and SQuAD (Ra-
jpurkar et al., 2016).
There are also some works pre-training the encoder-decoder
model for language generation. Dai & Le (2015); Ra-
machandran et al. (2016) leverage a language model or
auto-encoder to pre-train the encoder and decoder. Their
improvements, although observed, are limited and not as
general and significant as the pre-training methods (e.g.,
BERT) for language understanding. Zhang & Zong (2016)
designed a sentence reordering task for pre-training, but
only for the encoder part of the encoder-decoder model.
Zoph et al. (2016); Firat et al. (2016) pre-train the model
on similar rich-resource language pairs and fine-tuned on
the target language pair, which relies on supervised data on
other language pairs. Recently, XLM (Lample & Conneau,
2019) pre-trained BERT-like models both for the encoder
and decoder, and achieved the previous state of the art re-
sults on unsupervised machine translation. However, the
encoder and decoder in XLM are pre-trained separately and
the encoder-decoder attention mechanism cannot be pre-
trained, which are sub-optimal for sequence to sequence
based language generation tasks.
Different from previous works, our proposed MASS is care-
fully designed to pre-train both the encoder and decoder
jointly using only unlabeled data, and can be applied to
most language generations tasks.
3. MASS
In this section, we first introduce the basic framework of
sequence to sequence learning, and then propose MASS
(MAsked Sequence to Sequence pre-training). We then
discuss the differences between MASS and previous pre-
training methods including the masked language modeling
in BERT and standard language modeling.
3.1. Sequence to Sequence Learning
We denote
(x, y) ∈ (X , Y)
as a sentence pair, where
x = (x
1
, x
2
, ..., x
m
)
is the source sentence with
m
to-
kens, and
y = (y
1
, y
2
, ..., y
n
)
is the target sentence with
n
tokens, and
X
and
Y
are the source and target do-
mains. A sequence to sequence model learns the param-
eter
θ
to estimate the conditional probability
P (y|x; θ)
,
and usually uses log likelihood as the objective function:
L(θ; (X , Y)) = Σ
(x,y )∈(X ,Y)
log P (y|x; θ)
. The condi-
tional probability
P (y|x; θ)
can be further factorized accord-
ing to the chain rule:
P (y|x; θ) =
Q
n
t=1
P (y
t
|y
<t
, x; θ)
,
where y
<t
is the proceeding tokens before position t.
A major approach to sequence to sequence learning is the
encoder-decoder framework: The encoder reads the source
sequence and generates a set of representations; the decoder
estimates the conditional probability of each target token
given the source representations and its preceding tokens.
Attention mechanism (Bahdanau et al., 2015a) is further
introduced between the encoder and decoder to find which
source representation to focus on when predicting the cur-
rent token.
3.2. Masked Sequence to Sequence Pre-training
We introduce a novel unsupervised prediction task in this
section. Given an unpaired source sentence
x ∈ X
, we
denote x
\u:v
as a modified version of x where its fragment
from position
u
to
v
are masked,
0 < u < v < m
and
m
is
the number of tokens of sentence
x
. We denote
k = v−u+1
as the number of tokens being masked from position
u
to
v
. We replace each masked token by a special symbol
[M]
,
and the length of the masked sentence is not changed.
x
u:v
denotes the sentence fragment of x from u to v.
MASS pre-trains a sequence to sequence model by predict-
ing the sentence fragment
x
u:v
taking the masked sequence
x
\u:v
as input. We also use the log likelihood as the objec-
tive function:
L(θ; X ) =
1
|X |
Σ
x∈X
log P (x
u:v
|x
\u:v
; θ)
=
1
|X |
Σ
x∈X
log
v
Y
t=u
P (x
u:v
t
|x
u:v
<t
, x
\u:v
; θ).
(1)
We show an example in Figure 1, where the input sequence
has 8 tokens with the fragment
x
3
x
4
x
5
x
6
being masked.
Note that the model only predicts the masked fragment
x
3
x
4
x
5
x
6
, given
x
3
x
4
x
5
as the decoder input for position
4 − 6
, and the decoder takes the special mask symbol
[M]
as inputs for the other positions (e.g., position
1 − 3
and