{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Serial FL Simulation\n", "In this notebook, we are going to show how to simulate an FL experiment on a single machine by having each client running serially. It should be noted that only simulating synchronous FL algorithms makes sense when the experiments running serially. We use 10 clients running serially in this example." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "num_clients = 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load server configurations\n", "In the example, we are going to use the `FedAvg` server aggregation algorithm and the MNIST dataset by loading the server configurations from `examples/resources/configs/mnist/server_fedavg.yaml`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "client_configs:\n", " train_configs:\n", " trainer: VanillaTrainer\n", " mode: step\n", " num_local_steps: 100\n", " optim: Adam\n", " optim_args:\n", " lr: 0.001\n", " loss_fn_path: ./resources/loss/celoss.py\n", " loss_fn_name: CELoss\n", " do_validation: true\n", " do_pre_validation: true\n", " metric_path: ./resources/metric/acc.py\n", " metric_name: accuracy\n", " use_dp: false\n", " epsilon: 1\n", " clip_grad: false\n", " clip_value: 1\n", " clip_norm: 1\n", " train_batch_size: 64\n", " val_batch_size: 64\n", " train_data_shuffle: true\n", " val_data_shuffle: false\n", " model_configs:\n", " model_path: ./resources/model/cnn.py\n", " model_name: CNN\n", " model_kwargs:\n", " num_channel: 1\n", " num_classes: 10\n", " num_pixel: 28\n", " comm_configs:\n", " compressor_configs:\n", " enable_compression: false\n", " lossy_compressor: SZ2Compressor\n", " lossless_compressor: blosc\n", " error_bounding_mode: REL\n", " error_bound: 0.001\n", " param_cutoff: 1024\n", "server_configs:\n", " num_clients: 2\n", " scheduler: SyncScheduler\n", " scheduler_kwargs:\n", " same_init_model: true\n", " aggregator: FedAvgAggregator\n", " aggregator_kwargs:\n", " client_weights_mode: equal\n", " device: cpu\n", " num_global_epochs: 10\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", " comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "from omegaconf import OmegaConf\n", "\n", "server_config_file = \"../../examples/resources/configs/mnist/server_fedavg.yaml\"\n", "server_config = OmegaConf.load(server_config_file)\n", "print(OmegaConf.to_yaml(server_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "💡 It should be noted that configuration fields such as `loss_fn_path`, `metric_path`, and `model_path` are the paths to the corresponding files, so we need to change their relative paths now to make sure the paths point to the right files. \n", "\n", "💡 We also change the `num_global_epochs` from 10 to 3.\n", "\n", "⚠️ We also need change `num_clients` in `server_configs` to 10." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "server_config.client_configs.train_configs.loss_fn_path = (\n", " \"../../examples/resources/loss/celoss.py\"\n", ")\n", "server_config.client_configs.train_configs.metric_path = (\n", " \"../../examples/resources/metric/acc.py\"\n", ")\n", "server_config.client_configs.model_configs.model_path = (\n", " \"../../examples/resources/model/cnn.py\"\n", ")\n", "server_config.server_configs.num_global_epochs = 3\n", "server_config.server_configs.num_clients = num_clients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load client configurations\n", "In this example, we suppose that `num_clients=10` and load the basic configurations for all the clients from `examples/configs/mnist/client_1.yaml`. Let's first take a look at this basic configuration." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "client_id: Client1\n", "train_configs:\n", " device: cpu\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", "data_configs:\n", " dataset_path: ./resources/dataset/mnist_dataset.py\n", " dataset_name: get_mnist\n", " dataset_kwargs:\n", " num_clients: 2\n", " client_id: 0\n", " partition_strategy: class_noniid\n", " visualization: true\n", " output_dirname: ./output\n", " output_filename: visualization.pdf\n", "comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "client_config_file = \"../../examples/resources/configs/mnist/client_1.yaml\"\n", "client_config = OmegaConf.load(client_config_file)\n", "print(OmegaConf.to_yaml(client_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the configuration above, it should be mentioned that `data_configs` contains the necessary configurations to load the simulated \"local\" datasets for each client. Specifically,\n", "\n", "- `dataset_path` is the path to the file that contains the function to load the dataset\n", "- `dataset_name` is the function nation that loads the dataset in the above file\n", "- `dataset_kwargs` are the keyword arguments for that function\n", "\n", "In the `get_mnist` function above, it partitions the MNIST dataset into `num_clients` client splits in an IID/non-IID (IID: independent identically distributed) manner. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need to modify the general client configurations for different clients. Specifically, we make the following changes:\n", "\n", "- Change `client_id` for each client\n", "- Change the relative path of `dataset_path` to make it point to the right file\n", "- Change `dataset_kwargs.num_clients` to 10 and `dataset_kwargs.client_id` to [0, 1, .., 9] for different clients.\n", "- Change `dataset_kwargs.visualization` to False for nine clients to only have one data distribution visualization plots." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import copy\n", "\n", "client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]\n", "for i in range(num_clients):\n", " client_configs[i].client_id = f\"Client{i + 1}\"\n", " client_configs[\n", " i\n", " ].data_configs.dataset_path = \"../../examples/resources/dataset/mnist_dataset.py\"\n", " client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients\n", " client_configs[i].data_configs.dataset_kwargs.client_id = i\n", " client_configs[i].data_configs.dataset_kwargs.visualization = (\n", " True if i == 0 else False\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create FL server agent and client agents\n", "In APPFL, we use agent to act on behalf of the FL server and FL clients to do necessary steps for the federated learning experiments. User can easily create the agents using the server/client configurations we loaded (and modified a little bit) from the configuration yaml file. Creating the client agents will load the local dataset and plot the data distribution visualization as shown below." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:56,430 server]: Logging to ./output/result_Server_2025-01-08-09-57-56.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:56,435 Client1]: Logging to ./output/result_Client1_2025-01-08-09-57-56.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:03,678 Client2]: Logging to ./output/result_Client2_2025-01-08-09-58-03.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:10,433 Client3]: Logging to ./output/result_Client3_2025-01-08-09-58-10.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:16,842 Client4]: Logging to ./output/result_Client4_2025-01-08-09-58-16.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:23,217 Client5]: Logging to ./output/result_Client5_2025-01-08-09-58-23.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:29,646 Client6]: Logging to ./output/result_Client6_2025-01-08-09-58-29.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:36,078 Client7]: Logging to ./output/result_Client7_2025-01-08-09-58-36.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:42,554 Client8]: Logging to ./output/result_Client8_2025-01-08-09-58-42.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:49,468 Client9]: Logging to ./output/result_Client9_2025-01-08-09-58-49.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:58:56,215 Client10]: Logging to ./output/result_Client10_2025-01-08-09-58-56.txt\n" ] } ], "source": [ "from appfl.agent import ServerAgent, ClientAgent\n", "\n", "server_agent = ServerAgent(server_agent_config=server_config)\n", "client_agents = [\n", " ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start the training process" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The server configuration files contains many client configurations which should apply for ALL clients. Now, we need to get those configurations from the server and provide them to the client agents." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Get additional client configurations from the server\n", "client_config_from_server = server_agent.get_client_configs()\n", "for client_agent in client_agents:\n", " client_agent.load_config(client_config_from_server)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, let the clients load initial global model from the server and optionally send the number of local data to the server for weighted aggregation.\n", "\n", "💡 **Note**: Typically, `server_agent.get_parameters()` blocks the result return or returns a `Future` object, and only returns the global model after receiving `num_clients` calls to synchronoize to process for clients to get the initial model, ensuring all the clients have the same initial model weights. However, as we are doing serial simulation, we don't want the blocking, so we pass `serail_run=True` when calling the function to get the global model immediately." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Load initial global model from the server\n", "init_global_model = server_agent.get_parameters(serial_run=True)\n", "for client_agent in client_agents:\n", " client_agent.load_parameters(init_global_model)\n", "\n", "# [Optional] Set number of local data to the server\n", "for i in range(num_clients):\n", " sample_size = client_agents[i].get_sample_size()\n", " server_agent.set_sample_size(\n", " client_id=client_agents[i].get_id(), sample_size=sample_size\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can start the training iterations. Please note the following points:\n", "\n", "- `server_agent.training_finished` will return a boolean flag indicating whether the training has reached the specified `num_global epochs`\n", "- `client_agent.train` trains the client local model using the client's \"local\" data\n", "- `server_agent.global_update` takes one client's local model together with a client id (can be get by `client_agent.get_id`) to schedule the global update for the client local model. For synchronous server aggregation algorithms such as `FedAvg`, they will not update the global model until receiving local models from all `num_clients=10` clients, so the call to `global_update` will return a `concurrent.futures.Future` object (if you set `blocking=False`, otherwise, it will block forever for serial simulation), which will be set after all client local models are sent for global update. \n", "- In the output log, `Pre Val?` means whether it is validation prior to the local training. As each client only holds data from 3 to 5 classes, the validation accuracy even drops after local training. However, the global model accuracy continues to increase, showcasing the capabilities of federated learning in improving the generalizability of the trained machine learning model." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:02,820 Client1]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:03,825 Client1]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:06,651 Client1]: 0 N 2.8258 0.4323 90.8109 15.4995 30.3500\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:06,653 Client2]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:07,690 Client2]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:10,635 Client2]: 0 N 2.9442 0.3130 87.7656 9.9495 48.6000\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:10,636 Client3]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:11,665 Client3]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:14,515 Client3]: 0 N 2.8484 0.3909 85.0781 10.4302 38.2900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:14,517 Client4]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:15,541 Client4]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:18,495 Client4]: 0 N 2.9528 0.4242 89.8377 11.0662 39.8000\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:18,496 Client5]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:19,528 Client5]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:22,589 Client5]: 0 N 3.0604 0.1851 92.1406 22.1558 28.8700\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:22,591 Client6]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:23,603 Client6]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:26,441 Client6]: 0 N 2.8377 0.6525 85.3739 14.9090 38.8200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:26,443 Client7]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:27,462 Client7]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:30,243 Client7]: 0 N 2.7805 0.2848 86.4844 10.2574 48.8900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:30,245 Client8]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:31,256 Client8]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:34,121 Client8]: 0 N 2.8636 0.3402 92.6434 14.2500 31.1400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:34,123 Client9]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:35,132 Client9]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:37,892 Client9]: 0 N 2.7590 0.3460 86.6562 9.1354 50.0000\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:37,894 Client10]: Round Pre Val? Time Train Loss Train Accuracy Val Loss Val Accuracy\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:38,885 Client10]: 0 Y 2.3006 15.9300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:41,568 Client10]: 0 N 2.6825 0.2586 88.4375 12.8184 40.2200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:42,569 Client1]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:45,282 Client1]: 1 N 2.7129 0.1185 98.0902 9.2444 30.4200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:46,287 Client2]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:49,161 Client2]: 1 N 2.8722 0.1277 95.7188 8.0948 48.7200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:50,222 Client3]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:52,936 Client3]: 1 N 2.7130 0.1673 94.5312 8.9735 38.5400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:53,939 Client4]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:56,585 Client4]: 1 N 2.6448 0.1617 96.1872 9.9464 39.7100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:59:57,592 Client5]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:00,295 Client5]: 1 N 2.7021 0.0906 97.3438 13.3640 29.5100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:01,289 Client6]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:04,012 Client6]: 1 N 2.7218 0.1973 96.3711 9.5049 39.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:05,028 Client7]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:07,715 Client7]: 1 N 2.6853 0.1087 95.3750 5.8535 49.2100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:08,714 Client8]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:11,396 Client8]: 1 N 2.6800 0.0983 98.0466 9.4541 31.3000\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:12,399 Client9]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:15,078 Client9]: 1 N 2.6767 0.1480 94.8281 6.5887 49.2400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:16,078 Client10]: 1 Y 1.8781 31.3900\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:18,718 Client10]: 1 N 2.6380 0.1009 96.3438 11.2293 40.7800\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:19,716 Client1]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:22,358 Client1]: 2 N 2.6401 0.0553 99.0451 6.5585 30.5100\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:23,347 Client2]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:26,005 Client2]: 2 N 2.6579 0.0708 97.4844 5.9560 49.2300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:27,009 Client3]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:29,809 Client3]: 2 N 2.7990 0.0948 96.4219 7.6792 38.7600\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:30,845 Client4]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:33,755 Client4]: 2 N 2.9089 0.0759 98.0936 8.1204 40.0200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:34,857 Client5]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:37,737 Client5]: 2 N 2.8792 0.0454 98.3750 8.8908 29.7500\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:38,732 Client6]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:41,369 Client6]: 2 N 2.6357 0.0995 98.0120 8.2203 39.4500\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:42,367 Client7]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:45,160 Client7]: 2 N 2.7908 0.0554 97.5781 4.9709 49.0400\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:46,179 Client8]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:48,859 Client8]: 2 N 2.6786 0.0540 98.8658 8.9072 31.2200\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:49,859 Client9]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:52,476 Client9]: 2 N 2.6149 0.0861 97.0000 5.5963 50.2700\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:53,481 Client10]: 2 Y 0.9154 68.4300\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 10:00:56,112 Client10]: 2 N 2.6278 0.0638 97.3594 8.1234 40.8700\n" ] } ], "source": [ "while not server_agent.training_finished():\n", " new_global_models = []\n", " for client_agent in client_agents:\n", " # Client local training\n", " client_agent.train()\n", " local_model = client_agent.get_parameters()\n", " if isinstance(local_model, tuple):\n", " local_model, metadata = local_model[0], local_model[1]\n", " else:\n", " metadata = {}\n", " # \"Send\" local model to server and get a Future object for the new global model\n", " # The Future object will be resolved when the server receives local models from all clients\n", " new_global_model_future = server_agent.global_update(\n", " client_id=client_agent.get_id(),\n", " local_model=local_model,\n", " blocking=False,\n", " **metadata,\n", " )\n", " new_global_models.append(new_global_model_future)\n", " # Load the new global model from the server\n", " for client_agent, new_global_model_future in zip(client_agents, new_global_models):\n", " client_agent.load_parameters(new_global_model_future.result())" ] } ], "metadata": { "kernelspec": { "display_name": "appfl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }