{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FL with compression\n", "\n", "In this notebook, we will show how to use lossy compressor in FL to compress the model parameters and reduce the communication cost." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install compressors\n", "\n", "To install the compressors, fist make sure that you have installed necessary packages for the compressors by running the following command in the `APPFL` directory.\n", "\n", "```bash\n", "pip install -e . # if you installed using source code\n", "```\n", "\n", "or \n", "\n", "```bash\n", "pip install appfl # if you installed directly using pypi\n", "```\n", "\n", "Then, you can easily install the lossy compressors by running the following command anywhere. It will download all the source code for compressors under `APPFL/.compressor`.\n", "\n", "```bash\n", "appfl-install-compressor\n", "```\n", "\n", "💡 `appfl.compressor` supports four compressors: [SZ2](https://github.com/szcompressor/SZ), [SZ3](https://github.com/szcompressor/SZ3), [ZFP](https://pypi.org/project/zfpy/), and [SZX](https://github.com/szcompressor/SZx). However, as SZx needs particular permission to access, so we have to omit its installation here. If you want to try with SZx, please contact its authors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load server and client configurations\n", "\n", "Following the same steps in the [serial FL example](serial_fl.ipynb), we load and modify the configurations for server and five clients." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "client_configs:\n", " train_configs:\n", " trainer: VanillaTrainer\n", " mode: step\n", " num_local_steps: 100\n", " optim: Adam\n", " optim_args:\n", " lr: 0.001\n", " loss_fn_path: ../../examples/resources/loss/celoss.py\n", " loss_fn_name: CELoss\n", " do_validation: true\n", " do_pre_validation: true\n", " metric_path: ../../examples/resources/metric/acc.py\n", " metric_name: accuracy\n", " use_dp: false\n", " epsilon: 1\n", " clip_grad: false\n", " clip_value: 1\n", " clip_norm: 1\n", " train_batch_size: 64\n", " val_batch_size: 64\n", " train_data_shuffle: true\n", " val_data_shuffle: false\n", " model_configs:\n", " model_path: ../../examples/resources/model/cnn.py\n", " model_name: CNN\n", " model_kwargs:\n", " num_channel: 1\n", " num_classes: 10\n", " num_pixel: 28\n", " comm_configs:\n", " compressor_configs:\n", " enable_compression: false\n", " lossy_compressor: SZ2Compressor\n", " lossless_compressor: blosc\n", " error_bounding_mode: REL\n", " error_bound: 0.001\n", " param_cutoff: 1024\n", "server_configs:\n", " num_clients: 5\n", " scheduler: SyncScheduler\n", " scheduler_kwargs:\n", " same_init_model: true\n", " aggregator: FedAvgAggregator\n", " aggregator_kwargs:\n", " client_weights_mode: equal\n", " device: cpu\n", " num_global_epochs: 3\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", " comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "import copy\n", "from omegaconf import OmegaConf\n", "\n", "num_clients = 5\n", "\n", "server_config_file = \"../../examples/resources/configs/mnist/server_fedavg.yaml\"\n", "server_config = OmegaConf.load(server_config_file)\n", "server_config.client_configs.train_configs.loss_fn_path = (\n", " \"../../examples/resources/loss/celoss.py\"\n", ")\n", "server_config.client_configs.train_configs.metric_path = (\n", " \"../../examples/resources/metric/acc.py\"\n", ")\n", "server_config.client_configs.model_configs.model_path = (\n", " \"../../examples/resources/model/cnn.py\"\n", ")\n", "server_config.server_configs.num_global_epochs = 3\n", "server_config.server_configs.num_clients = num_clients\n", "\n", "client_config_file = \"../../examples/resources/configs/mnist/client_1.yaml\"\n", "client_config = OmegaConf.load(client_config_file)\n", "client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]\n", "for i in range(num_clients):\n", " client_configs[i].client_id = f\"Client{i + 1}\"\n", " client_configs[\n", " i\n", " ].data_configs.dataset_path = \"../../examples/resources/dataset/mnist_dataset.py\"\n", " client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients\n", " client_configs[i].data_configs.dataset_kwargs.client_id = i\n", " client_configs[i].data_configs.dataset_kwargs.visualization = (\n", " True if i == 0 else False\n", " )\n", "\n", "print(OmegaConf.to_yaml(server_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Enable compression\n", "\n", "To enable compression, we just need to set `server_config.client_configs.comm_configs.compressor_configs.enable_compression` to True.\n", "\n", "💡 You may notices that both `server_config.server_configs` and `server_config.client_configs` have a `comm_configs` fields. Actually, when creating the server agent, its communication configurations will be the merging of `server_config.server_configs.comm_configs` and `server_config.client_configs.comm_configs`. However, `server_config.client_configs.comm_configs` will also be shared with clients, while `server_config.server_configs.comm_configs` will not. As we want the clients to be aware of the compressor configurations, we put `compressor_configs` under `server_config.client_configs.comm_configs` to share with the clients during the FL experiment." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the agents and start the experiment" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:18,444 server]: Logging to ./output/result_Server_2025-01-08-10-06-18.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:18,460 Client1]: Logging to ./output/result_Client1_2025-01-08-10-06-18.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:25,732 Client2]: Logging to ./output/result_Client2_2025-01-08-10-06-25.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:32,544 Client3]: Logging to ./output/result_Client3_2025-01-08-10-06-32.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:39,832 Client4]: Logging to ./output/result_Client4_2025-01-08-10-06-39.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:47,010 Client5]: Logging to ./output/result_Client5_2025-01-08-10-06-47.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:54,282 Client1]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:55,360 Client1]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:59,106 Client1]: 0 N 3.7451 0.3113 84.9531 6.7566 56.9400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:06:59,161 Client2]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:00,201 Client2]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:03,147 Client2]: 0 N 2.9452 0.3075 82.9688 8.5662 46.1100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:03,198 Client3]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:04,228 Client3]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:07,038 Client3]: 0 N 2.8099 0.1685 86.3125 6.6134 58.9900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:07,091 Client4]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:08,100 Client4]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:10,845 Client4]: 0 N 2.7446 0.2278 85.4844 6.8239 58.5100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:10,898 Client5]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:11,911 Client5]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:14,735 Client5]: 0 N 2.8234 0.2344 82.7500 9.4964 46.6900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:15,808 Client1]: 1 Y 1.6293 52.3100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:18,685 Client1]: 1 N 2.8768 0.0952 96.2031 4.4368 57.2000\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:19,806 Client2]: 1 Y 1.6293 52.3100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:22,713 Client2]: 1 N 2.9055 0.0961 95.0000 7.2504 46.2000\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:23,812 Client3]: 1 Y 1.6293 52.3100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:26,750 Client3]: 1 N 2.9370 0.0727 94.5469 5.1526 59.5900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:27,835 Client4]: 1 Y 1.6293 52.3100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:30,725 Client4]: 1 N 2.8892 0.0815 95.3750 5.5127 59.6400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:31,805 Client5]: 1 Y 1.6293 52.3100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:34,711 Client5]: 1 N 2.9055 0.1016 93.1875 5.8957 46.8500\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:35,810 Client1]: 2 Y 0.9978 70.3300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:38,703 Client1]: 2 N 2.8922 0.0485 97.9219 3.5707 57.9400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:39,782 Client2]: 2 Y 0.9978 70.3300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:42,690 Client2]: 2 N 2.9079 0.0610 96.7656 5.0698 47.3200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:43,794 Client3]: 2 Y 0.9978 70.3300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:46,690 Client3]: 2 N 2.8956 0.0407 96.6875 3.9961 60.2800\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:47,788 Client4]: 2 Y 0.9978 70.3300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:50,689 Client4]: 2 N 2.9001 0.0479 97.3906 3.5627 60.1800\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:51,773 Client5]: 2 Y 0.9978 70.3300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:07:54,684 Client5]: 2 N 2.9105 0.0704 95.1719 5.3279 47.7900\n" ] } ], "source": [ "from appfl.agent import ServerAgent, ClientAgent\n", "\n", "# Create server and client agents\n", "server_config.client_configs.comm_configs.compressor_configs.enable_compression = True\n", "server_agent = ServerAgent(server_agent_config=server_config)\n", "client_agents = [\n", " ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)\n", "]\n", "\n", "# Get additional client configurations from the server\n", "client_config_from_server = server_agent.get_client_configs()\n", "for client_agent in client_agents:\n", " client_agent.load_config(client_config_from_server)\n", "\n", "# Load initial global model from the server\n", "init_global_model = server_agent.get_parameters(serial_run=True)\n", "for client_agent in client_agents:\n", " client_agent.load_parameters(init_global_model)\n", "\n", "# [Optional] Set number of local data to the server\n", "for i in range(num_clients):\n", " sample_size = client_agents[i].get_sample_size()\n", " server_agent.set_sample_size(\n", " client_id=client_agents[i].get_id(), sample_size=sample_size\n", " )\n", "\n", "while not server_agent.training_finished():\n", " new_global_models = []\n", " for client_agent in client_agents:\n", " # Client local training\n", " client_agent.train()\n", " local_model = client_agent.get_parameters()\n", " if isinstance(local_model, tuple):\n", " local_model, metadata = local_model[0], local_model[1]\n", " else:\n", " metadata = {}\n", " # \"Send\" local model to server and get a Future object for the new global model\n", " # The Future object will be resolved when the server receives local models from all clients\n", " new_global_model_future = server_agent.global_update(\n", " client_id=client_agent.get_id(),\n", " local_model=local_model,\n", " blocking=False,\n", " **metadata,\n", " )\n", " new_global_models.append(new_global_model_future)\n", " # Load the new global model from the server\n", " for client_agent, new_global_model_future in zip(client_agents, new_global_models):\n", " client_agent.load_parameters(new_global_model_future.result())" ] } ], "metadata": { "kernelspec": { "display_name": "appfl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }