Example: Add a Custom Action¶
In APPFL
, the server supports several actions such as getting general client configurations, getting the global model parameters, and updating the global model parameters (i.e., federated training). However, in some cases, you may want to add a custom action to the server, such as fedearated evaluation. In this example, we show how to add a custom action to the server to generate a data readiness report for all the clients local datasets.
Client Side Implementation¶
In this example, we focus on the client-driven communication pattern (MPI and gRPC), where the clients sends requests to the server for any actions they want to perform. In this case, the client-side can simply define a function to generate the data readiness report for its local dataset, and then send a request to the server to generate the report for all the clients. The server handles the action synchronously, meaning it waits to receive requests from all clients before proceeding with the generation of the aggregated readiness report.
However, as the APPFL
defines client agent appfl.agent.ClientAgent
to act on behalf of the client, we highly recommend to define the custom action within the client agent either by extending the appfl.agent.ClientAgent
or by adding a new method to the existing appfl.agent.ClientAgent
. In this example, we create a new method for the existing client to generate the data readiness report.
Note
If you think your custom action is useful for the community, please consider define it within the appfl.client.ClientAgent
directly and contribute it to the APPFL
framework by creating a pull request.
New Method in Client Agent¶
In the ClientAgent
, we define a new function generate_readiness_report
to generate the data readiness report for the local dataset. The function will be called by the client to generate the readiness report for its local dataset. The function returns a dictionary containing the readiness metrics and plots. Below is the implementation of the function generate_readiness_report
in the client agent.
ClientAgent
which can generate data readiness evaluations and plots for the dataset.¶
def generate_readiness_report(self, client_config):
"""
Generate data readiness report based on the configuration provided by the server.
"""
if hasattr(client_config.data_readiness_configs, "dr_metrics"):
results = {}
plot_results = {"plots": {}}
to_combine_results = {"to_combine": {}}
# Determine how to retrieve data input and labels based on dataset attributes
if hasattr(self.train_dataset, "data_label"):
data_labels = self.train_dataset.data_label.tolist()
else:
data_labels = [label.item() for _, label in self.train_dataset]
# data_labels = [label for _, label in self.train_dataset]
if hasattr(self.train_dataset, "data_input"):
data_input = self.train_dataset.data_input
else:
data_input = torch.stack(
[input_data for input_data, _ in self.train_dataset]
)
# data_input, data_labels = balance_data(data_input, data_labels)
# data_input, explained_variance = apply_pca(data_input)
# data_input = normalize_data(data_input)
# Define metrics with corresponding computation functions
standard_metrics = {
"class_imbalance": lambda: round(imbalance_degree(data_labels), 2),
"sample_size": lambda: len(data_labels),
"num_classes": lambda: len(set(data_labels)),
"data_shape": lambda: (len(data_input), *data_input[0].size()),
"completeness": lambda: completeness(data_input),
"data_range": lambda: get_data_range(data_input),
"overall_sparsity": lambda: sparsity(data_input),
"variance": lambda: variance(data_input),
"skewness": lambda: skewness(data_input),
"entropy": lambda: entropy(data_input),
"kurtosis": lambda: kurtosis(data_input),
"class_distribution": lambda: class_distribution(data_labels),
"brisque": lambda: brisque(data_input),
"total_variation": lambda: total_variation(data_input),
"sharpness": lambda: dataset_sharpness(data_input),
"outlier_proportion": lambda: calculate_outlier_proportion(data_input),
"time_to_event_imbalance": lambda: quantify_time_to_event_imbalance(
data_labels
),
}
plots = {
"class_distribution_plot": lambda: plot_class_distribution(data_labels),
"data_sample_plot": lambda: plot_data_sample(data_input),
"data_distribution_plot": lambda: plot_data_distribution(data_input),
"class_variance_plot": lambda: plot_class_variance(
data_input, data_labels
),
"outlier_detection_plot": lambda: plot_outliers(data_input),
# "time_to_event_plot": lambda: plot_time_to_event_distribution(data_labels), # TODO: Add time to event plot
"feature_correlation_plot": lambda: plot_feature_correlations(
data_input
),
"feature_statistics_plot": lambda: plot_feature_statistics(data_input),
}
combine = {
"feature_space_distribution": lambda: get_feature_space_distribution(
data_input
),
}
# Handle standard metrics
for metric_name, compute_function in standard_metrics.items():
if hasattr(client_config.data_readiness_configs, "dr_metrics"):
if metric_name in client_config.data_readiness_configs.dr_metrics:
if getattr(
client_config.data_readiness_configs.dr_metrics, metric_name
):
results[metric_name] = compute_function()
# Handle plot-specific metrics
for metric_name, compute_function in plots.items():
if hasattr(client_config.data_readiness_configs.dr_metrics, "plot"):
if (
metric_name
in client_config.data_readiness_configs.dr_metrics.plot
):
if getattr(
client_config.data_readiness_configs.dr_metrics.plot,
metric_name,
):
plot_results["plots"][metric_name] = compute_function()
# Combine results with plot results
results.update(plot_results)
# Handle combined metrics
for metric_name, compute_function in combine.items():
if hasattr(client_config.data_readiness_configs.dr_metrics, "combine"):
if (
metric_name
in client_config.data_readiness_configs.dr_metrics.combine
):
if getattr(
client_config.data_readiness_configs.dr_metrics.combine,
metric_name,
):
to_combine_results["to_combine"][metric_name] = (
compute_function()
)
results.update(to_combine_results)
return results
else:
return "Data readiness metrics not available in configuration"
The corresponding computation functions for the metrics and plots are defined in the src.appfl.misc.data_readiness.metrics.py
and src.appfl.misc.data_readiness.plots.py
files respectively.
Send Request to the Server¶
Whether it is MPI or gRPC communicator, APPFL
provides an interface function, invoke_custom_action
, for user to send any custom request to the server, as shown below. The handler for the custom action should be defined on the server side, as shown in the next section.
Note
See the sections Launch Experiments for details on how to create the communicator.
client_communicator = ... # this can be either MPI or gRPC communicator, see the following sections
data_readiness = client_agent.generate_readiness_report(client_config)
client_communicator.invoke_custom_action(action='get_data_readiness_report', **data_readiness)
Server Side Implementation¶
On the server side, the server should update a handler for processing the custom action request from the client. In this example, we will show how to update the custom action handler for both the MPI and gRPC communicator.
MPI Communicator¶
User needs to update the APPFL
source code’s MPI server communicator at appfl.comm.mpi.mpi_server_communicator.MPIServerCommunicator
. You need to update its _invoke_custom_action
method to handle the custom action request from the client. Here is an example implementation for action get_data_readiness_report
which synchronously waits for all clients to send their readiness evaluations and metadata. Once all clients have submitted their data, it aggregates the results and sends them to the server agent to generate and output the readiness report.
def _invoke_custom_action(
self,
client_id: int,
request: MPITaskRequest,
) -> Optional[MPITaskResponse]:
...
if action == "set_sample_size":
...
elif action == "close_connection":
...
elif action == "get_data_readiness_report":
num_clients = self.server_agent.get_num_clients()
if not hasattr(self, "_dr_metrics_lock"):
self._dr_metrics = {}
self._dr_metrics_client_ids = set()
self._dr_metrics_lock = threading.Lock()
with self._dr_metrics_lock:
self._dr_metrics_client_ids.add(client_id)
for k, v in meta_data.items():
if k not in self._dr_metrics:
self._dr_metrics[k] = {}
self._dr_metrics[k][client_id] = v
if len(self._dr_metrics_client_ids) == num_clients:
self.server_agent.data_readiness_report(self._dr_metrics)
response = MPITaskResponse(
status=MPIServerStatus.RUN.value,
)
response_bytes = response_to_byte(response)
for client_id in self._dr_metrics_client_ids:
self.comm.Send(response_bytes, dest=client_id, tag=client_id)
self._dr_metrics = {}
self._dr_metrics_client_ids = set()
return None
else:
raise NotImplementedError(f"Custom action {action} is not implemented.")
gRPC Communicator¶
Similar to the MPI communicator, user needs to update the APPFL
source code’s gRPC server communicator at appfl.comm.grpc.grpc_server_communicator.GRPCServerCommunicator
. You need to update its InvokeCustomAction
method to handle the custom action request from the client. Here is an example implementation for action get_data_readiness_report
which synchronously waits for all clients to send their readiness evaluations and metadata. Once all clients have submitted their data, it aggregates the results and sends them to the server to generate and output the readiness report.
def InvokeCustomAction(self, request, context):
action = request.action
meta_data = yaml.safe_load(request.meta_data) if len(request.meta_data) > 0 else {}
if action == "set_sample_size":
...
elif action == "close_connection":
...
elif action == "get_data_readiness_report":
num_clients = self.server_agent.get_num_clients()
if not hasattr(self, "_dr_metrics_lock"):
self._dr_metrics = {}
self._dr_metrics_futures = {}
self._dr_metrics_lock = threading.Lock()
with self._dr_metrics_lock:
for k, v in meta_data.items():
if k not in self._dr_metrics:
self._dr_metrics[k] = {}
self._dr_metrics[k][client_id] = v
_dr_metric_future = Future()
self._dr_metrics_futures[client_id] = _dr_metric_future
if len(self._dr_metrics_futures) == num_clients:
self.server_agent.data_readiness_report(self._dr_metrics)
for client_id, future in self._dr_metrics_futures.items():
future.set_result(None)
self._dr_metrics = {}
self._dr_metrics_futures = {}
# waiting for the data readiness report to be generated for synchronization
_dr_metric_future.result()
response = CustomActionResponse(
header=ServerHeader(status=ServerStatus.DONE),
)
return response
else:
raise NotImplementedError(f"Custom action {action} is not implemented.")
Server Agent Report Generation¶
The ServerAgent
should have a method to generate the readiness report for all the clients. The method should take the readiness evaluations and metadata from all the clients and generate the aggregated readiness report. Here the server generates an HTML and JSON report and outputs it to .output
directory. Shown below is the implementation of the method data_readiness_report
in the ServerAgent
.
ServerAgent
that generates a single aggregated readiness report, outputting the results in both HTML and JSON formats for all clients.¶
def data_readiness_report(self, readiness_report: Dict) -> None:
"""
Generate the data readiness report and save it to the output directory.
"""
output_dir = self.server_agent_config.client_configs.data_readiness_configs.get(
"output_dirname", "./output"
)
output_filename = (
self.server_agent_config.client_configs.data_readiness_configs.get(
"output_filename", "data_readiness_report"
)
)
if not os.path.exists(output_dir):
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
# Save JSON report
# json_file_path = get_unique_file_path(output_dir, output_filename, "json")
# save_json_report(json_file_path, readiness_report, self.logger)
# Generate and save HTML report
html_file_path = get_unique_file_path(output_dir, output_filename, "html")
html_content = generate_html_content(readiness_report)
save_html_report(html_file_path, html_content, self.logger)
self._data_readiness_reports = {}
The helper functions for generating the HTML and JSON readiness report are defined in the src.appfl.misc.data_readiness.report.py
file.
Launch Experiments¶
Data readiness report generation is integrated into the standard workflow. The report will be generated before the local training and global model update iterations. To generate the report, you must set generate_dr_report: True
in the server configuration, along with specifying which metrics and plots to include in the evaluation. These settings should be defined in the client_configs
section of the server configuration file. Below is an example of how to set the configurations for the data readiness report generation.
client_configs:
...
data_readiness_configs:
generate_dr_report: True # Enable or disable the generation of data readiness report
output_dirname: "./output" # Directory to save the report
output_filename: "data_readiness_report" # Name of the report file
dr_metrics: # Metrics to evaluate data readiness
class_imbalance: True # Check for class imbalance degree
sample_size: True # Evaluate the sample size
...
plot: # Plots to include in the report
class_distribution_plot: True # Generate a class distribution plot
...
MPI Experiment¶
Once the configurations are set, you can run the experiment using the following command while in the examples
directory.
mpiexec -n 6 python mpi/run_mpi.py
gRPC Experiment¶
Once the configurations are set, you can run the experiment while in the examples
directory. To run the experiment, you need to start the server first in one terminal, and then start two clients in two separate terminals. You can run the server using the following command:
python grpc/run_server.py
Then, you can run the clients using the following command:
python grpc/run_client.py --config resources/configs/mnist/client_1.yaml
python grpc/run_client.py --config resources/configs/mnist/client_2.yaml # run this in a separate terminal
The server will generate the data readiness report for all the clients and save it in the output
directory with the name as specified in the configuration file.
examples
|--- output
|--- data_readiness_report.html
|--- data_readiness_report.json