{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Serial FL Simulation\n", "In this notebook, we are going to show how to simulate an FL experiment on a single machine by having each client running serially. It should be noted that only simulating synchronous FL algorithms makes sense when the experiments running serially. We use 10 clients running serially in this example." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "num_clients = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load server configurations\n", "In the example, we are going to use the `FedAvg` server aggregation algorithm and the MNIST dataset by loading the server configurations from `examples/resources/configs/mnist/server_fedavg.yaml`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "client_configs:\n", " train_configs:\n", " trainer: VanillaTrainer\n", " mode: step\n", " num_local_steps: 100\n", " optim: Adam\n", " optim_args:\n", " lr: 0.001\n", " loss_fn_path: ./resources/loss/celoss.py\n", " loss_fn_name: CELoss\n", " do_validation: true\n", " do_pre_validation: true\n", " metric_path: ./resources/metric/acc.py\n", " metric_name: accuracy\n", " use_dp: false\n", " epsilon: 1\n", " clip_grad: false\n", " clip_value: 1\n", " clip_norm: 1\n", " train_batch_size: 64\n", " val_batch_size: 64\n", " train_data_shuffle: true\n", " val_data_shuffle: false\n", " model_configs:\n", " model_path: ./resources/model/cnn.py\n", " model_name: CNN\n", " model_kwargs:\n", " num_channel: 1\n", " num_classes: 10\n", " num_pixel: 28\n", " comm_configs:\n", " compressor_configs:\n", " enable_compression: false\n", " lossy_compressor: SZ2Compressor\n", " lossless_compressor: blosc\n", " error_bounding_mode: REL\n", " error_bound: 0.001\n", " param_cutoff: 1024\n", "server_configs:\n", " scheduler: SyncScheduler\n", " scheduler_kwargs:\n", " num_clients: 2\n", " same_init_model: true\n", " aggregator: FedAvgAggregator\n", " aggregator_kwargs:\n", " client_weights_mode: equal\n", " device: cpu\n", " num_global_epochs: 10\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", " comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "from omegaconf import OmegaConf\n", "server_config_file = \"../../examples/resources/configs/mnist/server_fedavg.yaml\"\n", "server_config = OmegaConf.load(server_config_file)\n", "print(OmegaConf.to_yaml(server_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "💡 It should be noted that configuration fields such as `loss_fn_path`, `metric_path`, and `model_path` are the paths to the corresponding files, so we need to change their relative paths now to make sure the paths point to the right files. \n", "\n", "💡 We also change the `num_global_epochs` from 10 to 5.\n", "\n", "⚠️ We also need change `num_clients` in `server_configs.scheduler_kwargs` to 10." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "server_config.client_configs.train_configs.loss_fn_path = '../../examples/resources/loss/celoss.py'\n", "server_config.client_configs.train_configs.metric_path = '../../examples/resources/metric/acc.py'\n", "server_config.client_configs.model_configs.model_path = '../../examples/resources/model/cnn.py'\n", "server_config.server_configs.num_global_epochs = 5\n", "server_config.server_configs.scheduler_kwargs.num_clients = num_clients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load client configurations\n", "In this example, we suppose that `num_clients=10` and load the basic configurations for all the clients from `examples/configs/mnist/client_1.yaml`. Let's first take a look at this basic configuration." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_configs:\n", " device: cpu\n", " logging_id: Client1\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", "data_configs:\n", " dataset_path: ./resources/dataset/mnist_dataset.py\n", " dataset_name: get_mnist\n", " dataset_kwargs:\n", " num_clients: 2\n", " client_id: 0\n", " partition_strategy: class_noniid\n", " visualization: true\n", " output_dirname: ./output\n", " output_filename: visualization.pdf\n", "comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "client_config_file = \"../../examples/resources/configs/mnist/client_1.yaml\"\n", "client_config = OmegaConf.load(client_config_file)\n", "print(OmegaConf.to_yaml(client_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the configuration above, it should be mentioned that `data_configs` contains the necessary configurations to load the simulated \"local\" datasets for each client. Specifically,\n", "\n", "- `dataset_path` is the path to the file that contains the function to load the dataset\n", "- `dataset_name` is the function nation that loads the dataset in the above file\n", "- `dataset_kwargs` are the keyword arguments for that function\n", "\n", "In the `get_mnist` function above, it partitions the MNIST dataset into `num_clients` client splits in an IID/non-IID (IID: independent identically distributed) manner. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need to modify the general client configurations for different clients. Specifically, we make the following changes:\n", "\n", "- Change `logging_id` for each client\n", "- Change the relative path of `dataset_path` to make it point to the right file\n", "- Change `num_clients` to 10 and `client_id` to [0, 1, .., 9] for different clients.\n", "- Change `visualization` to False for nine clients to only have one data distribution visualization plots." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import copy\n", "client_configs = [\n", " copy.deepcopy(client_config) for _ in range(num_clients)\n", "]\n", "for i in range(num_clients):\n", " client_configs[i].train_configs.logging_id = f'Client_{i+1}'\n", " client_configs[i].data_configs.dataset_path = '../../examples/resources/dataset/mnist_dataset.py'\n", " client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients\n", " client_configs[i].data_configs.dataset_kwargs.client_id = i\n", " client_configs[i].data_configs.dataset_kwargs.visualization = True if i == 0 else False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create FL server agent and client agents\n", "In APPFL, we use agent to act on behalf of the FL server and FL clients to do necessary steps for the federated learning experiments. User can easily create the agents using the server/client configurations we loaded (and modified a little bit) from the configuration yaml file. Creating the client agents will load the loacal dataset and plot the data distribution visualization as shown below." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2024-12-01 22:18:34,805 INFO server]: Logging to ./output/result_Server_2024-12-01-22:18:34.txt\n", "[2024-12-01 22:18:34,810 INFO Client_1]: Logging to ./output/result_Client_1_2024-12-01-22:18:34.txt\n", "[2024-12-01 22:18:42,062 INFO Client_2]: Logging to ./output/result_Client_2_2024-12-01-22:18:42.txt\n", "[2024-12-01 22:18:49,106 INFO Client_3]: Logging to ./output/result_Client_3_2024-12-01-22:18:49.txt\n", "[2024-12-01 22:18:55,840 INFO Client_4]: Logging to ./output/result_Client_4_2024-12-01-22:18:55.txt\n", "[2024-12-01 22:19:02,595 INFO Client_5]: Logging to ./output/result_Client_5_2024-12-01-22:19:02.txt\n", "[2024-12-01 22:19:09,300 INFO Client_6]: Logging to ./output/result_Client_6_2024-12-01-22:19:09.txt\n", "[2024-12-01 22:19:15,998 INFO Client_7]: Logging to ./output/result_Client_7_2024-12-01-22:19:15.txt\n", "[2024-12-01 22:19:22,849 INFO Client_8]: Logging to ./output/result_Client_8_2024-12-01-22:19:22.txt\n", "[2024-12-01 22:19:29,778 INFO Client_9]: Logging to ./output/result_Client_9_2024-12-01-22:19:29.txt\n", "[2024-12-01 22:19:36,495 INFO Client_10]: Logging to ./output/result_Client_10_2024-12-01-22:19:36.txt\n" ] } ], "source": [ "from appfl.agent import ServerAgent, ClientAgent\n", "server_agent = ServerAgent(server_agent_config=server_config)\n", "client_agents =[\n", " ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start the training process" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The server configuration files contains many client configurations which should apply for ALL clients. Now, we need to get those configurations from the server and provide them to the client agents." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Get additional client configurations from the server\n", "client_config_from_server = server_agent.get_client_configs()\n", "for client_agent in client_agents:\n", " client_agent.load_config(client_config_from_server)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, let the clients load initial global model from the server and optionally send the number of local data to the server for weighted aggregation.\n", "\n", "💡 **Note**: Typically, `server_agent.get_parameters()` blocks the result return or returns a `Future` object, and only returns the global model after receiving `num_clients` calls to synchronoize to process for clients to get the initial model, ensuring all the clients have the same initial model weights. However, as we are doing serial simulation, we don't want the blocking, so we pass `serail_run=True` when calling the function to get the global model immediately." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Load initial global model from the server\n", "init_global_model = server_agent.get_parameters(serial_run=True)\n", "for client_agent in client_agents:\n", " client_agent.load_parameters(init_global_model)\n", "\n", "# [Optional] Set number of local data to the server\n", "for i in range(num_clients):\n", " sample_size = client_agents[i].get_sample_size()\n", " server_agent.set_sample_size(\n", " client_id=client_agents[i].get_id(), \n", " sample_size=sample_size\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can start the training iterations. Please note the following points:\n", "\n", "- `server_agent.training_finished` will return a boolean flag indicating whether the training has reached the specified `num_global epochs`\n", "- `client_agent.train` trains the client local model using the client's \"local\" data\n", "- `server_agent.global_update` takes one client's local model together with a client id (can be get by `client_agent.get_id`) to schedule the global update for the client local model. For synchronous server aggregation algorithms such as `FedAvg`, they will not update the global model until receiving local models from all `num_clients=10` clients, so the call to `global_update` will return a `concurrent.futures.Future` object (if you set `blocking=False`, otherwise, it will block forever for serial simulation), which will be set after all client local models are sent for global update. \n", "- In the output log, `Pre Val?` means whether it is validation prior to the local training. As each client only holds data from 3 to 5 classes, the validation accuracy even drops after local training. However, the global model accuracy continues to increase, showcasing the capabilities of federated learning in improving the generalizability of the trained machine learning model." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[2024-12-01 22:19:43,300 INFO Client_1]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:19:44,369 INFO Client_1]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:19:47,698 INFO Client_1]: 0 N 3.3286 0.4323 90.8109 15.4995 30.3500\n", "[2024-12-01 22:19:47,700 INFO Client_2]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:19:48,773 INFO Client_2]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:19:51,901 INFO Client_2]: 0 N 3.1272 0.3130 87.7656 9.9495 48.6000\n", "[2024-12-01 22:19:51,903 INFO Client_3]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:19:52,976 INFO Client_3]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:19:55,965 INFO Client_3]: 0 N 2.9883 0.3909 85.0781 10.4302 38.2900\n", "[2024-12-01 22:19:55,967 INFO Client_4]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:19:57,044 INFO Client_4]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:00,023 INFO Client_4]: 0 N 2.9778 0.4242 89.8377 11.0662 39.8000\n", "[2024-12-01 22:20:00,025 INFO Client_5]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:20:01,090 INFO Client_5]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:04,119 INFO Client_5]: 0 N 3.0275 0.1851 92.1406 22.1558 28.8700\n", "[2024-12-01 22:20:04,120 INFO Client_6]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:20:05,178 INFO Client_6]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:08,072 INFO Client_6]: 0 N 2.8929 0.6525 85.3739 14.9090 38.8200\n", "[2024-12-01 22:20:08,074 INFO Client_7]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:20:09,177 INFO Client_7]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:12,338 INFO Client_7]: 0 N 3.1597 0.2848 86.4844 10.2574 48.8900\n", "[2024-12-01 22:20:12,339 INFO Client_8]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:20:13,427 INFO Client_8]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:16,502 INFO Client_8]: 0 N 3.0735 0.3402 92.6434 14.2500 31.1400\n", "[2024-12-01 22:20:16,503 INFO Client_9]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:20:17,577 INFO Client_9]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:20,760 INFO Client_9]: 0 N 3.1821 0.3460 86.6562 9.1354 50.0000\n", "[2024-12-01 22:20:20,762 INFO Client_10]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "[2024-12-01 22:20:21,802 INFO Client_10]: 0 Y 2.3006 15.9300\n", "[2024-12-01 22:20:24,732 INFO Client_10]: 0 N 2.9294 0.2586 88.4375 12.8184 40.2200\n", "[2024-12-01 22:20:25,816 INFO Client_1]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:28,699 INFO Client_1]: 1 N 2.8818 0.1185 98.0902 9.2444 30.4200\n", "[2024-12-01 22:20:29,727 INFO Client_2]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:32,687 INFO Client_2]: 1 N 2.9588 0.1277 95.7188 8.0948 48.7200\n", "[2024-12-01 22:20:33,767 INFO Client_3]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:36,661 INFO Client_3]: 1 N 2.8933 0.1673 94.5312 8.9735 38.5400\n", "[2024-12-01 22:20:37,695 INFO Client_4]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:40,583 INFO Client_4]: 1 N 2.8872 0.1617 96.1872 9.9464 39.7100\n", "[2024-12-01 22:20:41,639 INFO Client_5]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:44,528 INFO Client_5]: 1 N 2.8879 0.0906 97.3438 13.3640 29.5100\n", "[2024-12-01 22:20:45,617 INFO Client_6]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:48,567 INFO Client_6]: 1 N 2.9492 0.1973 96.3711 9.5049 39.4300\n", "[2024-12-01 22:20:49,577 INFO Client_7]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:52,383 INFO Client_7]: 1 N 2.8060 0.1087 95.3750 5.8535 49.2100\n", "[2024-12-01 22:20:53,394 INFO Client_8]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:56,161 INFO Client_8]: 1 N 2.7659 0.0983 98.0466 9.4541 31.3000\n", "[2024-12-01 22:20:57,172 INFO Client_9]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:20:59,975 INFO Client_9]: 1 N 2.8018 0.1480 94.8281 6.5887 49.2400\n", "[2024-12-01 22:21:00,988 INFO Client_10]: 1 Y 1.8781 31.3900\n", "[2024-12-01 22:21:03,785 INFO Client_10]: 1 N 2.7959 0.1009 96.3438 11.2293 40.7800\n", "[2024-12-01 22:21:04,818 INFO Client_1]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:07,577 INFO Client_1]: 2 N 2.7580 0.0553 99.0451 6.5585 30.5100\n", "[2024-12-01 22:21:08,604 INFO Client_2]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:11,365 INFO Client_2]: 2 N 2.7600 0.0708 97.4844 5.9560 49.2300\n", "[2024-12-01 22:21:12,382 INFO Client_3]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:15,154 INFO Client_3]: 2 N 2.7708 0.0948 96.4219 7.6792 38.7600\n", "[2024-12-01 22:21:16,169 INFO Client_4]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:18,905 INFO Client_4]: 2 N 2.7350 0.0759 98.0936 8.1204 40.0200\n", "[2024-12-01 22:21:19,917 INFO Client_5]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:22,720 INFO Client_5]: 2 N 2.8013 0.0454 98.3750 8.8908 29.7500\n", "[2024-12-01 22:21:23,741 INFO Client_6]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:26,519 INFO Client_6]: 2 N 2.7768 0.0995 98.0120 8.2203 39.4500\n", "[2024-12-01 22:21:27,535 INFO Client_7]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:30,330 INFO Client_7]: 2 N 2.7942 0.0554 97.5781 4.9709 49.0400\n", "[2024-12-01 22:21:31,354 INFO Client_8]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:34,107 INFO Client_8]: 2 N 2.7515 0.0540 98.8658 8.9072 31.2200\n", "[2024-12-01 22:21:35,134 INFO Client_9]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:37,913 INFO Client_9]: 2 N 2.7780 0.0861 97.0000 5.5963 50.2700\n", "[2024-12-01 22:21:38,934 INFO Client_10]: 2 Y 0.9154 68.4300\n", "[2024-12-01 22:21:41,721 INFO Client_10]: 2 N 2.7860 0.0638 97.3594 8.1234 40.8700\n", "[2024-12-01 22:21:42,754 INFO Client_1]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:21:45,530 INFO Client_1]: 3 N 2.7753 0.0470 99.1077 7.0593 30.4200\n", "[2024-12-01 22:21:46,548 INFO Client_2]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:21:49,333 INFO Client_2]: 3 N 2.7839 0.0519 98.1250 5.5613 48.7300\n", "[2024-12-01 22:21:50,364 INFO Client_3]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:21:53,141 INFO Client_3]: 3 N 2.7767 0.0716 97.5469 5.5125 38.6900\n", "[2024-12-01 22:21:54,162 INFO Client_4]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:21:56,927 INFO Client_4]: 3 N 2.7643 0.0618 98.5032 6.4788 40.1200\n", "[2024-12-01 22:21:57,962 INFO Client_5]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:22:00,750 INFO Client_5]: 3 N 2.7874 0.0339 98.6719 8.3036 29.7600\n", "[2024-12-01 22:22:01,780 INFO Client_6]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:22:04,574 INFO Client_6]: 3 N 2.7931 0.0956 98.2644 5.1085 39.2800\n", "[2024-12-01 22:22:05,608 INFO Client_7]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:22:08,416 INFO Client_7]: 3 N 2.8076 0.0428 98.2031 4.1431 49.4900\n", "[2024-12-01 22:22:09,438 INFO Client_8]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:22:12,258 INFO Client_8]: 3 N 2.8195 0.0394 99.2596 7.1669 31.3200\n", "[2024-12-01 22:22:13,284 INFO Client_9]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:22:16,086 INFO Client_9]: 3 N 2.8008 0.0674 97.4844 4.7681 50.6100\n", "[2024-12-01 22:22:17,123 INFO Client_10]: 3 Y 0.6473 75.2900\n", "[2024-12-01 22:22:19,949 INFO Client_10]: 3 N 2.8250 0.0420 98.1094 8.0815 40.7800\n", "[2024-12-01 22:22:20,986 INFO Client_1]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:23,784 INFO Client_1]: 4 N 2.7980 0.0319 99.5147 7.0248 30.5000\n", "[2024-12-01 22:22:24,807 INFO Client_2]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:27,744 INFO Client_2]: 4 N 2.9366 0.0417 98.2812 4.9374 49.0200\n", "[2024-12-01 22:22:28,797 INFO Client_3]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:31,884 INFO Client_3]: 4 N 3.0864 0.0610 97.8906 5.2966 38.9000\n", "[2024-12-01 22:22:32,934 INFO Client_4]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:35,844 INFO Client_4]: 4 N 2.9087 0.0517 98.8026 5.3538 40.2100\n", "[2024-12-01 22:22:36,884 INFO Client_5]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:39,812 INFO Client_5]: 4 N 2.9274 0.0319 98.6719 9.3730 29.7600\n", "[2024-12-01 22:22:40,921 INFO Client_6]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:43,962 INFO Client_6]: 4 N 3.0398 0.0611 98.8640 5.7755 39.5000\n", "[2024-12-01 22:22:44,995 INFO Client_7]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:47,905 INFO Client_7]: 4 N 2.9082 0.0280 98.8281 5.6106 49.1600\n", "[2024-12-01 22:22:48,949 INFO Client_8]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:51,869 INFO Client_8]: 4 N 2.9192 0.0360 99.3541 6.4345 31.4000\n", "[2024-12-01 22:22:52,927 INFO Client_9]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:55,814 INFO Client_9]: 4 N 2.8864 0.0578 97.6406 3.9761 50.7900\n", "[2024-12-01 22:22:56,852 INFO Client_10]: 4 Y 0.5214 78.7300\n", "[2024-12-01 22:22:59,800 INFO Client_10]: 4 N 2.9476 0.0416 98.3281 6.6747 41.0200\n" ] } ], "source": [ "while not server_agent.training_finished():\n", " new_global_models = []\n", " for client_agent in client_agents:\n", " # Client local training\n", " client_agent.train()\n", " local_model = client_agent.get_parameters()\n", " if isinstance(local_model, tuple):\n", " local_model, metadata = local_model[0], local_model[1]\n", " else:\n", " metadata = {}\n", " # \"Send\" local model to server and get a Future object for the new global model\n", " # The Future object will be resolved when the server receives local models from all clients\n", " new_global_model_future = server_agent.global_update(\n", " client_id=client_agent.get_id(), \n", " local_model=local_model,\n", " blocking=False,\n", " **metadata\n", " )\n", " new_global_models.append(new_global_model_future)\n", " # Load the new global model from the server\n", " for client_agent, new_global_model_future in zip(client_agents, new_global_models):\n", " client_agent.load_parameters(new_global_model_future.result())" ] } ], "metadata": { "kernelspec": { "display_name": "fedcompass", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }