How to simulate FL¶
We present step-by-step description of how to serially simulate the federated learning on MNIST data.
Installation¶
To this end, we first make sure that the required dependencies are installed. If not, uncomment the following cell and run it.
[1]:
# !pip install "appfl[examples]"
You can also install the package from the Github repository.
[2]:
# !git clone --single-branch --branch main https://github.com/APPFL/APPFL.git
# !cd APPFL
# !pip install -e ".[examples]"
Import dependencies¶
We put all the imports here. Our framework appfl
is backboned by torch
and its neural network model torch.nn
. We also import torchvision
to download the MNIST
dataset.
[3]:
import math
import torch
import torchvision
import numpy as np
import torch.nn as nn
import appfl.run_serial as ppfl
from omegaconf import OmegaConf
from appfl.config import Config
from appfl.misc.data import Dataset
from torchvision.transforms import ToTensor
Training datasets¶
Since this is a simulation of federated learning, we manually split the training datasets. Note, however, that this is not necessary in practice. In this example, we consider only two clients in the simulation. But, we can set num_clients
to a larger value for more clients.
[4]:
num_clients = 2
Each client needs to create Dataset
object with the training data. Here, we create the objects for all the clients.
[5]:
train_data_raw = torchvision.datasets.MNIST(
"./_data", train=True, download=True, transform=ToTensor()
)
split_train_data_raw = np.array_split(range(len(train_data_raw)), num_clients)
train_datasets = []
for i in range(num_clients):
train_data_input = []
train_data_label = []
for idx in split_train_data_raw[i]:
train_data_input.append(train_data_raw[idx][0].tolist())
train_data_label.append(train_data_raw[idx][1])
train_datasets.append(
Dataset(
torch.FloatTensor(train_data_input),
torch.tensor(train_data_label),
)
)
Test dataset¶
The test data also needs to be wrapped in Dataset
object.
[6]:
test_data_raw = torchvision.datasets.MNIST(
"./_data", train=False, download=False, transform=ToTensor()
)
test_data_input = []
test_data_label = []
for idx in range(len(test_data_raw)):
test_data_input.append(test_data_raw[idx][0].tolist())
test_data_label.append(test_data_raw[idx][1])
test_dataset = Dataset(
torch.FloatTensor(test_data_input), torch.tensor(test_data_label)
)
Model¶
Users can define their own models by deriving torch.nn.Module
. For example in this simulation, we define the following convolutional neural network. The loss function is set to be torch.nn.CrossEntropyLoss()
. We also define our own evaluation metric function.
[7]:
class CNN(nn.Module):
def __init__(self, num_channel=1, num_classes=10, num_pixel=28):
super().__init__()
self.conv1 = nn.Conv2d(
num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True
)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)
self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.act = nn.ReLU(inplace=True)
X = num_pixel
X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
X = X / 2
X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
X = X / 2
X = int(X)
self.fc1 = nn.Linear(64 * X * X, 512)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = self.act(self.conv1(x))
x = self.maxpool(x)
x = self.act(self.conv2(x))
x = self.maxpool(x)
x = torch.flatten(x, 1)
x = self.act(self.fc1(x))
x = self.fc2(x)
return x
model = CNN()
loss_fn = torch.nn.CrossEntropyLoss()
def accuracy(y_true, y_pred):
"""
y_true and y_pred are both of type np.ndarray
y_true (N, d) where N is the size of the validation set, and d is the dimension of the label
y_pred (N, D) where N is the size of the validation set, and D is the output dimension of the ML model
"""
if len(y_pred.shape) == 1:
y_pred = np.round(y_pred)
else:
y_pred = y_pred.argmax(axis=1)
return 100 * np.sum(y_pred == y_true) / y_pred.shape[0]
Training configurations¶
All the FL training-related configurations are stored in a appfl.config.Config
class, you can print its value and modify the parameters according to your need in the following way.
[8]:
cfg = OmegaConf.structured(Config)
cfg.num_clients = num_clients
print(OmegaConf.to_yaml(cfg))
# Change the number of local epochs to 1, number of global epochs to 5, and learning rate to 0.01
cfg.fed.args.num_local_epochs = 1
cfg.num_epochs = 5
cfg.fed.args.optim_args.lr = 0.01
fed:
type: federated
servername: ServerFedAvg
clientname: ClientOptim
args:
server_learning_rate: 0.01
server_adapt_param: 0.001
server_momentum_param_1: 0.9
server_momentum_param_2: 0.99
optim: SGD
num_local_epochs: 10
optim_args:
lr: 0.001
use_dp: false
epsilon: 1
clip_grad: false
clip_value: 1
clip_norm: 1
device: cpu
device_server: cpu
num_clients: 2
num_epochs: 2
num_workers: 0
batch_training: true
train_data_batch_size: 64
train_data_shuffle: true
validation: true
test_data_batch_size: 64
test_data_shuffle: false
data_sanity: false
reproduce: true
pca_dir: ''
params_start: 0
params_end: 49
ncomponents: 40
use_tensorboard: false
load_model: false
load_model_dirname: ''
load_model_filename: ''
save_model: false
save_model_dirname: ''
save_model_filename: ''
checkpoints_interval: 2
save_model_state_dict: false
send_final_model: false
output_dirname: output
output_filename: result
logginginfo: {}
summary_file: ''
personalization: false
p_layers: []
config_name: ''
max_message_size: 104857600
operator:
id: 1
server:
id: 1
host: localhost
port: 50051
use_tls: false
api_key: null
client:
id: 1
enable_compression: false
lossy_compressor: SZ2
lossless_compressor: blosc
compressor_sz2_path: ../.compressor/SZ/build/sz/libSZ.dylib
compressor_sz3_path: ../.compressor/SZ3/build/tools/sz3c/libSZ3c.dylib
compressor_szx_path: ../.compressor/SZx-main/build/lib/libSZx.dylib
error_bounding_mode: ''
error_bound: 0.0
flat_model_dtype: np.float32
param_cutoff: 1024
Run experiments¶
We can now start training with the configuration cfg
, defined model, loss function, training and test datasets, dataset name, and evaluation metric.
[9]:
ppfl.run_serial(cfg, model, loss_fn, train_datasets, test_dataset, "MNIST", accuracy)
Iter Local(s) Global(s) Valid(s) PerIter(s) Elapsed(s) TestLoss TestAccu
1 28.19 0.00 2.49 30.68 30.68 0.5289 86.02
2 27.84 0.00 2.46 30.31 60.99 0.2948 91.33
3 28.21 0.00 2.50 30.72 91.71 0.1994 94.11
4 27.95 0.00 2.49 30.44 122.16 0.1596 95.38
5 28.03 0.00 2.52 30.55 152.71 0.1430 95.79
Device=cpu
#Processors=1
#Clients=2
Server=ServerFedAvg
Clients=ClientOptim
Comm_Rounds=5
Local_Rounds=1
DP_Eps=False
Clipping=False
Elapsed_time=152.71
BestAccuracy=95.79
client_learning_rate = 0.01