Serial FL SimulationΒΆ

In this notebook, we are going to show how to simulate an FL experiment on a single machine by having each client running serially. It should be noted that only simulating synchronous FL algorithms makes sense when the experiments running serially. We use 10 clients running serially in this example.

[1]:
num_clients = 10

Load server configurationsΒΆ

In the example, we are going to use the FedAvg server aggregation algorithm and the MNIST dataset by loading the server configurations from examples/resources/configs/mnist/server_fedavg.yaml.

[2]:
from omegaconf import OmegaConf

server_config_file = "../../examples/resources/configs/mnist/server_fedavg.yaml"
server_config = OmegaConf.load(server_config_file)
print(OmegaConf.to_yaml(server_config))
client_configs:
  train_configs:
    trainer: VanillaTrainer
    mode: step
    num_local_steps: 100
    optim: Adam
    optim_args:
      lr: 0.001
    loss_fn_path: ./resources/loss/celoss.py
    loss_fn_name: CELoss
    do_validation: true
    do_pre_validation: true
    metric_path: ./resources/metric/acc.py
    metric_name: accuracy
    use_dp: false
    epsilon: 1
    clip_grad: false
    clip_value: 1
    clip_norm: 1
    train_batch_size: 64
    val_batch_size: 64
    train_data_shuffle: true
    val_data_shuffle: false
  model_configs:
    model_path: ./resources/model/cnn.py
    model_name: CNN
    model_kwargs:
      num_channel: 1
      num_classes: 10
      num_pixel: 28
  comm_configs:
    compressor_configs:
      enable_compression: false
      lossy_compressor: SZ2Compressor
      lossless_compressor: blosc
      error_bounding_mode: REL
      error_bound: 0.001
      param_cutoff: 1024
server_configs:
  num_clients: 2
  scheduler: SyncScheduler
  scheduler_kwargs:
    same_init_model: true
  aggregator: FedAvgAggregator
  aggregator_kwargs:
    client_weights_mode: equal
  device: cpu
  num_global_epochs: 10
  logging_output_dirname: ./output
  logging_output_filename: result
  comm_configs:
    grpc_configs:
      server_uri: localhost:50051
      max_message_size: 1048576
      use_ssl: false

πŸ’‘ It should be noted that configuration fields such as loss_fn_path, metric_path, and model_path are the paths to the corresponding files, so we need to change their relative paths now to make sure the paths point to the right files.

πŸ’‘ We also change the num_global_epochs from 10 to 3.

⚠️ We also need change num_clients in server_configs to 10.

[3]:
server_config.client_configs.train_configs.loss_fn_path = (
    "../../examples/resources/loss/celoss.py"
)
server_config.client_configs.train_configs.metric_path = (
    "../../examples/resources/metric/acc.py"
)
server_config.client_configs.model_configs.model_path = (
    "../../examples/resources/model/cnn.py"
)
server_config.server_configs.num_global_epochs = 3
server_config.server_configs.num_clients = num_clients

Load client configurationsΒΆ

In this example, we suppose that num_clients=10 and load the basic configurations for all the clients from examples/configs/mnist/client_1.yaml. Let’s first take a look at this basic configuration.

[4]:
client_config_file = "../../examples/resources/configs/mnist/client_1.yaml"
client_config = OmegaConf.load(client_config_file)
print(OmegaConf.to_yaml(client_config))
client_id: Client1
train_configs:
  device: cpu
  logging_output_dirname: ./output
  logging_output_filename: result
data_configs:
  dataset_path: ./resources/dataset/mnist_dataset.py
  dataset_name: get_mnist
  dataset_kwargs:
    num_clients: 2
    client_id: 0
    partition_strategy: class_noniid
    visualization: true
    output_dirname: ./output
    output_filename: visualization.pdf
comm_configs:
  grpc_configs:
    server_uri: localhost:50051
    max_message_size: 1048576
    use_ssl: false

For the configuration above, it should be mentioned that data_configs contains the necessary configurations to load the simulated β€œlocal” datasets for each client. Specifically,

  • dataset_path is the path to the file that contains the function to load the dataset

  • dataset_name is the function nation that loads the dataset in the above file

  • dataset_kwargs are the keyword arguments for that function

In the get_mnist function above, it partitions the MNIST dataset into num_clients client splits in an IID/non-IID (IID: independent identically distributed) manner.

Now we need to modify the general client configurations for different clients. Specifically, we make the following changes:

  • Change client_id for each client

  • Change the relative path of dataset_path to make it point to the right file

  • Change dataset_kwargs.num_clients to 10 and dataset_kwargs.client_id to [0, 1, .., 9] for different clients.

  • Change dataset_kwargs.visualization to False for nine clients to only have one data distribution visualization plots.

[5]:
import copy

client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]
for i in range(num_clients):
    client_configs[i].client_id = f"Client{i + 1}"
    client_configs[
        i
    ].data_configs.dataset_path = "../../examples/resources/dataset/mnist_dataset.py"
    client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients
    client_configs[i].data_configs.dataset_kwargs.client_id = i
    client_configs[i].data_configs.dataset_kwargs.visualization = (
        True if i == 0 else False
    )

Create FL server agent and client agentsΒΆ

In APPFL, we use agent to act on behalf of the FL server and FL clients to do necessary steps for the federated learning experiments. User can easily create the agents using the server/client configurations we loaded (and modified a little bit) from the configuration yaml file. Creating the client agents will load the local dataset and plot the data distribution visualization as shown below.

[6]:
from appfl.agent import ServerAgent, ClientAgent

server_agent = ServerAgent(server_agent_config=server_config)
client_agents = [
    ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)
]
appfl: βœ…[2025-01-08 09:57:56,430 server]: Logging to ./output/result_Server_2025-01-08-09-57-56.txt
appfl: βœ…[2025-01-08 09:57:56,435 Client1]: Logging to ./output/result_Client1_2025-01-08-09-57-56.txt
appfl: βœ…[2025-01-08 09:58:03,678 Client2]: Logging to ./output/result_Client2_2025-01-08-09-58-03.txt
appfl: βœ…[2025-01-08 09:58:10,433 Client3]: Logging to ./output/result_Client3_2025-01-08-09-58-10.txt
appfl: βœ…[2025-01-08 09:58:16,842 Client4]: Logging to ./output/result_Client4_2025-01-08-09-58-16.txt
appfl: βœ…[2025-01-08 09:58:23,217 Client5]: Logging to ./output/result_Client5_2025-01-08-09-58-23.txt
appfl: βœ…[2025-01-08 09:58:29,646 Client6]: Logging to ./output/result_Client6_2025-01-08-09-58-29.txt
appfl: βœ…[2025-01-08 09:58:36,078 Client7]: Logging to ./output/result_Client7_2025-01-08-09-58-36.txt
appfl: βœ…[2025-01-08 09:58:42,554 Client8]: Logging to ./output/result_Client8_2025-01-08-09-58-42.txt
appfl: βœ…[2025-01-08 09:58:49,468 Client9]: Logging to ./output/result_Client9_2025-01-08-09-58-49.txt
appfl: βœ…[2025-01-08 09:58:56,215 Client10]: Logging to ./output/result_Client10_2025-01-08-09-58-56.txt

Start the training processΒΆ

The server configuration files contains many client configurations which should apply for ALL clients. Now, we need to get those configurations from the server and provide them to the client agents.

[7]:
# Get additional client configurations from the server
client_config_from_server = server_agent.get_client_configs()
for client_agent in client_agents:
    client_agent.load_config(client_config_from_server)

Then, let the clients load initial global model from the server and optionally send the number of local data to the server for weighted aggregation.

πŸ’‘ Note: Typically, server_agent.get_parameters() blocks the result return or returns a Future object, and only returns the global model after receiving num_clients calls to synchronoize to process for clients to get the initial model, ensuring all the clients have the same initial model weights. However, as we are doing serial simulation, we don’t want the blocking, so we pass serail_run=True when calling the function to get the global model immediately.

[8]:
# Load initial global model from the server
init_global_model = server_agent.get_parameters(serial_run=True)
for client_agent in client_agents:
    client_agent.load_parameters(init_global_model)

# [Optional] Set number of local data to the server
for i in range(num_clients):
    sample_size = client_agents[i].get_sample_size()
    server_agent.set_sample_size(
        client_id=client_agents[i].get_id(), sample_size=sample_size
    )

Now, we can start the training iterations. Please note the following points:

  • server_agent.training_finished will return a boolean flag indicating whether the training has reached the specified num_global epochs

  • client_agent.train trains the client local model using the client’s β€œlocal” data

  • server_agent.global_update takes one client’s local model together with a client id (can be get by client_agent.get_id) to schedule the global update for the client local model. For synchronous server aggregation algorithms such as FedAvg, they will not update the global model until receiving local models from all num_clients=10 clients, so the call to global_update will return a concurrent.futures.Future object (if you set blocking=False, otherwise, it will block forever for serial simulation), which will be set after all client local models are sent for global update.

  • In the output log, Pre Val? means whether it is validation prior to the local training. As each client only holds data from 3 to 5 classes, the validation accuracy even drops after local training. However, the global model accuracy continues to increase, showcasing the capabilities of federated learning in improving the generalizability of the trained machine learning model.

[9]:
while not server_agent.training_finished():
    new_global_models = []
    for client_agent in client_agents:
        # Client local training
        client_agent.train()
        local_model = client_agent.get_parameters()
        if isinstance(local_model, tuple):
            local_model, metadata = local_model[0], local_model[1]
        else:
            metadata = {}
        # "Send" local model to server and get a Future object for the new global model
        # The Future object will be resolved when the server receives local models from all clients
        new_global_model_future = server_agent.global_update(
            client_id=client_agent.get_id(),
            local_model=local_model,
            blocking=False,
            **metadata,
        )
        new_global_models.append(new_global_model_future)
    # Load the new global model from the server
    for client_agent, new_global_model_future in zip(client_agents, new_global_models):
        client_agent.load_parameters(new_global_model_future.result())
appfl: βœ…[2025-01-08 09:59:02,820 Client1]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:03,825 Client1]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:06,651 Client1]:          0          N     2.8258     0.4323        90.8109    15.4995      30.3500
appfl: βœ…[2025-01-08 09:59:06,653 Client2]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:07,690 Client2]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:10,635 Client2]:          0          N     2.9442     0.3130        87.7656     9.9495      48.6000
appfl: βœ…[2025-01-08 09:59:10,636 Client3]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:11,665 Client3]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:14,515 Client3]:          0          N     2.8484     0.3909        85.0781    10.4302      38.2900
appfl: βœ…[2025-01-08 09:59:14,517 Client4]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:15,541 Client4]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:18,495 Client4]:          0          N     2.9528     0.4242        89.8377    11.0662      39.8000
appfl: βœ…[2025-01-08 09:59:18,496 Client5]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:19,528 Client5]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:22,589 Client5]:          0          N     3.0604     0.1851        92.1406    22.1558      28.8700
appfl: βœ…[2025-01-08 09:59:22,591 Client6]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:23,603 Client6]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:26,441 Client6]:          0          N     2.8377     0.6525        85.3739    14.9090      38.8200
appfl: βœ…[2025-01-08 09:59:26,443 Client7]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:27,462 Client7]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:30,243 Client7]:          0          N     2.7805     0.2848        86.4844    10.2574      48.8900
appfl: βœ…[2025-01-08 09:59:30,245 Client8]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:31,256 Client8]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:34,121 Client8]:          0          N     2.8636     0.3402        92.6434    14.2500      31.1400
appfl: βœ…[2025-01-08 09:59:34,123 Client9]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:35,132 Client9]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:37,892 Client9]:          0          N     2.7590     0.3460        86.6562     9.1354      50.0000
appfl: βœ…[2025-01-08 09:59:37,894 Client10]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 09:59:38,885 Client10]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 09:59:41,568 Client10]:          0          N     2.6825     0.2586        88.4375    12.8184      40.2200
appfl: βœ…[2025-01-08 09:59:42,569 Client1]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 09:59:45,282 Client1]:          1          N     2.7129     0.1185        98.0902     9.2444      30.4200
appfl: βœ…[2025-01-08 09:59:46,287 Client2]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 09:59:49,161 Client2]:          1          N     2.8722     0.1277        95.7188     8.0948      48.7200
appfl: βœ…[2025-01-08 09:59:50,222 Client3]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 09:59:52,936 Client3]:          1          N     2.7130     0.1673        94.5312     8.9735      38.5400
appfl: βœ…[2025-01-08 09:59:53,939 Client4]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 09:59:56,585 Client4]:          1          N     2.6448     0.1617        96.1872     9.9464      39.7100
appfl: βœ…[2025-01-08 09:59:57,592 Client5]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 10:00:00,295 Client5]:          1          N     2.7021     0.0906        97.3438    13.3640      29.5100
appfl: βœ…[2025-01-08 10:00:01,289 Client6]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 10:00:04,012 Client6]:          1          N     2.7218     0.1973        96.3711     9.5049      39.4300
appfl: βœ…[2025-01-08 10:00:05,028 Client7]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 10:00:07,715 Client7]:          1          N     2.6853     0.1087        95.3750     5.8535      49.2100
appfl: βœ…[2025-01-08 10:00:08,714 Client8]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 10:00:11,396 Client8]:          1          N     2.6800     0.0983        98.0466     9.4541      31.3000
appfl: βœ…[2025-01-08 10:00:12,399 Client9]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 10:00:15,078 Client9]:          1          N     2.6767     0.1480        94.8281     6.5887      49.2400
appfl: βœ…[2025-01-08 10:00:16,078 Client10]:          1          Y                                          1.8781      31.3900
appfl: βœ…[2025-01-08 10:00:18,718 Client10]:          1          N     2.6380     0.1009        96.3438    11.2293      40.7800
appfl: βœ…[2025-01-08 10:00:19,716 Client1]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:22,358 Client1]:          2          N     2.6401     0.0553        99.0451     6.5585      30.5100
appfl: βœ…[2025-01-08 10:00:23,347 Client2]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:26,005 Client2]:          2          N     2.6579     0.0708        97.4844     5.9560      49.2300
appfl: βœ…[2025-01-08 10:00:27,009 Client3]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:29,809 Client3]:          2          N     2.7990     0.0948        96.4219     7.6792      38.7600
appfl: βœ…[2025-01-08 10:00:30,845 Client4]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:33,755 Client4]:          2          N     2.9089     0.0759        98.0936     8.1204      40.0200
appfl: βœ…[2025-01-08 10:00:34,857 Client5]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:37,737 Client5]:          2          N     2.8792     0.0454        98.3750     8.8908      29.7500
appfl: βœ…[2025-01-08 10:00:38,732 Client6]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:41,369 Client6]:          2          N     2.6357     0.0995        98.0120     8.2203      39.4500
appfl: βœ…[2025-01-08 10:00:42,367 Client7]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:45,160 Client7]:          2          N     2.7908     0.0554        97.5781     4.9709      49.0400
appfl: βœ…[2025-01-08 10:00:46,179 Client8]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:48,859 Client8]:          2          N     2.6786     0.0540        98.8658     8.9072      31.2200
appfl: βœ…[2025-01-08 10:00:49,859 Client9]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:52,476 Client9]:          2          N     2.6149     0.0861        97.0000     5.5963      50.2700
appfl: βœ…[2025-01-08 10:00:53,481 Client10]:          2          Y                                          0.9154      68.4300
appfl: βœ…[2025-01-08 10:00:56,112 Client10]:          2          N     2.6278     0.0638        97.3594     8.1234      40.8700