{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FL Server over Secure RPC\n", "\n", "We demonstrate how to launch a gRPC server as a federated learning server with authentication. Consider only one client so that we can launch a server and a client (from another notebook) together." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "num_clients = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load server configurations\n", "\n", "In this example, we use the `FedAvg` server aggregation algorithm (while there is only one client for easy demo, the aggregation algorithm does not matter a lot though) and the MNIST dataset by loading the server configurations from `examples/resources/configs/mnist/server_fedavg.yaml`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "client_configs:\n", " train_configs:\n", " trainer: VanillaTrainer\n", " mode: step\n", " num_local_steps: 100\n", " optim: Adam\n", " optim_args:\n", " lr: 0.001\n", " loss_fn_path: ./resources/loss/celoss.py\n", " loss_fn_name: CELoss\n", " do_validation: true\n", " do_pre_validation: true\n", " metric_path: ./resources/metric/acc.py\n", " metric_name: accuracy\n", " use_dp: false\n", " epsilon: 1\n", " clip_grad: false\n", " clip_value: 1\n", " clip_norm: 1\n", " train_batch_size: 64\n", " val_batch_size: 64\n", " train_data_shuffle: true\n", " val_data_shuffle: false\n", " model_configs:\n", " model_path: ./resources/model/cnn.py\n", " model_name: CNN\n", " model_kwargs:\n", " num_channel: 1\n", " num_classes: 10\n", " num_pixel: 28\n", " comm_configs:\n", " compressor_configs:\n", " enable_compression: false\n", " lossy_compressor: SZ2Compressor\n", " lossless_compressor: blosc\n", " error_bounding_mode: REL\n", " error_bound: 0.001\n", " param_cutoff: 1024\n", "server_configs:\n", " num_clients: 2\n", " scheduler: SyncScheduler\n", " scheduler_kwargs:\n", " same_init_model: true\n", " aggregator: FedAvgAggregator\n", " aggregator_kwargs:\n", " client_weights_mode: equal\n", " device: cpu\n", " num_global_epochs: 10\n", " logging_output_dirname: ./output\n", " logging_output_filename: result\n", " comm_configs:\n", " grpc_configs:\n", " server_uri: localhost:50051\n", " max_message_size: 1048576\n", " use_ssl: false\n", "\n" ] } ], "source": [ "from omegaconf import OmegaConf\n", "\n", "server_config_file = \"../../examples/resources/configs/mnist/server_fedavg.yaml\"\n", "server_config = OmegaConf.load(server_config_file)\n", "print(OmegaConf.to_yaml(server_config))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "💡 It should be noted that configuration fields such as `loss_fn_path`, `metric_path`, and `model_path` are the paths to the corresponding files, so we need to change their relative paths now to make sure the paths point to the right files. \n", "\n", "⚠️ We also need change `num_clients` in `server_configs` to 1." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "server_config.client_configs.train_configs.loss_fn_path = (\n", " \"../../examples/resources/loss/celoss.py\"\n", ")\n", "server_config.client_configs.train_configs.metric_path = (\n", " \"../../examples/resources/metric/acc.py\"\n", ")\n", "server_config.client_configs.model_configs.model_path = (\n", " \"../../examples/resources/model/cnn.py\"\n", ")\n", "server_config.server_configs.num_clients = num_clients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create secure SSL server and authenticator\n", "\n", "Secure SSL server requires both *public certificate* and *private key* for data encryption. We have provided a example pair of [certificate](https://github.com/APPFL/APPFL/blob/main/src/appfl/comm/grpc/credentials/localhost.crt) and [key](https://github.com/APPFL/APPFL/blob/main/src/appfl/comm/grpc/credentials/localhost.key) for demonstration. **It should be noted that in practice, you should never share your key to others and keep it secretly**. \n", "\n", "💡 Please check this [tutorial](https://appfl.ai/en/latest/tutorials/examples_ssl.html) for more details on how to generate SSL certificates for securing the gRPC connections in practice.\n", "\n", "\n", "To enable the SSL channel and use the provided certificate and key, we need to set the following. If the user would like to use his own certificate and key, just change the corresponding field to the file path." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "server_config.server_configs.comm_configs.grpc_configs.use_ssl = True\n", "server_config.server_configs.comm_configs.grpc_configs.server_certificate_key = (\n", " \"../../src/appfl/comm/grpc/credentials/localhost.key\"\n", ")\n", "server_config.server_configs.comm_configs.grpc_configs.server_certificate = (\n", " \"../../src/appfl/comm/grpc/credentials/localhost.crt\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup an authenticator\n", "\n", "Now we use a naive authenticator, where the server sets a special token and uses token-match to authenticate the client. \n", "\n", "💡 It should be noted that the naive authenticator is only for easy demonstration and is not really safe in practice to protect your FL experiment. We also provide Globus authenticator, and you can also define your own ones." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "server_config.server_configs.comm_configs.grpc_configs.use_authenticator = True\n", "server_config.server_configs.comm_configs.grpc_configs.authenticator = (\n", " \"NaiveAuthenticator\"\n", ")\n", "server_config.server_configs.comm_configs.grpc_configs.authenticator_args = {\n", " \"auth_token\": \"A_SECRET_DEMO_TOKEN\"\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start server\n", "\n", "Now, we are ready to create the server agent using the `server_config` defined and modified above and start the grpc server.\n", "\n", "After launching 🚀 the server, let's go to the notebook to launch the client to talk to the server!\n", "\n", "💡 After finishing the FL experiment, you need to manually stop the server." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:19,762 server]: Logging to ./output/result_Server_2025-01-08-09-56-19.txt\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:38,126 server]: Received GetConfiguration request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:38,142 server]: Received GetGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:38,151 server]: Received InvokeCustomAction set_sample_size request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:42,211 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:42,212 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 15.93,\n", " 'pre_val_loss': 2.30059186820012,\n", " 'round': 1,\n", " 'val_accuracy': 94.59,\n", " 'val_loss': 0.1758718944694491}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:46,293 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:46,294 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 94.59,\n", " 'pre_val_loss': 0.1758718927610357,\n", " 'round': 2,\n", " 'val_accuracy': 96.79,\n", " 'val_loss': 0.10359191318757381}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:50,108 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:50,109 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 96.79,\n", " 'pre_val_loss': 0.10359191384305881,\n", " 'round': 3,\n", " 'val_accuracy': 97.55,\n", " 'val_loss': 0.07821214441276766}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:54,011 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:54,012 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 97.55,\n", " 'pre_val_loss': 0.07821214327085942,\n", " 'round': 4,\n", " 'val_accuracy': 97.91,\n", " 'val_loss': 0.06505483987920616}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:58,043 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:56:58,044 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 97.91,\n", " 'pre_val_loss': 0.0650548392593131,\n", " 'round': 5,\n", " 'val_accuracy': 98.05,\n", " 'val_loss': 0.060302854882654064}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:01,968 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:01,969 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 98.05,\n", " 'pre_val_loss': 0.06030285448595217,\n", " 'round': 6,\n", " 'val_accuracy': 98.48,\n", " 'val_loss': 0.050337113507791235}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:05,951 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:05,952 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 98.48,\n", " 'pre_val_loss': 0.050337113759900846,\n", " 'round': 7,\n", " 'val_accuracy': 98.74,\n", " 'val_loss': 0.040618350146898914}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:09,880 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:09,881 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 98.74,\n", " 'pre_val_loss': 0.04061834986216335,\n", " 'round': 8,\n", " 'val_accuracy': 98.39,\n", " 'val_loss': 0.05227461058381726}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:13,723 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:13,724 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 98.39,\n", " 'pre_val_loss': 0.0522746103943643,\n", " 'round': 9,\n", " 'val_accuracy': 98.73,\n", " 'val_loss': 0.03826261686356175}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:17,601 server]: Received UpdateGlobalModel request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:17,602 server]: Received the following meta data from Client1:\n", "{'pre_val_accuracy': 98.73,\n", " 'pre_val_loss': 0.038262616999997535,\n", " 'round': 10,\n", " 'val_accuracy': 98.71,\n", " 'val_loss': 0.04022663724981323}\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:17,614 server]: Received InvokeCustomAction close_connection request from client Client1\n", "\u001b[34m\u001b[1mappfl: ✅\u001b[0m[2025-01-08 09:57:17,954 server]: Terminating the server ...\n" ] } ], "source": [ "from appfl.agent import ServerAgent\n", "from appfl.comm.grpc import GRPCServerCommunicator, serve\n", "\n", "server_agent = ServerAgent(server_agent_config=server_config)\n", "\n", "communicator = GRPCServerCommunicator(\n", " server_agent,\n", " logger=server_agent.logger,\n", " **server_config.server_configs.comm_configs.grpc_configs,\n", ")\n", "\n", "serve(\n", " communicator,\n", " **server_config.server_configs.comm_configs.grpc_configs,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "appfl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }