Example: Running a Batched MPI Example

In this example, we will show how to run FL training with batched MPI communication, i.e., each MPI process represents multiple clients. The example script is available at examples/mpi/run_batched_mpi.py.

Difference between Non-Batched and Batched MPI

Below shows the difference between the non-batched and batched MPI examples.

Note

The batched MPI example only supports synchronous FL training, i.e., scheduler="SyncScheduler".

  1import argparse
  2+ import numpy as np
  3from mpi4py import MPI
  4from omegaconf import OmegaConf
  5from appfl.agent import ClientAgent, ServerAgent
  6from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator
  7
  8argparse = argparse.ArgumentParser()
  9argparse.add_argument("--server_config", type=str, default="./resources/configs/mnist/server_fedavg.yaml")
 10argparse.add_argument("--client_config", type=str, default="./resources/configs/mnist/client_1.yaml")
 11+ argparse.add_argument("--num_clients", type=int, default=10)
 12args = argparse.parse_args()
 13
 14comm = MPI.COMM_WORLD
 15rank = comm.Get_rank()
 16size = comm.Get_size()
 17- num_clients = size - 1
 18+ num_clients = max(args.num_clients, size - 1)
 19+ # Split the clients into batches for each rank
 20+ client_batch = [[int(num) for num in array] for array in np.array_split(np.arange(num_clients), size - 1)]
 21
 22if rank == 0:
 23    # Load and set the server configurations
 24    server_agent_config = OmegaConf.load(args.server_config)
 25    server_agent_config.server_configs.num_clients = num_clients
 26    # Create the server agent and communicator
 27    server_agent = ServerAgent(server_agent_config=server_agent_config)
 28    server_communicator = MPIServerCommunicator(comm, server_agent, logger=server_agent.logger)
 29    # Start the server to serve the clients
 30    server_communicator.serve()
 31else:
 32    # Set client configurations and create client agent
 33-   client_agent_config = OmegaConf.load(args.client_config)
 34-   client_agent_config.client_id = f'Client{rank}'
 35-   client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
 36-   client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1
 37-   client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False
 38-   client_agent = ClientAgent(client_agent_config=client_agent_config)
 39+   client_agents = []
 40+   for client_id in client_batch[rank - 1]:
 41+       client_agent_config.client_id = f'Client{client_id}'
 42+       client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
 43+       client_agent_config.data_configs.dataset_kwargs.client_id = client_id
 44+       client_agent_config.data_configs.dataset_kwargs.visualization = True if client_id == 0 else False
 45+       client_agents.append(ClientAgent(client_agent_config=client_agent_config))
 46    # Create the client communicator
 47-   client_communicator = MPIClientCommunicator(comm, server_rank=0, client_id=client_agent_config.client_id)
 48+   client_communicator = MPIClientCommunicator(comm, server_rank=0, client_ids=[f"Client{client_id}" for client_id in client_batch[rank - 1]])
 49    # Get and load the general client configurations
 50    client_config = client_communicator.get_configuration()
 51-   client_agent.load_config(client_config)
 52+   for client_agent in client_agents:
 53+       client_agent.load_config(client_config)
 54    # Get and load the initial global model
 55    init_global_model = client_communicator.get_global_model(init_model=True)
 56-   client_agent.load_parameters(init_global_model)
 57+   for client_agent in client_agents:
 58+       client_agent.load_parameters(init_global_model)
 59    # [Optional] Send the sample size to the server
 60-   sample_size = client_agent.get_sample_size()
 61-   client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
 62+   client_sample_sizes = {
 63+       client_id: {'sample_size': client_agent.get_sample_size()}
 64+       for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents)
 65+   }
 66+   client_communicator.invoke_custom_action(action='set_sample_size', kwargs=client_sample_sizes)
 67    # Generate data readiness report
 68    if hasattr(client_config.data_readiness_configs, 'generate_dr_report') and client_config.data_readiness_configs.generate_dr_report:
 69-       data_readiness = client_agent.generate_readiness_report(client_config)
 70-       client_communicator.invoke_custom_action(action='get_data_readiness_report', **data_readiness)
 71+       data_readiness = {
 72+           client_id: client_agent.generate_readiness_report(client_config)
 73+           for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents)
 74+       }
 75+       client_communicator.invoke_custom_action(action='get_data_readiness_report', kwargs=data_readiness)
 76    # Local training and global model update iterations
 77    while True:
 78-       client_agent.train()
 79-       local_model = client_agent.get_parameters()
 80-       if isinstance(local_model, tuple):
 81-           local_model, metadata = local_model[0], local_model[1]
 82-       else:
 83-           metadata = {}
 84-       new_global_model, metadata = client_communicator.update_global_model(local_model, **metadata)
 85+       client_local_models = {}
 86+       client_metadata = {}
 87+       for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
 88+           client_agent.train()
 89+           local_model = client_agent.get_parameters()
 90+           if isinstance(local_model, tuple):
 91+               local_model, metadata = local_model[0], local_model[1]
 92+               client_metadata[client_id] = metadata
 93+           client_local_models[client_id] = local_model
 94+       new_global_model, metadata = client_communicator.update_global_model(client_local_models, kwargs=client_metadata)
 95-       if metadata['status'] == 'DONE':
 96+       if all(metadata[client_id]['status'] == 'DONE' for client_id in metadata):
 97            break
 98-       client_agent.load_parameters(new_global_model)
 99+       for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
100+           client_agent.load_parameters(new_global_model)
101    client_communicator.invoke_custom_action(action='close_connection')

Below summarizes the main changes made to the script:

  • The script evenly splits the clients into batches for each rank (lines 18-20), and initializes the client agents for each batch (lines 41-47).

  • When creating the client MPI communicator for batched MPI, the script passes the client IDs of the batch to the communicator (line 50).

  • For the invoked custom actions, the keyword arguments are passed as a dictionary kwargs with client IDs as keys (lines 64-68, 73-77).

  • For updating the global model, the script passes a dictionary of trained local models with client IDs as keys (lines 83-87).

  • For the metadata returned from the server, it is a dictionary with client IDs as keys, and a dictionary of metadata as values (line 89).

Running Batched MPI Example

You can run the batched MPI example with the following command to simulate 10 clients with 6 MPI processes, where one process is the server and the rest are clients, so each MPI client process represents two clients.

mpiexec -n 6 python ./mpi/run_batched_mpi.py --num_clients 10

You can also run the batched MPI example with the following command to simulate 10 clients with 11 MPI processes, where one process is the server and the rest are clients, so each MPI client process only represents one client.

mpiexec -n 11 python ./mpi/run_batched_mpi.py --num_clients 10
# Note: this is equivalent to running the non-batched MPI example below
mpiexec -n 11 python ./mpi/run_mpi.py

Extra: Running the Batched MPI Example for Asynchronous FL

Though it is not very logical to run batched MPI communication with asynchronous FL training, you can still have each MPI process represent multiple clients running serially and sending updates asynchronously.

Below shows the changes needed in local training part to run the batched MPI example with asynchronous FL training. The example script is available at examples/mpi/run_batched_mpi.py.

 1# Local training and global model update iterations
 2+ finish_flag = False
 3while True:
 4-   client_local_models = {}
 5-   client_metadata = {}
 6-   for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
 7-       client_agent.train()
 8-       local_model = client_agent.get_parameters()
 9-       if isinstance(local_model, tuple):
10-          local_model, metadata = local_model[0], local_model[1]
11-          client_metadata[client_id] = metadata
12-       client_local_models[client_id] = local_model
13-   new_global_model, metadata = client_communicator.update_global_model(client_local_models, kwargs=client_metadata)
14-   if all(metadata[client_id]['status'] == 'DONE' for client_id in metadata):
15-       break
16-   for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
17-       client_agent.load_parameters(new_global_model)
18+   for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
19+       client_agent.train()
20+       local_model = client_agent.get_parameters()
21+       if isinstance(local_model, tuple):
22+           local_model, metadata = local_model
23+       else:
24+           metadata = {}
25+       new_global_model, metadata = client_communicator.update_global_model(local_model, client_id=client_id, **metadata)
26+       if metadata['status'] == 'DONE':
27+           finish_flag = True
28+           break
29+       client_agent.load_parameters(new_global_model)
30+   if finish_flag:
31+       break
32client_communicator.invoke_custom_action(action='close_connection')

The main change made to the script is that: the client MPI process sends update_global_model request serially for each client in the batch and specify its client ID (line 16).

You can run the batched MPI example with the following command to simulate 10 clients with 6 MPI processes, where one process is the server and the rest are clients, so each MPI client process represents two clients.

mpiexec -n 6 python ./mpi/run_batched_mpi_async.py --num_clients 10