FL Configurations¶
In this notebook, we will showcase how to load and set configurations for federated learning (FL) server and clients in order to launch FL experiments.
Load Configuration from a YAML File¶
APPFL employes OmegaConf package, a hierarchical configuration system, for loading configurations for FL server and clients from YAML files.
For example, examples/resources/configs/mnist/server_fedavg.yaml
contains the server configurations for an FL experiment on the MNIST dataset using the FedAvg
server aggregation algorithm.
As shown below, the configuration file is primarily composed of two main
client_configs
server_configs
Does it look a bit confusing that the server configuration file also contains client_configs
at the initial look? This is because, in federated learning, we usually want certain client-side configurations to be the same among all the clients, for example, the local trainer and its corresponding hyperparameters, the ML model architecture, and the compression settings. Therefore, it becomes much more convenient to first specify all those settings and configurations on the server side to ensure
uniformity, and then send those configurations to all clients at the beginning of the FL experiment.
[1]:
from omegaconf import OmegaConf
server_config_file = "../../examples/resources/configs/mnist/server_fedavg.yaml"
server_config = OmegaConf.load(server_config_file)
print("============Level one configuration fields============")
for key in server_config:
print(key)
print("============Detailed server configurations============")
print(OmegaConf.to_yaml(server_config))
print("======================================================")
============Level one configuration fields============
client_configs
server_configs
============Detailed server configurations============
client_configs:
train_configs:
trainer: VanillaTrainer
mode: step
num_local_steps: 100
optim: Adam
optim_args:
lr: 0.001
loss_fn_path: ./resources/loss/celoss.py
loss_fn_name: CELoss
do_validation: true
do_pre_validation: true
metric_path: ./resources/metric/acc.py
metric_name: accuracy
use_dp: false
epsilon: 1
clip_grad: false
clip_value: 1
clip_norm: 1
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: true
val_data_shuffle: false
model_configs:
model_path: ./resources/model/cnn.py
model_name: CNN
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28
comm_configs:
compressor_configs:
enable_compression: false
lossy_compressor: SZ2Compressor
lossless_compressor: blosc
error_bounding_mode: REL
error_bound: 0.001
param_cutoff: 1024
server_configs:
num_clients: 2
scheduler: SyncScheduler
scheduler_kwargs:
same_init_model: true
aggregator: FedAvgAggregator
aggregator_kwargs:
client_weights_mode: equal
device: cpu
num_global_epochs: 10
logging_output_dirname: ./output
logging_output_filename: result
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: false
======================================================
Client Configurations¶
For client configurations that are shared among all clients, it is composed of three main components:
train_configs
: This component contains all training-related configurations, which can be further classified into the following sub-components:Trainer configurations: It should be noted that the required trainer configurations depend on the trainer you use. You can also define your own trainer with any additional configurations you need, and then provide those configurations under
client_config.train_configs
in the server configuration yaml file.trainer
: The class name of the trainer you would like to use for client local training. The trainer name should be defined insrc/appfl/trainer
. For example,VanillaTrainer
simply updates the model for a certain number of epochs or batches.mode
: ForVanillaTrainer
, mode is a required configuration to with allowable valuesepoch
orstep
to specify whether you want to train for a certain number of epochs or only a certain number of steps/batches.num_local_steps
/num_local_epochs
: Number of steps (ifmode=step
) or epochs (ifmode=epoch
) for an FL client in each local training round.optim
: Name of the optimizer to use from thetorch.optim
module.optim_args
: Keyword arguments for the selected optimizer.do_validation
: Whether to perform client-side validation in each training round.do_pre_validation
: Whether to perform client-side validation prior to local training.use_dp
: Whether to use differential privacy.epsilon
,clip_grad
,clip_value
,clip_norm
: Parameters used if differential privacy is enabled.
Loss function: To specify the loss function to use during local training, we provide two options:
Loss function from
torch
: By providing the name of the loss function available intorch.nn
(e.g.,CrossEntropyLoss
) inloss_fn
and corresponding arguments inloss_fn_kwargs
, user can employ loss function available in PyTorch.Loss function defined in local file: User can define their own loss function by inheriting
nn.Module
and defining itsforward()
function. Then the user needs to provide the path to the defined loss function file inloss_fn_path
, and the class name of the defined loss function inloss_fn_name
.
Metric function: To specify the metric function used during validation, user need to provide path to the file containing the metric function in
metric_path
and the name of the metric function inmetric_name
.Dataloader settings: While the server-side configuration does not contain any information about each client’s local dataset, it can specify the configurations when converting the dataset to dataloader, such as the batch size and whether to shuffle.
model_configs
: This component contains the definition of the machine learning model used in the FL experiment. The model architecture should be defined as atorch.nn.Module
in a local file on the server-side and then provides the following information:model_path
: Path to the model definition file.model_name
: Class name of the defined model.model_kwargs
: Keyword arguments for initiating a model.
comm_configs
: This component contains the settings for the communication between the FL server and clients, such as thecompression_configs
.
In addition to the client_configs
contained in the server configuration YAML file, each client also needs its own client configuration YAML file to specify client-specific configurations, such as the device to use and the way to load the private local dataset. Let’s first take a look at the client-specific configuration.
client_id
: This is the unique ID (among all clients in the FL experiments) used for logging purposes.train_configs
: This contains the device to use in training, and some logging configurations.data_configs
: This is the most important component in the client configuration file. It contains the path in client’s local machine to the file which defines how to load the local dataset (dataset_path
), the name of the function in the file to load the dataset (dataset_name
), and any keyword arguments if needed (dataset_kwargs
).comm_configs
: The client may also need to specify some communication settings in order to connect to the server. For example, if the experiment uses gRPC as the communication method, then the client needs to specify theserver_uri
,max_message_size
, anduse_ssl
to establish the connection to the server.
[2]:
client_config_file = "../../examples/resources/configs/mnist/client_1.yaml"
client_config = OmegaConf.load(client_config_file)
print(OmegaConf.to_yaml(client_config))
client_id: Client1
train_configs:
device: cpu
logging_output_dirname: ./output
logging_output_filename: result
data_configs:
dataset_path: ./resources/dataset/mnist_dataset.py
dataset_name: get_mnist
dataset_kwargs:
num_clients: 2
client_id: 0
partition_strategy: class_noniid
visualization: true
output_dirname: ./output
output_filename: visualization.pdf
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: false
Merging the general client configurations from the server configuration YAML file and the specific client configurations in the client configuration YAML file will give all we need for a client to conduct the FL experiment.
[3]:
general_client_config = server_config.client_configs
all_client_config = OmegaConf.merge(general_client_config, client_config)
print(OmegaConf.to_yaml(all_client_config))
train_configs:
trainer: VanillaTrainer
mode: step
num_local_steps: 100
optim: Adam
optim_args:
lr: 0.001
loss_fn_path: ./resources/loss/celoss.py
loss_fn_name: CELoss
do_validation: true
do_pre_validation: true
metric_path: ./resources/metric/acc.py
metric_name: accuracy
use_dp: false
epsilon: 1
clip_grad: false
clip_value: 1
clip_norm: 1
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: true
val_data_shuffle: false
device: cpu
logging_output_dirname: ./output
logging_output_filename: result
model_configs:
model_path: ./resources/model/cnn.py
model_name: CNN
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28
comm_configs:
compressor_configs:
enable_compression: false
lossy_compressor: SZ2Compressor
lossless_compressor: blosc
error_bounding_mode: REL
error_bound: 0.001
param_cutoff: 1024
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: false
client_id: Client1
data_configs:
dataset_path: ./resources/dataset/mnist_dataset.py
dataset_name: get_mnist
dataset_kwargs:
num_clients: 2
client_id: 0
partition_strategy: class_noniid
visualization: true
output_dirname: ./output
output_filename: visualization.pdf
Server Configurations¶
Now, let’s take a look at what is needed for the configurations of FL server. Specifically, it contains the following key components:
Scheduler configurations: User can specify the name of the scheduler (
scheduler
), and the corresponding keyword arguments (scheduler_kwargs
). All supported schedulers are available atsrc/appfl/scheduler
.Aggregator configurations: User can specify the name of the aggregator (
aggregator
), and the corresponding keyword arguments (aggregator_kwargs
). All supported aggregators are available atsrc/appfl/aggregator
.
[4]:
print(OmegaConf.to_yaml(server_config.server_configs))
num_clients: 2
scheduler: SyncScheduler
scheduler_kwargs:
same_init_model: true
aggregator: FedAvgAggregator
aggregator_kwargs:
client_weights_mode: equal
device: cpu
num_global_epochs: 10
logging_output_dirname: ./output
logging_output_filename: result
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: false