Example: Running a Batched MPI Example¶
In this example, we will show how to run FL training with batched MPI communication, i.e., each MPI process represents multiple clients. The example script is available at examples/mpi/run_batched_mpi.py.
Difference between Non-Batched and Batched MPI¶
Below shows the difference between the non-batched and batched MPI examples.
Note
The batched MPI example only supports synchronous FL training, i.e., scheduler="SyncScheduler"
.
1import argparse
2+ import numpy as np
3from mpi4py import MPI
4from omegaconf import OmegaConf
5from appfl.agent import ClientAgent, ServerAgent
6from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator
7
8argparse = argparse.ArgumentParser()
9argparse.add_argument("--server_config", type=str, default="./resources/configs/mnist/server_fedavg.yaml")
10argparse.add_argument("--client_config", type=str, default="./resources/configs/mnist/client_1.yaml")
11+ argparse.add_argument("--num_clients", type=int, default=10)
12args = argparse.parse_args()
13
14comm = MPI.COMM_WORLD
15rank = comm.Get_rank()
16size = comm.Get_size()
17- num_clients = size - 1
18+ num_clients = max(args.num_clients, size - 1)
19+ # Split the clients into batches for each rank
20+ client_batch = [[int(num) for num in array] for array in np.array_split(np.arange(num_clients), size - 1)]
21
22if rank == 0:
23 # Load and set the server configurations
24 server_agent_config = OmegaConf.load(args.server_config)
25 server_agent_config.server_configs.num_clients = num_clients
26 # Create the server agent and communicator
27 server_agent = ServerAgent(server_agent_config=server_agent_config)
28 server_communicator = MPIServerCommunicator(comm, server_agent, logger=server_agent.logger)
29 # Start the server to serve the clients
30 server_communicator.serve()
31else:
32 # Set client configurations and create client agent
33- client_agent_config = OmegaConf.load(args.client_config)
34- client_agent_config.client_id = f'Client{rank}'
35- client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
36- client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1
37- client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False
38- client_agent = ClientAgent(client_agent_config=client_agent_config)
39+ client_agents = []
40+ for client_id in client_batch[rank - 1]:
41+ client_agent_config.client_id = f'Client{client_id}'
42+ client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
43+ client_agent_config.data_configs.dataset_kwargs.client_id = client_id
44+ client_agent_config.data_configs.dataset_kwargs.visualization = True if client_id == 0 else False
45+ client_agents.append(ClientAgent(client_agent_config=client_agent_config))
46 # Create the client communicator
47- client_communicator = MPIClientCommunicator(comm, server_rank=0, client_id=client_agent_config.client_id)
48+ client_communicator = MPIClientCommunicator(comm, server_rank=0, client_ids=[f"Client{client_id}" for client_id in client_batch[rank - 1]])
49 # Get and load the general client configurations
50 client_config = client_communicator.get_configuration()
51- client_agent.load_config(client_config)
52+ for client_agent in client_agents:
53+ client_agent.load_config(client_config)
54 # Get and load the initial global model
55 init_global_model = client_communicator.get_global_model(init_model=True)
56- client_agent.load_parameters(init_global_model)
57+ for client_agent in client_agents:
58+ client_agent.load_parameters(init_global_model)
59 # [Optional] Send the sample size to the server
60- sample_size = client_agent.get_sample_size()
61- client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
62+ client_sample_sizes = {
63+ client_id: {'sample_size': client_agent.get_sample_size()}
64+ for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents)
65+ }
66+ client_communicator.invoke_custom_action(action='set_sample_size', kwargs=client_sample_sizes)
67 # Generate data readiness report
68 if hasattr(client_config.data_readiness_configs, 'generate_dr_report') and client_config.data_readiness_configs.generate_dr_report:
69- data_readiness = client_agent.generate_readiness_report(client_config)
70- client_communicator.invoke_custom_action(action='get_data_readiness_report', **data_readiness)
71+ data_readiness = {
72+ client_id: client_agent.generate_readiness_report(client_config)
73+ for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents)
74+ }
75+ client_communicator.invoke_custom_action(action='get_data_readiness_report', kwargs=data_readiness)
76 # Local training and global model update iterations
77 while True:
78- client_agent.train()
79- local_model = client_agent.get_parameters()
80- if isinstance(local_model, tuple):
81- local_model, metadata = local_model[0], local_model[1]
82- else:
83- metadata = {}
84- new_global_model, metadata = client_communicator.update_global_model(local_model, **metadata)
85+ client_local_models = {}
86+ client_metadata = {}
87+ for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
88+ client_agent.train()
89+ local_model = client_agent.get_parameters()
90+ if isinstance(local_model, tuple):
91+ local_model, metadata = local_model[0], local_model[1]
92+ client_metadata[client_id] = metadata
93+ client_local_models[client_id] = local_model
94+ new_global_model, metadata = client_communicator.update_global_model(client_local_models, kwargs=client_metadata)
95- if metadata['status'] == 'DONE':
96+ if all(metadata[client_id]['status'] == 'DONE' for client_id in metadata):
97 break
98- client_agent.load_parameters(new_global_model)
99+ for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
100+ client_agent.load_parameters(new_global_model)
101 client_communicator.invoke_custom_action(action='close_connection')
Below summarizes the main changes made to the script:
The script evenly splits the clients into batches for each rank (lines 18-20), and initializes the client agents for each batch (lines 41-47).
When creating the client MPI communicator for batched MPI, the script passes the client IDs of the batch to the communicator (line 50).
For the invoked custom actions, the keyword arguments are passed as a dictionary
kwargs
with client IDs as keys (lines 64-68, 73-77).For updating the global model, the script passes a dictionary of trained local models with client IDs as keys (lines 83-87).
For the metadata returned from the server, it is a dictionary with client IDs as keys, and a dictionary of metadata as values (line 89).
Running Batched MPI Example¶
You can run the batched MPI example with the following command to simulate 10 clients with 6 MPI processes, where one process is the server and the rest are clients, so each MPI client process represents two clients.
mpiexec -n 6 python ./mpi/run_batched_mpi.py --num_clients 10
You can also run the batched MPI example with the following command to simulate 10 clients with 11 MPI processes, where one process is the server and the rest are clients, so each MPI client process only represents one client.
mpiexec -n 11 python ./mpi/run_batched_mpi.py --num_clients 10
# Note: this is equivalent to running the non-batched MPI example below
mpiexec -n 11 python ./mpi/run_mpi.py
Extra: Running the Batched MPI Example for Asynchronous FL¶
Though it is not very logical to run batched MPI communication with asynchronous FL training, you can still have each MPI process represent multiple clients running serially and sending updates asynchronously.
Below shows the changes needed in local training part to run the batched MPI example with asynchronous FL training. The example script is available at examples/mpi/run_batched_mpi.py.
1# Local training and global model update iterations
2+ finish_flag = False
3while True:
4- client_local_models = {}
5- client_metadata = {}
6- for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
7- client_agent.train()
8- local_model = client_agent.get_parameters()
9- if isinstance(local_model, tuple):
10- local_model, metadata = local_model[0], local_model[1]
11- client_metadata[client_id] = metadata
12- client_local_models[client_id] = local_model
13- new_global_model, metadata = client_communicator.update_global_model(client_local_models, kwargs=client_metadata)
14- if all(metadata[client_id]['status'] == 'DONE' for client_id in metadata):
15- break
16- for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
17- client_agent.load_parameters(new_global_model)
18+ for client_id, client_agent in zip([f"Client{client_id}" for client_id in client_batch[rank - 1]], client_agents):
19+ client_agent.train()
20+ local_model = client_agent.get_parameters()
21+ if isinstance(local_model, tuple):
22+ local_model, metadata = local_model
23+ else:
24+ metadata = {}
25+ new_global_model, metadata = client_communicator.update_global_model(local_model, client_id=client_id, **metadata)
26+ if metadata['status'] == 'DONE':
27+ finish_flag = True
28+ break
29+ client_agent.load_parameters(new_global_model)
30+ if finish_flag:
31+ break
32client_communicator.invoke_custom_action(action='close_connection')
The main change made to the script is that: the client MPI process sends update_global_model
request serially for each client in the batch and specify its client ID (line 16).
You can run the batched MPI example with the following command to simulate 10 clients with 6 MPI processes, where one process is the server and the rest are clients, so each MPI client process represents two clients.
mpiexec -n 6 python ./mpi/run_batched_mpi_async.py --num_clients 10