{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FL Server over Secure RPC\n", "\n", "We demonstrate how to launch a gRPC server as a federated learning server with authentication. Consider only one client so that we can launch a server and a client (from another notebook) together." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "num_clients = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import dependencies\n", "\n", "We put all the imports here. \n", "Our framework `appfl` is backboned by `torch` and its neural network model `torch.nn`. We also import `torchvision` to download the `MNIST` dataset.\n", "More importantly, we need to import `appfl.run_grpc_server` module." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import math\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", "from torchvision.transforms import ToTensor\n", "\n", "from appfl.config import Config\n", "from appfl.misc.data import Dataset\n", "import appfl.run_grpc_server as grpc_server\n", "from omegaconf import OmegaConf, DictConfig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test dataset\n", "\n", "The server can also hold test data to check the performance of the global model, and the test data needs to be wrapped in `Dataset` object. Note that the server does not need any training data." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "test_data_raw = torchvision.datasets.MNIST(\n", " \"./_data\", train=False, download=False, transform=ToTensor()\n", ")\n", "test_data_input = []\n", "test_data_label = []\n", "for idx in range(len(test_data_raw)):\n", " test_data_input.append(test_data_raw[idx][0].tolist())\n", " test_data_label.append(test_data_raw[idx][1])\n", "\n", "test_dataset = Dataset(\n", " torch.FloatTensor(test_data_input), torch.tensor(test_data_label)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model\n", "\n", "Users can define their own models by deriving `torch.nn.Module`. For example in this simulation, we define the following convolutional neural network." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self, num_channel=1, num_classes=10, num_pixel=28):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(\n", " num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True\n", " )\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))\n", " self.act = nn.ReLU(inplace=True)\n", "\n", " X = num_pixel\n", " X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)\n", " X = X / 2\n", " X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)\n", " X = X / 2\n", " X = int(X)\n", "\n", " self.fc1 = nn.Linear(64 * X * X, 512)\n", " self.fc2 = nn.Linear(512, num_classes)\n", "\n", " def forward(self, x):\n", " x = self.act(self.conv1(x))\n", " x = self.maxpool(x)\n", " x = self.act(self.conv2(x))\n", " x = self.maxpool(x)\n", " x = torch.flatten(x, 1)\n", " x = self.act(self.fc1(x))\n", " x = self.fc2(x)\n", " return x\n", "\n", "model = CNN()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loss and metric\n", "We define the loss function" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "loss_fn = torch.nn.CrossEntropyLoss() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and the validation metric for the training as well." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def accuracy(y_true, y_pred):\n", " '''\n", " y_true and y_pred are both of type np.ndarray\n", " y_true (N, d) where N is the size of the validation set, and d is the dimension of the label\n", " y_pred (N, D) where N is the size of the validation set, and D is the output dimension of the ML model\n", " '''\n", " if len(y_pred.shape) == 1:\n", " y_pred = np.round(y_pred)\n", " else:\n", " y_pred = y_pred.argmax(axis=1)\n", " return 100*np.sum(y_pred==y_true)/y_pred.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configurations\n", "\n", "We run the `appfl` training with the data and model defined above. \n", "A number of parameters can be easily set by changing the configuration values.\n", "We read the default configurations from `appfl.config.Config` class as a `DictConfig` object." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cfg: DictConfig = OmegaConf.structured(Config)\n", "# print(OmegaConf.to_yaml(cfg))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create secure SSL server and authenticator\n", "\n", "Secure SSL server requires both *public certificate* and *private key* for data encryption. We have provided a example pair of [certificate](../../src/appfl/comm/grpc/credentials/localhost.crt) and [key](../../src/appfl/comm/grpc/credentials/localhost.key) for demonstration. **It should be noted that in practice, you should never share your key to others and keep it secretly**. \n", "\n", "To use the provided certificate and key, we need to set the following. If the user would like to use his own certificate and key, just change the corresponding field to the file path." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "cfg.server.server_certificate=\"default\"\n", "cfg.server.server_certificate_key=\"default\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then to use the `NaiveAuthenticator`, user needs to set the following as the `NaiveAuthenticator` does not take any argument." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cfg.server.authenticator=\"Naive\"\n", "cfg.server.authenticator_kwargs={}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run with configurations\n", "For the server, we just run it by setting the number of global epochs to 5, and start the **secure** FL experiment." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Round: 001] Finished; all clients have sent their results.\n", "[Round: 001] Finished; all clients have sent their results.\n", "[Round: 001] Updating model weights\n", "[Round: 001] Updating model weights\n", "[Round: 001] Test set: Average loss: 0.3082, Accuracy: 90.95%, Best Accuracy: 90.95%\n", "[Round: 001] Test set: Average loss: 0.3082, Accuracy: 90.95%, Best Accuracy: 90.95%\n", "[Round: 002] Finished; all clients have sent their results.\n", "[Round: 002] Finished; all clients have sent their results.\n", "[Round: 002] Updating model weights\n", "[Round: 002] Updating model weights\n", "[Round: 002] Test set: Average loss: 0.1699, Accuracy: 94.94%, Best Accuracy: 94.94%\n", "[Round: 002] Test set: Average loss: 0.1699, Accuracy: 94.94%, Best Accuracy: 94.94%\n", "[Round: 003] Finished; all clients have sent their results.\n", "[Round: 003] Finished; all clients have sent their results.\n", "[Round: 003] Updating model weights\n", "[Round: 003] Updating model weights\n", "[Round: 003] Test set: Average loss: 0.1106, Accuracy: 96.73%, Best Accuracy: 96.73%\n", "[Round: 003] Test set: Average loss: 0.1106, Accuracy: 96.73%, Best Accuracy: 96.73%\n", "[Round: 004] Finished; all clients have sent their results.\n", "[Round: 004] Finished; all clients have sent their results.\n", "[Round: 004] Updating model weights\n", "[Round: 004] Updating model weights\n", "[Round: 004] Test set: Average loss: 0.0852, Accuracy: 97.58%, Best Accuracy: 97.58%\n", "[Round: 004] Test set: Average loss: 0.0852, Accuracy: 97.58%, Best Accuracy: 97.58%\n", "[Round: 005] Finished; all clients have sent their results.\n", "[Round: 005] Finished; all clients have sent their results.\n", "[Round: 005] Updating model weights\n", "[Round: 005] Updating model weights\n", "[Round: 005] Test set: Average loss: 0.0764, Accuracy: 97.77%, Best Accuracy: 97.77%\n", "[Round: 005] Test set: Average loss: 0.0764, Accuracy: 97.77%, Best Accuracy: 97.77%\n" ] } ], "source": [ "cfg.num_epochs = 5\n", "grpc_server.run_server(cfg, model, loss_fn, num_clients, test_dataset, accuracy)" ] } ], "metadata": { "interpreter": { "hash": "d5a3775820edfef7d27663833b7a57b274657051daef716a62aaac9a7002010d" }, "kernelspec": { "display_name": "Python 3.8.12 64-bit ('appfl-dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }