{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FL Configurations\n", "In this notebook, we will showcase how to load and set configurations for federated learning (FL) server and clients in order to launch FL experiments." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Configuration from a YAML File\n", "APPFL employes [OmegaConf](https://omegaconf.readthedocs.io/) package, a hierarchical configuration system, for loading configurations for FL server and clients from YAML files.\n", "\n", "For example, `examples/resources/configs/mnist/server_fedavg.yaml` contains the server configurations for an FL experiment on the MNIST dataset using the `FedAvg` server aggregation algorithm. \n", "\n", "As shown below, the configuration file is primarily composed of two main \n", "\n", "- `client_configs` \n", "- `server_configs`\n", "\n", "Does it look a bit confusing that the server configuration file also contains `client_configs` at the initial look? This is because, in federated learning, we usually want certain client-side configurations to be the same among all the clients, for example, the local trainer and its corresponding hyperparameters, the ML model architecture, and the compression settings. Therefore, it becomes much more convenient to first specify all those settings and configurations on the server side to ensure uniformity, and then send those configurations to all clients at the beginning of the FL experiment." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============Level one configuration fields============\n", "client_configs\n", "server_configs\n", "============Detailed server configurations============\n", "client_configs:\n", " train_configs:\n", " trainer: VanillaTrainer\n", " mode: step\n", " num_local_steps: 100\n", " optim: Adam\n", " optim_args:\n", " lr: 0.001\n", " loss_fn_path: ./resources/loss/celoss.py\n", " loss_fn_name: CELoss\n", " do_validation: true\n", " do_pre_validation: true\n", " metric_path: ./resources/metric/acc.py\n", " metric_name: accuracy\n", " use_dp: false\n", " epsilon: 1\n", " clip_grad: false\n", " clip_value: 1\n", " clip_norm: 1\n", " train_batch_size: 64\n", " val_batch_size: 64\n", " train_data_shuffle: true\n", " val_data_shuffle: false\n", " model_configs:\n", " model_path: ./resources/model/cnn.py\n", " model_name: CNN\n", " model_kwargs:\n", " num_channel: 1\n", " num_classes: 10\n", " num_pixel: 28\n", " comm_configs:\n", " compressor_configs:\n", " enable_compression: false\n", " lossy_compressor: SZ2Compressor\n", " lossless_compressor: blosc\n", " error_bounding_mode: REL\n", " error_bound: 0.001\n", " param_cutoff: 1024\n", "server_configs:\n", " num_clients: 2\n", " scheduler: SyncScheduler\n", " scheduler_kwargs:\n", " same_init_model: true\n", " aggregator: FedAvgAggregator\n", " aggregator_kwargs:\n", " client_weights_mode: equal\n", " device: cpu\n", " num_global_epochs: 10\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", " comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n", "======================================================\n" ] } ], "source": [ "from omegaconf import OmegaConf\n", "\n", "server_config_file = \"../../examples/resources/configs/mnist/server_fedavg.yaml\"\n", "server_config = OmegaConf.load(server_config_file)\n", "print(\"============Level one configuration fields============\")\n", "for key in server_config:\n", " print(key)\n", "print(\"============Detailed server configurations============\")\n", "print(OmegaConf.to_yaml(server_config))\n", "print(\"======================================================\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Client Configurations\n", "For client configurations that are shared among all clients, it is composed of three main components:\n", "\n", "- `train_configs`: This component contains all training-related configurations, which can be further classified into the following sub-components:\n", "\n", " - *Trainer configurations*: It should be noted that the required trainer configurations depend on the trainer you use. You can also define your own trainer with any additional configurations you need, and then provide those configurations under `client_config.train_configs` in the server configuration yaml file.\n", "\n", " - `trainer`: The class name of the trainer you would like to use for client local training. The trainer name should be defined in `src/appfl/trainer`. For example, `VanillaTrainer` simply updates the model for a certain number of epochs or batches.\n", " - `mode`: For `VanillaTrainer`, mode is a required configuration to with allowable values `epoch` or `step` to specify whether you want to train for a certain number of epochs or only a certain number of steps/batches.\n", " - `num_local_steps`/`num_local_epochs`: Number of steps (if `mode=step`) or epochs (if `mode=epoch`) for an FL client in each local training round.\n", " - `optim`: Name of the optimizer to use from the `torch.optim` module.\n", " - `optim_args`: Keyword arguments for the selected optimizer.\n", " - `do_validation`: Whether to perform client-side validation in each training round.\n", " - `do_pre_validation`: Whether to perform client-side validation prior to local training.\n", " - `use_dp`: Whether to use differential privacy.\n", " - `epsilon`, `clip_grad`, `clip_value`, `clip_norm`: Parameters used if differential privacy is enabled.\n", " - *Loss function*: To specify the loss function to use during local training, we provide two options:\n", " - Loss function from `torch`: By providing the name of the loss function available in `torch.nn` (e.g., `CrossEntropyLoss`) in `loss_fn` and corresponding arguments in `loss_fn_kwargs`, user can employ loss function available in PyTorch.\n", " - Loss function defined in local file: User can define their own loss function by inheriting `nn.Module` and defining its `forward()` function. Then the user needs to provide the path to the defined loss function file in `loss_fn_path`, and the class name of the defined loss function in `loss_fn_name`.\n", " - *Metric function*: To specify the metric function used during validation, user need to provide path to the file containing the metric function in `metric_path` and the name of the metric function in `metric_name`. \n", " - *Dataloader settings*: While the server-side configuration does not contain any information about each client's local dataset, it can specify the configurations when converting the dataset to dataloader, such as the batch size and whether to shuffle.\n", "- `model_configs`: This component contains the definition of the machine learning model used in the FL experiment. The model architecture should be defined as a `torch.nn.Module` in a local file on the server-side and then provides the following information:\n", "\n", " - `model_path`: Path to the model definition file.\n", " - `model_name`: Class name of the defined model.\n", " - `model_kwargs`: Keyword arguments for initiating a model.\n", "- `comm_configs`: This component contains the settings for the communication between the FL server and clients, such as the `compression_configs`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to the `client_configs` contained in the server configuration YAML file, each client also needs its own client configuration YAML file to specify client-specific configurations, such as the device to use and the way to load the private local dataset. Let's first take a look at the client-specific configuration.\n", "\n", "- `client_id`: This is the **unique** ID (among all clients in the FL experiments) used for logging purposes.\n", "- `train_configs`: This contains the device to use in training, and some logging configurations.\n", "- `data_configs`: This is the most important component in the client configuration file. It contains the path in client's local machine to the file which defines how to load the local dataset (`dataset_path`), the name of the function in the file to load the dataset (`dataset_name`), and any keyword arguments if needed (`dataset_kwargs`).\n", "- `comm_configs`: The client may also need to specify some communication settings in order to connect to the server. For example, if the experiment uses gRPC as the communication method, then the client needs to specify the `server_uri`, `max_message_size`, and `use_ssl` to establish the connection to the server." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "client_id: Client1\n", "train_configs:\n", " device: cpu\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", "data_configs:\n", " dataset_path: ./resources/dataset/mnist_dataset.py\n", " dataset_name: get_mnist\n", " dataset_kwargs:\n", " num_clients: 2\n", " client_id: 0\n", " partition_strategy: class_noniid\n", " visualization: true\n", " output_dirname: ./output\n", " output_filename: visualization.pdf\n", "comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "client_config_file = \"../../examples/resources/configs/mnist/client_1.yaml\"\n", "client_config = OmegaConf.load(client_config_file)\n", "print(OmegaConf.to_yaml(client_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Merging the general client configurations from the server configuration YAML file and the specific client configurations in the client configuration YAML file will give all we need for a client to conduct the FL experiment." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_configs:\n", " trainer: VanillaTrainer\n", " mode: step\n", " num_local_steps: 100\n", " optim: Adam\n", " optim_args:\n", " lr: 0.001\n", " loss_fn_path: ./resources/loss/celoss.py\n", " loss_fn_name: CELoss\n", " do_validation: true\n", " do_pre_validation: true\n", " metric_path: ./resources/metric/acc.py\n", " metric_name: accuracy\n", " use_dp: false\n", " epsilon: 1\n", " clip_grad: false\n", " clip_value: 1\n", " clip_norm: 1\n", " train_batch_size: 64\n", " val_batch_size: 64\n", " train_data_shuffle: true\n", " val_data_shuffle: false\n", " device: cpu\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", "model_configs:\n", " model_path: ./resources/model/cnn.py\n", " model_name: CNN\n", " model_kwargs:\n", " num_channel: 1\n", " num_classes: 10\n", " num_pixel: 28\n", "comm_configs:\n", " compressor_configs:\n", " enable_compression: false\n", " lossy_compressor: SZ2Compressor\n", " lossless_compressor: blosc\n", " error_bounding_mode: REL\n", " error_bound: 0.001\n", " param_cutoff: 1024\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "client_id: Client1\n", "data_configs:\n", " dataset_path: ./resources/dataset/mnist_dataset.py\n", " dataset_name: get_mnist\n", " dataset_kwargs:\n", " num_clients: 2\n", " client_id: 0\n", " partition_strategy: class_noniid\n", " visualization: true\n", " output_dirname: ./output\n", " output_filename: visualization.pdf\n", "\n" ] } ], "source": [ "general_client_config = server_config.client_configs\n", "all_client_config = OmegaConf.merge(general_client_config, client_config)\n", "print(OmegaConf.to_yaml(all_client_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Server Configurations\n", "Now, let's take a look at what is needed for the configurations of FL server. Specifically, it contains the following key components:\n", "\n", "- *Scheduler configurations*: User can specify the name of the scheduler (`scheduler`), and the corresponding keyword arguments (`scheduler_kwargs`). All supported schedulers are available at `src/appfl/scheduler`.\n", "- *Aggregator configurations*: User can specify the name of the aggregator (`aggregator`), and the corresponding keyword arguments (`aggregator_kwargs`). All supported aggregators are available at `src/appfl/aggregator`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_clients: 2\n", "scheduler: SyncScheduler\n", "scheduler_kwargs:\n", " same_init_model: true\n", "aggregator: FedAvgAggregator\n", "aggregator_kwargs:\n", " client_weights_mode: equal\n", "device: cpu\n", "num_global_epochs: 10\n", "logging_output_dirname: ./output\n", "logging_output_filename: result\n", "comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "print(OmegaConf.to_yaml(server_config.server_configs))" ] } ], "metadata": { "kernelspec": { "display_name": "appfl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }