Serial FL SimulationΒΆ
In this notebook, we are going to show how to simulate an FL experiment on a single machine by having each client running serially. It should be noted that only simulating synchronous FL algorithms makes sense when the experiments running serially. We use 10 clients running serially in this example.
[1]:
num_clients = 10
Load server configurationsΒΆ
In the example, we are going to use the FedAvg
server aggregation algorithm and the MNIST dataset by loading the server configurations from examples/resources/configs/mnist/server_fedavg.yaml
.
[2]:
from omegaconf import OmegaConf
server_config_file = "../../examples/resources/configs/mnist/server_fedavg.yaml"
server_config = OmegaConf.load(server_config_file)
print(OmegaConf.to_yaml(server_config))
client_configs:
train_configs:
trainer: VanillaTrainer
mode: step
num_local_steps: 100
optim: Adam
optim_args:
lr: 0.001
loss_fn_path: ./resources/loss/celoss.py
loss_fn_name: CELoss
do_validation: true
do_pre_validation: true
metric_path: ./resources/metric/acc.py
metric_name: accuracy
use_dp: false
epsilon: 1
clip_grad: false
clip_value: 1
clip_norm: 1
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: true
val_data_shuffle: false
model_configs:
model_path: ./resources/model/cnn.py
model_name: CNN
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28
comm_configs:
compressor_configs:
enable_compression: false
lossy_compressor: SZ2Compressor
lossless_compressor: blosc
error_bounding_mode: REL
error_bound: 0.001
param_cutoff: 1024
server_configs:
num_clients: 2
scheduler: SyncScheduler
scheduler_kwargs:
same_init_model: true
aggregator: FedAvgAggregator
aggregator_kwargs:
client_weights_mode: equal
device: cpu
num_global_epochs: 10
logging_output_dirname: ./output
logging_output_filename: result
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: false
π‘ It should be noted that configuration fields such as loss_fn_path
, metric_path
, and model_path
are the paths to the corresponding files, so we need to change their relative paths now to make sure the paths point to the right files.
π‘ We also change the num_global_epochs
from 10 to 3.
β οΈ We also need change num_clients
in server_configs
to 10.
[3]:
server_config.client_configs.train_configs.loss_fn_path = (
"../../examples/resources/loss/celoss.py"
)
server_config.client_configs.train_configs.metric_path = (
"../../examples/resources/metric/acc.py"
)
server_config.client_configs.model_configs.model_path = (
"../../examples/resources/model/cnn.py"
)
server_config.server_configs.num_global_epochs = 3
server_config.server_configs.num_clients = num_clients
Load client configurationsΒΆ
In this example, we suppose that num_clients=10
and load the basic configurations for all the clients from examples/configs/mnist/client_1.yaml
. Letβs first take a look at this basic configuration.
[4]:
client_config_file = "../../examples/resources/configs/mnist/client_1.yaml"
client_config = OmegaConf.load(client_config_file)
print(OmegaConf.to_yaml(client_config))
client_id: Client1
train_configs:
device: cpu
logging_output_dirname: ./output
logging_output_filename: result
data_configs:
dataset_path: ./resources/dataset/mnist_dataset.py
dataset_name: get_mnist
dataset_kwargs:
num_clients: 2
client_id: 0
partition_strategy: class_noniid
visualization: true
output_dirname: ./output
output_filename: visualization.pdf
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: false
For the configuration above, it should be mentioned that data_configs
contains the necessary configurations to load the simulated βlocalβ datasets for each client. Specifically,
dataset_path
is the path to the file that contains the function to load the datasetdataset_name
is the function nation that loads the dataset in the above filedataset_kwargs
are the keyword arguments for that function
In the get_mnist
function above, it partitions the MNIST dataset into num_clients
client splits in an IID/non-IID (IID: independent identically distributed) manner.
Now we need to modify the general client configurations for different clients. Specifically, we make the following changes:
Change
client_id
for each clientChange the relative path of
dataset_path
to make it point to the right fileChange
dataset_kwargs.num_clients
to 10 anddataset_kwargs.client_id
to [0, 1, .., 9] for different clients.Change
dataset_kwargs.visualization
to False for nine clients to only have one data distribution visualization plots.
[5]:
import copy
client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]
for i in range(num_clients):
client_configs[i].client_id = f"Client{i + 1}"
client_configs[
i
].data_configs.dataset_path = "../../examples/resources/dataset/mnist_dataset.py"
client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients
client_configs[i].data_configs.dataset_kwargs.client_id = i
client_configs[i].data_configs.dataset_kwargs.visualization = (
True if i == 0 else False
)
Create FL server agent and client agentsΒΆ
In APPFL, we use agent to act on behalf of the FL server and FL clients to do necessary steps for the federated learning experiments. User can easily create the agents using the server/client configurations we loaded (and modified a little bit) from the configuration yaml file. Creating the client agents will load the local dataset and plot the data distribution visualization as shown below.
[6]:
from appfl.agent import ServerAgent, ClientAgent
server_agent = ServerAgent(server_agent_config=server_config)
client_agents = [
ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)
]
appfl: β
[2025-01-08 09:57:56,430 server]: Logging to ./output/result_Server_2025-01-08-09-57-56.txt
appfl: β
[2025-01-08 09:57:56,435 Client1]: Logging to ./output/result_Client1_2025-01-08-09-57-56.txt
appfl: β
[2025-01-08 09:58:03,678 Client2]: Logging to ./output/result_Client2_2025-01-08-09-58-03.txt
appfl: β
[2025-01-08 09:58:10,433 Client3]: Logging to ./output/result_Client3_2025-01-08-09-58-10.txt
appfl: β
[2025-01-08 09:58:16,842 Client4]: Logging to ./output/result_Client4_2025-01-08-09-58-16.txt
appfl: β
[2025-01-08 09:58:23,217 Client5]: Logging to ./output/result_Client5_2025-01-08-09-58-23.txt
appfl: β
[2025-01-08 09:58:29,646 Client6]: Logging to ./output/result_Client6_2025-01-08-09-58-29.txt
appfl: β
[2025-01-08 09:58:36,078 Client7]: Logging to ./output/result_Client7_2025-01-08-09-58-36.txt
appfl: β
[2025-01-08 09:58:42,554 Client8]: Logging to ./output/result_Client8_2025-01-08-09-58-42.txt
appfl: β
[2025-01-08 09:58:49,468 Client9]: Logging to ./output/result_Client9_2025-01-08-09-58-49.txt
appfl: β
[2025-01-08 09:58:56,215 Client10]: Logging to ./output/result_Client10_2025-01-08-09-58-56.txt
Start the training processΒΆ
The server configuration files contains many client configurations which should apply for ALL clients. Now, we need to get those configurations from the server and provide them to the client agents.
[7]:
# Get additional client configurations from the server
client_config_from_server = server_agent.get_client_configs()
for client_agent in client_agents:
client_agent.load_config(client_config_from_server)
Then, let the clients load initial global model from the server and optionally send the number of local data to the server for weighted aggregation.
π‘ Note: Typically, server_agent.get_parameters()
blocks the result return or returns a Future
object, and only returns the global model after receiving num_clients
calls to synchronoize to process for clients to get the initial model, ensuring all the clients have the same initial model weights. However, as we are doing serial simulation, we donβt want the blocking, so we pass serail_run=True
when calling the function to get the global model immediately.
[8]:
# Load initial global model from the server
init_global_model = server_agent.get_parameters(serial_run=True)
for client_agent in client_agents:
client_agent.load_parameters(init_global_model)
# [Optional] Set number of local data to the server
for i in range(num_clients):
sample_size = client_agents[i].get_sample_size()
server_agent.set_sample_size(
client_id=client_agents[i].get_id(), sample_size=sample_size
)
Now, we can start the training iterations. Please note the following points:
server_agent.training_finished
will return a boolean flag indicating whether the training has reached the specifiednum_global epochs
client_agent.train
trains the client local model using the clientβs βlocalβ dataserver_agent.global_update
takes one clientβs local model together with a client id (can be get byclient_agent.get_id
) to schedule the global update for the client local model. For synchronous server aggregation algorithms such asFedAvg
, they will not update the global model until receiving local models from allnum_clients=10
clients, so the call toglobal_update
will return aconcurrent.futures.Future
object (if you setblocking=False
, otherwise, it will block forever for serial simulation), which will be set after all client local models are sent for global update.In the output log,
Pre Val?
means whether it is validation prior to the local training. As each client only holds data from 3 to 5 classes, the validation accuracy even drops after local training. However, the global model accuracy continues to increase, showcasing the capabilities of federated learning in improving the generalizability of the trained machine learning model.
[9]:
while not server_agent.training_finished():
new_global_models = []
for client_agent in client_agents:
# Client local training
client_agent.train()
local_model = client_agent.get_parameters()
if isinstance(local_model, tuple):
local_model, metadata = local_model[0], local_model[1]
else:
metadata = {}
# "Send" local model to server and get a Future object for the new global model
# The Future object will be resolved when the server receives local models from all clients
new_global_model_future = server_agent.global_update(
client_id=client_agent.get_id(),
local_model=local_model,
blocking=False,
**metadata,
)
new_global_models.append(new_global_model_future)
# Load the new global model from the server
for client_agent, new_global_model_future in zip(client_agents, new_global_models):
client_agent.load_parameters(new_global_model_future.result())
appfl: β
[2025-01-08 09:59:02,820 Client1]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:03,825 Client1]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:06,651 Client1]: 0 N 2.8258 0.4323 90.8109 15.4995 30.3500
appfl: β
[2025-01-08 09:59:06,653 Client2]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:07,690 Client2]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:10,635 Client2]: 0 N 2.9442 0.3130 87.7656 9.9495 48.6000
appfl: β
[2025-01-08 09:59:10,636 Client3]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:11,665 Client3]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:14,515 Client3]: 0 N 2.8484 0.3909 85.0781 10.4302 38.2900
appfl: β
[2025-01-08 09:59:14,517 Client4]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:15,541 Client4]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:18,495 Client4]: 0 N 2.9528 0.4242 89.8377 11.0662 39.8000
appfl: β
[2025-01-08 09:59:18,496 Client5]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:19,528 Client5]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:22,589 Client5]: 0 N 3.0604 0.1851 92.1406 22.1558 28.8700
appfl: β
[2025-01-08 09:59:22,591 Client6]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:23,603 Client6]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:26,441 Client6]: 0 N 2.8377 0.6525 85.3739 14.9090 38.8200
appfl: β
[2025-01-08 09:59:26,443 Client7]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:27,462 Client7]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:30,243 Client7]: 0 N 2.7805 0.2848 86.4844 10.2574 48.8900
appfl: β
[2025-01-08 09:59:30,245 Client8]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:31,256 Client8]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:34,121 Client8]: 0 N 2.8636 0.3402 92.6434 14.2500 31.1400
appfl: β
[2025-01-08 09:59:34,123 Client9]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:35,132 Client9]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:37,892 Client9]: 0 N 2.7590 0.3460 86.6562 9.1354 50.0000
appfl: β
[2025-01-08 09:59:37,894 Client10]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy
appfl: β
[2025-01-08 09:59:38,885 Client10]: 0 Y 2.3006 15.9300
appfl: β
[2025-01-08 09:59:41,568 Client10]: 0 N 2.6825 0.2586 88.4375 12.8184 40.2200
appfl: β
[2025-01-08 09:59:42,569 Client1]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 09:59:45,282 Client1]: 1 N 2.7129 0.1185 98.0902 9.2444 30.4200
appfl: β
[2025-01-08 09:59:46,287 Client2]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 09:59:49,161 Client2]: 1 N 2.8722 0.1277 95.7188 8.0948 48.7200
appfl: β
[2025-01-08 09:59:50,222 Client3]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 09:59:52,936 Client3]: 1 N 2.7130 0.1673 94.5312 8.9735 38.5400
appfl: β
[2025-01-08 09:59:53,939 Client4]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 09:59:56,585 Client4]: 1 N 2.6448 0.1617 96.1872 9.9464 39.7100
appfl: β
[2025-01-08 09:59:57,592 Client5]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 10:00:00,295 Client5]: 1 N 2.7021 0.0906 97.3438 13.3640 29.5100
appfl: β
[2025-01-08 10:00:01,289 Client6]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 10:00:04,012 Client6]: 1 N 2.7218 0.1973 96.3711 9.5049 39.4300
appfl: β
[2025-01-08 10:00:05,028 Client7]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 10:00:07,715 Client7]: 1 N 2.6853 0.1087 95.3750 5.8535 49.2100
appfl: β
[2025-01-08 10:00:08,714 Client8]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 10:00:11,396 Client8]: 1 N 2.6800 0.0983 98.0466 9.4541 31.3000
appfl: β
[2025-01-08 10:00:12,399 Client9]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 10:00:15,078 Client9]: 1 N 2.6767 0.1480 94.8281 6.5887 49.2400
appfl: β
[2025-01-08 10:00:16,078 Client10]: 1 Y 1.8781 31.3900
appfl: β
[2025-01-08 10:00:18,718 Client10]: 1 N 2.6380 0.1009 96.3438 11.2293 40.7800
appfl: β
[2025-01-08 10:00:19,716 Client1]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:22,358 Client1]: 2 N 2.6401 0.0553 99.0451 6.5585 30.5100
appfl: β
[2025-01-08 10:00:23,347 Client2]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:26,005 Client2]: 2 N 2.6579 0.0708 97.4844 5.9560 49.2300
appfl: β
[2025-01-08 10:00:27,009 Client3]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:29,809 Client3]: 2 N 2.7990 0.0948 96.4219 7.6792 38.7600
appfl: β
[2025-01-08 10:00:30,845 Client4]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:33,755 Client4]: 2 N 2.9089 0.0759 98.0936 8.1204 40.0200
appfl: β
[2025-01-08 10:00:34,857 Client5]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:37,737 Client5]: 2 N 2.8792 0.0454 98.3750 8.8908 29.7500
appfl: β
[2025-01-08 10:00:38,732 Client6]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:41,369 Client6]: 2 N 2.6357 0.0995 98.0120 8.2203 39.4500
appfl: β
[2025-01-08 10:00:42,367 Client7]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:45,160 Client7]: 2 N 2.7908 0.0554 97.5781 4.9709 49.0400
appfl: β
[2025-01-08 10:00:46,179 Client8]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:48,859 Client8]: 2 N 2.6786 0.0540 98.8658 8.9072 31.2200
appfl: β
[2025-01-08 10:00:49,859 Client9]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:52,476 Client9]: 2 N 2.6149 0.0861 97.0000 5.5963 50.2700
appfl: β
[2025-01-08 10:00:53,481 Client10]: 2 Y 0.9154 68.4300
appfl: β
[2025-01-08 10:00:56,112 Client10]: 2 N 2.6278 0.0638 97.3594 8.1234 40.8700