Simulating PPFL (MPI)#
In this section, we describe how to simulate PPFL on a single machine or cluster by having the server and each client run on different MPI processes. It can be used for simulating both synchronous and asynchronous FL algorithms.
Note
To run the MPI simulation, you need to use several MPI processes by using mpiexec
command. For example, to run 4 MPI processes, you can use the following command:
mpiexec -n 4 python mpi_code.py
First, user needs to load configuration files for the client and server agents. The total number of clients is equal to the total number of MPI processes minus one (as one process is used for the server), and then make necessary changes to make the configurations compatible with num_clients
. With the configuration, we can create the server and client agents.
from mpi4py import MPI
from omegaconf import OmegaConf
from appfl.agent import ClientAgent, ServerAgent
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
num_clients = size - 1
if rank == 0:
# Load and update server configuration
server_agent_config = OmegaConf.load("<path_to_server_config>.yaml")
server_agent_config.server_configs.scheduler_kwargs.num_clients = num_clients
if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"):
server_agent_config.server_configs.aggregator_kwargs.num_clients = num_clients
# Create the server agent
server_agent = ServerAgent(server_agent_config=server_agent_config)
else:
# Load and set client configuration
client_agent_config = OmegaConf.load("<path_to_client_config>.yaml")
client_agent_config.train_configs.logging_id = f'Client{rank}'
client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1
client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False
# Create the client agent
client_agent = ClientAgent(client_agent_config=client_agent_config)
Then for the FL server, we can create an MPI communicator to serve the requests from the clients using the serve
method.
from appfl.comm.mpi import MPIServerCommunicator
if rank == 0:
server_communicator = MPIServerCommunicator(
comm,
server_agent,
logger=server_agent.logger
)
server_communicator.serve()
For the clients, we can start the FL training process by doing the following process:
Create an MPI communicator for the client.
Get and load the shared client configurations from the server (such as trainer and model architecture).
Get and load the initial global model.
Start the training process by calling the
client_agent.train()
method, and then send the updated model (client_agent.get_parameters
) to the server until the end of the training process.
from appfl.comm.mpi import MPIClientCommunicator
if rank != 0:
client_communicator = MPIClientCommunicator(comm, server_rank=0)
# Load the configurations and initial global model
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# Send the sample size to the server
sample_size = client_agent.get_sample_size()
client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
# Local training and global model update iterations
while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)