for the input. GN here indicates replacing all Batch normalization layers in ResNet-50 with Group
Normalization. The patch embedding is applied to
1 × 1
patches extracted from the CNN feature
map instead of from raw images.
3.2 Federated Learning Methods
We apply one of the most popular parallel methods (FedAVG [
43
]) and serial methods (CWT [
7
]) as
training algorithms (see schematic descriptions in Figure 1).
Federated Averaging.
FedAVG combines local stochastic gradient descent (SGD) on each client
with iterative model averaging [
43
]. Specifically, a fraction of local clients are randomly sampled in
each communication round, and the server sends the current global model to each of these clients.
Each selected client then performs
E
epochs of local SGD on its local training data and sends the
local gradients back to the central server for aggregation synchronously. The server then applies the
averaged gradients to update its global model, and the process repeats.
Cyclic Weight Transfer.
In contrast to FedAVG where each local client is trained in a synchronous
and parallel way, the local clients in CWT are trained in a serial and cyclic way. In each round of
training, CWT trains a global model on one local client with its local data for a number of epochs
E
,
and then cyclically transfers this global model to the next client for training, until all the local clients
have been trained on once [
7
]. The training process then cycles through the clients repeatedly until
the model converges or a predefined number of communication rounds is reached.
4 Experiments
Our experiments are designed to answer the following research questions that are of importance to
practical deployment of FL methods, while also aiding our understanding of (vision) Transformer
architectures.
•
Are Transformers able to learn a better global model in FL settings as compared to CNNs which
have been the de-facto approach on FL tasks (section 4.2)?
• Are Transformers especially capable of handling heterogeneous data partitions (section 4.3.1)?
• Do Transformers reduce communication costs as compared to CNNs (section 4.3.2)?
• What are practical tips helpful for practitioners to deploy Transformers in FL (section 4.4)?
Experimental code is included in the supplement and will be made public after blind review.
4.1 Experimental Setup
Following [
7
,
20
], we evaluate different FL methods on the Kaggle Diabetic Retinopathy competition
dataset (denoted as Retina) [
26
] and CIFAR-
10
dataset [
29
] in our study. Specifically, we binarize the
labels in the Retina dataset to Healthy (positive) and Diseased (negative), randomly selecting
6, 000
balanced images for training,
3, 000
images as the global validation dataset, and
3, 000
images as
the global testing dataset following [
7
]. We use the original test set in CIFAR-
10
as the global test
dataset, set aside
5, 000
images from the original training dataset as the global validation dataset, and
use the remaining
45, 000
images as the training dataset. Detailed image pre-processing steps for
Retina and CIFAR-
10
dataset are shown in Appendix A.1. We simulate three sets of data partitions
for both Retina and CIFAR-
10
: one IID-data partition, and two non-IID data partitions with label
distribution skew. Each data partition in Retina and CIFAR-
10
dataset contains
4
and
5
simulated
clients, respectively. We use the mean Kolmogorov-Smirnov (KS) statistics between every two clients
to measure the degree of label distribution skewness.
KS = 0
indicates IID data partitions, while
KS = 1
results in an extremely non-IID data partition. The detailed data partitions are shown in
Appendix A.1.
We use linear learning rate warm-up and decay scheduler for VIT-FL. The learning rate scheduler for
FL with CNNs is selected from linear warm-up and decay or step decay. Gradient clipping (at global
norm 1) is applied to stabilize the training. We set local training epoch in all the FL methods to
1
,
and the total communication round to 100, unless otherwise stated. More implementation details are
shown in Appendix A.2.
4