APPFL Communicator¶
The APPFL communicator is used for exchanging
various types of model data (e.g. parameters, gradients, compressed bytes, etc.)
metadata such as configurations, control signals, etc.
for different tasks on the server/client agent side to run.
In APPFL, we support the following types of communication protocols:
gRPC: Google Remote Procedure Call¶
gRPC can be used either for simulating federated learning on a single machine or cluster, or for running federated learning on real-world distributed machines. It is composed of two parts:
gRPC Server Communicator (
appfl.comm.grpc.GRPCServerCommunicator
) which creates a server for listening to incoming requests from clients for various tasks.gRPC Client Communicator (
appfl.comm.grpc.GRPCClientCommunicator
) which sends requests to the server for various tasks.
gRPC Server Communicator¶
For the server side, the server only needs to create an instance of GRPCServerCommunicator
and call the serve
method (available in appfl.comm.grpc
) to start the server. The server will listen to incoming requests from clients for various tasks.
The server can handle the following tasks:
Get configurations that are shared among all clients via the
GetConfiguration
method.Get the global model via the
GetGlobalModel
method.Update the global model with the local model from the client via the
UpdateGlobalModel
method.Invoke custom action on the server via the
InvokeCustomAction
method.
Note
You can add any custom tasks by implementing the corresponding methods in the InvokeCustomAction
class.
class GRPCServerCommunicator(GRPCCommunicatorServicer):
def __init__(
self,
server_agent: ServerAgent,
max_message_size: int = 2 * 1024 * 1024,
logger: Optional[ServerAgentFileLogger] = None,
) -> None:
"""
Creates a gRPC server communicator.
:param `server_agent`: `ServerAgent` object
:param `max_message_size`: Maximum message size in bytes to be sent/received.
Object size larger than this will be split into multiple messages.
:param `logger`: A logger object for logging messages
"""
def GetConfiguration(self, request, context):
"""
Client requests the FL configurations that are shared among all clients from the server.
:param: `request.header.client_id`: A unique client ID
:param: `request.meta_data`: YAML serialized metadata dictionary (if needed)
:return `response.header.status`: Server status
:return `response.configuration`: YAML serialized FL configurations
"""
def GetGlobalModel(self, request, context):
"""
Return the global model to clients. This method is supposed to be called by
clients to get the initial and final global model. Returns are sent back as a
stream of messages.
:param: `request.header.client_id`: A unique client ID
:param: `request.meta_data`: YAML serialized metadata dictionary (if needed)
:return `response.header.status`: Server status
:return `response.global_model`: Serialized global model
"""
def UpdateGlobalModel(self, request_iterator, context):
"""
Update the global model with the local model from a client. This method will
return the updated global model to the client as a stream of messages.
:param: request_iterator: A stream of `DataBuffer` messages - which contains
serialized request in `UpdateGlobalModelRequest` type.
If concatenating all messages in `request_iterator` to get a `request`, then
:param: request.header.client_id: A unique client ID
:param: request.local_model: Serialized local model
:param: request.meta_data: YAML serialized metadata dictionary (if needed)
"""
def InvokeCustomAction(self, request, context):
"""
This function is the entry point for any custom action that the server agent
can perform. The server agent should implement the custom action and call this
function to perform the action.
:param: `request.header.client_id`: A unique client ID
:param: `request.action`: A string tag representing the custom action
:param: `request.meta_data`: YAML serialized metadata dictionary for the custom action (if needed)
:return `response.header.status`: Server status
:return `response.meta_data`: YAML serialized metadata dictionary for return values (if needed)
"""
gRPC Client Communicator¶
During the federated learning process, the client can communicate to the server by invoking corresponding methods in the GRPCClientCommunicator
class. For example, after a client finish a local training round, it can send the local model to the server for global aggregation by calling the update_global_model
method.
Note
You can add any custom tasks by implementing the corresponding methods in the invoke_custom_action
class. Also make sure that the server has the corresponding handler codes implemented in the InvokeCustomAction
method.
class GRPCClientCommunicator:
def __init__(
self,
client_id: Union[str, int],
*,
server_uri: str,
use_ssl: bool = False,
use_authenticator: bool = False,
root_certificate: Optional[Union[str, bytes]] = None,
authenticator: Optional[str] = None,
authenticator_args: Dict[str, Any] = {},
max_message_size: int = 2 * 1024 * 1024,
**kwargs,
):
"""
Create a channel to the server and initialize the gRPC client stub.
:param client_id: A unique client ID.
:param server_uri: The URI of the server to connect to.
:param use_ssl: Whether to use SSL/TLS to authenticate the server and encrypt communicated data.
:param use_authenticator: Whether to use an authenticator to authenticate the client in each RPC. Must have `use_ssl=True` if `True`.
:param root_certificate: The PEM-encoded root certificates as a byte string, or `None` to retrieve them from a default location chosen by gRPC runtime.
:param authenticator: The name of the authenticator to use for authenticating the client in each RPC.
:param authenticator_args: The arguments to pass to the authenticator.
:param max_message_size: The maximum message size in bytes.
"""
def get_configuration(self, **kwargs) -> DictConfig:
"""
Get the federated learning configurations from the server.
:param kwargs: additional metadata to be sent to the server
:return: the federated learning configurations
"""
def get_global_model(self, **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]:
"""
Get the global model from the server.
:param kwargs: additional metadata to be sent to the server
:return: the global model with additional metadata (if any)
"""
def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Tuple[Union[Dict, OrderedDict], Dict]:
"""
Send local model to FL server for global update, and return the new global model.
:param local_model: the local model to be sent to the server for global aggregation
:param kwargs: additional metadata to be sent to the server
:return: the updated global model with additional metadata. Specifically, `meta_data["status"]` is either "RUNNING" or "DONE".
"""
def invoke_custom_action(self, action: str, **kwargs) -> Dict:
"""
Invoke a custom action on the server.
:param action: the action to be invoked
:param kwargs: additional metadata to be sent to the server
:return: the response from the server
"""
Example Usage (gRPC)¶
Below shows an example on how to start a server using GRPCServerCommunicator, which waits for incoming requests from clients.
import argparse
from omegaconf import OmegaConf
from appfl.agent import ServerAgent
from appfl.comm.grpc import GRPCServerCommunicator, serve
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--config",
type=str,
default="./resources/configs/mnist/server_fedavg.yaml",
help="Path to the configuration file.",
)
args = argparser.parse_args()
server_agent_config = OmegaConf.load(args.config)
server_agent = ServerAgent(server_agent_config=server_agent_config)
communicator = GRPCServerCommunicator(
server_agent,
logger=server_agent.logger,
**server_agent_config.server_configs.comm_configs.grpc_configs,
)
serve(
communicator,
**server_agent_config.server_configs.comm_configs.grpc_configs,
)
To interact with the server and start an FL experiment, you can start a client using GRPCClientCommunicator as shown below.
import argparse
from omegaconf import OmegaConf
from appfl.agent import ClientAgent
from appfl.comm.grpc import GRPCClientCommunicator
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--config",
type=str,
default="./resources/configs/mnist/client_1.yaml",
help="Path to the configuration file.",
)
args = argparser.parse_args()
client_agent_config = OmegaConf.load(args.config)
client_agent = ClientAgent(client_agent_config=client_agent_config)
client_communicator = GRPCClientCommunicator(
client_id=client_agent.get_id(),
**client_agent_config.comm_configs.grpc_configs,
)
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
client_communicator.invoke_custom_action(
action="set_sample_size", sample_size=sample_size
)
# Generate data readiness report
if (
hasattr(client_config, "data_readiness_configs")
and hasattr(client_config.data_readiness_configs, "generate_dr_report")
and client_config.data_readiness_configs.generate_dr_report
):
data_readiness = client_agent.generate_readiness_report(client_config)
client_communicator.invoke_custom_action(
action="get_data_readiness_report", **data_readiness
)
while True:
client_agent.train()
local_model = client_agent.get_parameters()
if isinstance(local_model, tuple):
local_model, metadata = local_model[0], local_model[1]
else:
metadata = {}
new_global_model, metadata = client_communicator.update_global_model(
local_model, **metadata
)
if metadata["status"] == "DONE":
break
if "local_steps" in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata["local_steps"]
client_agent.load_parameters(new_global_model)
client_communicator.invoke_custom_action(action="close_connection")
MPI: Message Passing Interface¶
MPI can be used for simulating federated learning on a single machine or a cluster of machines. It is composed of two parts:
MPI Server Communicator (
appfl.comm.mpi.MPIServerCommunicator
) which starts a server to listen to incoming requests from clients for various tasks.MPI Client Communicator (
appfl.comm.mpi.MPIClientCommunicator
) which sends requests to the server for various tasks.
MPI Server Communicator¶
For the server side, the server only needs to create an instance of MPIServerCommunicator
and call the serve
method to start the server. The server will listen to incoming requests from clients for various tasks.
The server can handle the following tasks:
Get configurations that are shared among all clients via the
_get_configuration
method.Get the global model via the
_get_global_model
method.Update the global model with the local model from the client via the
_update_global_model
method.Invoke custom action on the server via the
_invoke_custom_action
method.
Note
The server will automatically stop itself after reaching the specified num_global_epochs
.
Note
You can add any custom tasks by implementing the corresponding methods in the _invoke_custom_action
class.
class MPIServerCommunicator:
def __init__(
self,
comm,
server_agent: ServerAgent,
logger: Optional[ServerAgentFileLogger] = None,
) -> None:
"""
Create an MPI server communicator.
:param `comm`: MPI communicator object
:param `server_agent`: `ServerAgent` object
:param `logger`: A logger object for logging messages
"""
def serve(self):
"""
Start the server to serve the clients.
"""
def _get_configuration(
self, client_rank: int, request: MPITaskRequest
) -> MPITaskResponse:
"""
Client requests the FL configurations that are shared among all clients from the server.
:param: `client_rank`: The rank of the client in MPI
:param: `request.meta_data`: YAML serialized metadata dictionary (if needed)
:return `response.status`: Server status
:return `response.meta_data`: YAML serialized FL configurations
"""
def _get_global_model(
self, client_rank: int, request: MPITaskRequest
) -> Optional[MPITaskResponse]:
"""
Return the global model to clients. This method is supposed to provide clients with
the initial and final global model.
:param: `client_rank`: The rank of the client(s) in MPI
:param: `request.meta_data`: YAML serialized metadata dictionary (if needed)
- `meta_data['_client_ids']`: A list of client ids to get the global model for batched clients
- `meta_data['init_model']`: Whether to get the initial global model or not
:return `response.status`: Server status
:return `response.payload`: Serialized global model
:return `response.meta_data`: YAML serialized metadata dictionary (if needed)
"""
def _update_global_model(
self, client_rank: int, request: MPITaskRequest
) -> Optional[MPITaskResponse]:
"""
Update the global model with the local model from the client,
and return the updated global model to the client.
:param: `client_rank`: The rank of the client in MPI
:param: `request.payload`: Serialized local model
:param: `request.meta_data`: YAML serialized metadata dictionary (if needed)
:return `response.status`: Server status
:return `response.payload`: Serialized updated global model
:return `response.meta_data`: YAML serialized metadata dictionary (if needed)
"""
def _invoke_custom_action(
self,
client_rank: int,
request: MPITaskRequest,
) -> Optional[MPITaskResponse]:
"""
Invoke custom action on the server.
:param: `client_rank`: The rank of the client in MPI
:param: `request.meta_data`: YAML serialized metadata dictionary (if needed)
:return `response.status`: Server status
:return `response.meta_data`: YAML serialized metadata dictionary (if needed)
"""
MPI Client Communicator¶
During the federated learning process, the client can communicates to the server by invoking corresponding methods in the MPIClientCommunicator
class. For example, after a client finish a local training round, it can send the local model to the server for global aggregation by calling the update_global_model
method.
Note
You can add any custom tasks by implementing the corresponding methods in the invoke_custom_action
class. Also make sure that the server has the corresponding handler codes implemented in the _invoke_custom_action
method.
class MPIClientCommunicator:
"""
MPI client communicator for federated learning.
:param comm: the MPI communicator from mpi4py
:param server_rank: the rank of the server in the MPI communicator
:param client_id: [optional] an optional client ID for one client for logging purposes, mutually exclusive with client_ids
:param client_ids: [optional] a list of client IDs for a batched clients,
this is only required when the MPI process represents multiple clients
"""
def get_configuration(self, **kwargs) -> DictConfig:
"""
Get the federated learning configurations from the server.
:param kwargs: additional metadata to be sent to the server
:return: the federated learning configurations
"""
def get_global_model(
self, **kwargs
) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]:
"""
Get the global model from the server.
:param kwargs: additional metadata to be sent to the server
:return: the global model with additional metadata (if any)
"""
def update_global_model(
self,
local_model: Union[Dict, OrderedDict, bytes],
client_id: Optional[Union[str, int]] = None,
**kwargs,
) -> Tuple[Union[Dict, OrderedDict], Dict]:
"""
Send local model(s) to the FL server for global update, and return the new global model.
:param local_model: the local model to be sent to the server for global aggregation
- `local_model` can be a single model if one MPI process has only one client or one MPI process
has multiple clients but the user wants to send one model at a time
- `local_model` can be a dictionary of multiple models as well if one MPI process has multiple clients
and the user wants to send all models
:param client_id (optional): the client ID for the local model. It is only required when the MPI process has multiple clients
and the user only wants to send one model at a time.
:param kwargs (optional): additional metadata to be sent to the server. When sending local models for multiple clients,
use the client ID as the key and the metadata as the value, e.g.,
```
update_global_model(
local_model=...,
kwargs = {
client_id1: {key1: value1, key2: value2},
client_id2: {key1: value1, key2: value2},
}
)
```
:return model: the updated global model
- Note: the global model is only one model even if multiple local models are sent, which means that
the server should have synchronous aggregation. If asynchronous aggregation is needed, the user should
pass the local models one by one.
:return meta_data: additional metadata from the server. When updating local models for multiple clients, the response will
be a dictionary with the client ID as the key and the response as the value, e.g.,
```
{
client_id1: {ret1: value1, ret2: value2},
client_id2: {ret1: value1, ret2: value2},
}
```
"""
def invoke_custom_action(
self, action: str, client_id: Optional[Union[str, int]] = None, **kwargs
) -> Dict:
"""
Invoke a custom action on the server.
:param action: the action to be invoked
:param client_id (optional): the client ID for the action. It is only required when the MPI process has multiple clients
and the action is specific to a client instead of all clients.
:param kwargs (optional): additional metadata to be sent to the server. When invoking custom action for multiple clients,
use the client ID as the key and the metadata as the value, e.g.,
```
invoke_custom_action(
action=...,
kwargs = {
client_id1: {key1: value1, key2: value2},
client_id2: {key1: value1, key2: value2},
}
)
```
:return: the response from the server (if any). When invoking custom action for multiple clients, the response will
be a dictionary with the client ID as the key and the response as the value, e.g.,
```
{
client_id1: {ret1: value1, ret2: value2},
client_id2: {ret1: value1, ret2: value2},
}
```
"""
Example Usage (MPI)¶
Here is an example of how to use the MPI communicator in APPFL to start FL experiments.
import argparse
from mpi4py import MPI
from omegaconf import OmegaConf
from appfl.agent import ClientAgent, ServerAgent
from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator
argparse = argparse.ArgumentParser()
argparse.add_argument(
"--server_config",
type=str,
default="./resources/configs/mnist/server_fedcompass.yaml",
)
argparse.add_argument(
"--client_config", type=str, default="./resources/configs/mnist/client_1.yaml"
)
args = argparse.parse_args()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
num_clients = size - 1
if rank == 0:
# Load and set the server configurations
server_agent_config = OmegaConf.load(args.server_config)
server_agent_config.server_configs.num_clients = num_clients
# Create the server agent and communicator
server_agent = ServerAgent(server_agent_config=server_agent_config)
server_communicator = MPIServerCommunicator(
comm, server_agent, logger=server_agent.logger
)
# Start the server to serve the clients
server_communicator.serve()
else:
# Set the client configurations
client_agent_config = OmegaConf.load(args.client_config)
client_agent_config.client_id = f"Client{rank}"
client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1
client_agent_config.data_configs.dataset_kwargs.visualization = (
True if rank == 1 else False
)
# Create the client agent and communicator
client_agent = ClientAgent(client_agent_config=client_agent_config)
client_communicator = MPIClientCommunicator(
comm, server_rank=0, client_id=client_agent_config.client_id
)
# Load the configurations and initial global model
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# Send the sample size to the server
sample_size = client_agent.get_sample_size()
client_communicator.invoke_custom_action(
action="set_sample_size", sample_size=sample_size
)
# Generate data readiness report
if (
hasattr(client_config, "data_readiness_configs")
and hasattr(client_config.data_readiness_configs, "generate_dr_report")
and client_config.data_readiness_configs.generate_dr_report
):
data_readiness = client_agent.generate_readiness_report(client_config)
client_communicator.invoke_custom_action(
action="get_data_readiness_report", **data_readiness
)
# Local training and global model update iterations
while True:
client_agent.train()
local_model = client_agent.get_parameters()
if isinstance(local_model, tuple):
local_model, metadata = local_model[0], local_model[1]
else:
metadata = {}
new_global_model, metadata = client_communicator.update_global_model(
local_model, **metadata
)
if metadata["status"] == "DONE":
break
if "local_steps" in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata["local_steps"]
client_agent.load_parameters(new_global_model)
client_communicator.invoke_custom_action(action="close_connection")
Globus Compute¶
Globus Compute is a distributed function as a service platform, which can be used for running federated learning on real-world distributed machines. It can turn each client into an endpoint which can be patched with remote functions on the server side to run federated learning tasks. It is composed of two parts:
Globus Compute Server Communicator (
appfl.comm.globus_compute.GlobusComputeServerCommunicator
), which can send task to the client endpoints and receive the results.Globus Compute Client Entry Point (
appfl.comm.globus_compute.GlobusComputeClientCommunicator.globus_compute_client_entry_point
), which is the entry point for the client to execute the task and send the results back to the server.
Globus Compute Server Communicator¶
Globus Compute is a server-driven communication protocol, where the server sends tasks to the clients and receives the results to control the FL process. It is composed of the following methods to send tasks to the clients and receive the results:
class GlobusComputeServerCommunicator:
"""
Communicator used by the federated learning server which plans to use Globus Compute
for orchestrating the federated learning experiments.
Globus Compute is a distributed function-as-a-service platform that allows users to run
functions on specified remote endpoints. For more details, check the Globus Compute SDK
documentation at https://globus-compute.readthedocs.io/en/latest/endpoints.html.
:param `gcc`: Globus Compute client object
:param `server_agent_config`: The server agent configuration
:param `client_agent_configs`: A list of client agent configurations.
:param [Optional] `logger`: Optional logger object.
"""
def send_task_to_all_clients(
self,
task_name: str,
*,
model: Optional[Union[Dict, OrderedDict, bytes]] = None,
metadata: Union[Dict, List[Dict]] = {},
need_model_response: bool = False,
):
"""
Send a specific task to all clients.
:param `task_name`: Name of the task to be executed on the clients
:param [Optional] `model`: Model to be sent to the clients
:param [Optional] `metadata`: Additional metadata to be sent to the clients
:param `need_model_response`: Whether the task requires a model response from the clients
If so, the server will provide a pre-signed URL for the clients to upload the model if using S3.
"""
pass
def send_task_to_one_client(
self,
client_endpoint_id: str,
task_name: str,
*,
model: Optional[Union[Dict, OrderedDict, bytes]] = None,
metadata: Optional[Dict] = {},
need_model_response: bool = False,
):
"""
Send a specific task to one specific client endpoint.
:param `client_endpoint_id`: The client endpoint id to which the task is sent.
:param `task_name`: Name of the task to be executed on the clients
:param [Optional] `model`: Model to be sent to the clients
:param [Optional] `metadata`: Additional metadata to be sent to the clients
:param `need_model_response`: Whether the task requires a model response from the clients
If so, the server will provide a pre-signed URL for the clients to upload the model if using S3.
"""
pass
def recv_result_from_all_clients(self) -> Tuple[Dict, Dict]:
"""
Receive task results from all clients that have running tasks.
:return `client_results`: A dictionary containing the results from all clients - Dict[client_endpoint_id, client_model]
:return `client_metadata`: A dictionary containing the metadata from all clients - Dict[client_endpoint_id, client_metadata]
"""
pass
def recv_result_from_one_client(self) -> Tuple[str, Any, Dict]:
"""
Receive task results from the first client that finishes the task.
:return `client_endpoint_id`: The client endpoint id from which the result is received.
:return `client_model`: The model returned from the client
:return `client_metadata`: The metadata returned from the client
"""
pass
def shutdown_all_clients(self):
"""Cancel all the running tasks on the clients and shutdown the globus compute executor."""
pass
def cancel_all_tasks(self):
"""Cancel all on-the-fly client tasks."""
pass
Globus Compute Client Entry Point¶
For all tasks the server sent to the client, the tasks should be implemented in the client entry point function, which is the entry point for the client to execute the task and send the results back to the server. Below shows the current client entry point function which only supports the get_sample_size
and train
tasks. User can freely add more tasks by implementing the corresponding functions in the client entry point.
def globus_compute_client_entry_point(
task_name="N/A",
client_agent_config=None,
model=None,
meta_data=None,
):
"""
Entry point for the Globus Compute client endpoint for federated learning.
:param `task_name`: The name of the task to be executed.
:param `client_agent_config`: The configuration for the client agent.
:param `model`: [Optional] The model to be used for the task.
:param `meta_data`: [Optional] The metadata for the task.
:return `model_local`: The local model after the task is executed. [Return `None` if the task does not return a model.]
:return `meta_data_local`: The local metadata after the task is executed. [Return `{}` if the task does not return metadata.]
"""
from appfl.agent import ClientAgent
from appfl.comm.globus_compute.utils.client_utils import load_global_model, send_local_model
client_agent = ClientAgent(client_agent_config=client_agent_config)
if model is not None:
model = load_global_model(client_agent.client_agent_config, model)
client_agent.load_parameters(model)
if task_name == "get_sample_size":
return None, {
"sample_size": client_agent.get_sample_size()
}
elif task_name == "train":
client_agent.train()
local_model = client_agent.get_parameters()
if isinstance(local_model, tuple):
local_model, meta_data_local = local_model
else:
meta_data_local = {}
local_model = send_local_model(
client_agent.client_agent_config,
local_model,
meta_data["local_model_key"],
meta_data["local_model_url"],
)
return local_model, meta_data_local