Zero-Shot Text-to-Image Generation
smaller reconstruction error at the end of training.
4
2.2. Stage Two: Learning the Prior
In the second stage, we fix
φ
and
θ
, and learn the prior
distribution over the text and image tokens by maximizing
the ELB with respect to
ψ
. Here,
p
ψ
is represented by a
12-billion parameter sparse transformer (Child et al., 2019).
Given a text-image pair, we BPE-encode (Sennrich et al.,
2015) the lowercased caption using at most 256 tokens
5
with vocabulary size
16,384
, and encode the image using
32 × 32 = 1024
tokens with vocabulary size
8192
. The
image tokens are obtained using argmax sampling from the
dVAE encoder logits, without adding any gumbel noise.
6
Finally, the text and image tokens are concatenated and
modeled autoregressively as a single stream of data.
The transformer is a decoder-only model in which each im-
age token can attend to all text tokens in any one of its 64
self-attention layers. The full architecture is described in Ap-
pendix B.1. There are three different kinds of self-attention
masks used in the model. The part of the attention masks
corresponding to the text-to-text attention is the standard
causal mask, and the part for the image-to-image attention
uses either a row, column, or convolutional attention mask.
7
We limit the length of a text caption to 256 tokens, though
it is not totally clear what to do for the “padding” positions
in between the last text token and the start-of-image token.
One option is to set the logits for these tokens to
−∞
in the
self-attention operations. Instead, we opt to learn a special
padding token separately for each of the 256 text positions.
This token is used only when no text token is available. In
preliminary experiments on Conceptual Captions (Sharma
et al., 2018), we found that this resulted in higher validation
loss, but better performance on out-of-distribution captions.
We normalize the cross-entropy losses for the text and image
4
This is contrary to the usual tradeoff between the two terms.
We speculate that for smaller values of
β
, the noise from the
relaxation causes the optimizer to reduce codebook usage toward
the beginning of training, resulting in worse ELB at convergence.
5
During training, we apply 10% BPE dropout (Provilkov et al.,
2019), whose use is common in the neural machine translation
literature.
6
Strictly speaking, Equation 1 requires us to sample from
the categorical distribution specified by the dVAE encoder log-
its, rather than taking the argmax. In preliminary experiments on
ImageNet, we found that this was a useful regularizer in the overpa-
rameterized regime, and allows the transformer to be trained using
soft targets for the cross-entropy loss. We decided against this
here since the model in consideration is in the underparameterized
regime.
7
We found using a single attention operation for all three inter-
actions – “text attends to text”, “image attends to text”, and “image
attends to image” – to perform better than using separate attention
operations that are independently normalized.
Figure 4.
Illustration of per-resblock gradient scaling for a trans-
former resblock. The solid line indicates the sequence of opera-
tions for forward propagation, and the dashed line the sequence of
operations for backpropagation. We scale the incoming gradient
for each resblock by its gradient scale, and unscale the outgoing
gradient before it is added to the sum of the gradients from the suc-
cessive resblocks. The activations and gradients along the identity
path are stored in 32-bit precision. The “filter” operation sets all
Inf and NaN values in the activation gradient to zero. Without this,
a nonfinite event in the current resblock would cause the gradient
scales for all preceding resblocks to unnecessarily drop, thereby
resulting in underflow.
tokens by the total number of each kind in a batch of data.
Since we are primarily interested in image modeling, we
multiply the cross-entropy loss for the text by
1/8
and the
cross-entropy loss for the image by
7/8
. The objective is
optimized using Adam with exponentially weighted iterate
averaging; Appendix B.2 describes the training procedure
in more detail. We reserved about
606,000
images for vali-
dation, and found no signs of overfitting at convergence.
2.3. Data Collection
Our preliminary experiments for models up to
1.2
billion pa-
rameters were carried out on Conceptual Captions, a dataset
of 3.3 million text-image pairs that was developed as an
extension to MS-COCO (Lin et al., 2014).
To scale up to
12
-billion parameters, we created a dataset of
a similar scale to JFT-300M (Sun et al., 2017) by collecting
250 million text-images pairs from the internet. This dataset
does not include MS-COCO, but does include Conceptual
Captions and a filtered subset of YFCC100M (Thomee et al.,
2016). As MS-COCO was created from the latter, our train-
ing data includes a fraction of the MS-COCO validation
images (but none of the captions). We control for this in the
quantitative results presented in Section 3 and find that it has
no appreciable bearing on the results. We provide further