{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FL Client over Secure RPC\n", "\n", "In this notebook, we will present how to launch a gRPC client as an FL client with an authenticator. To pair with the server notebook, we consider only one client." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "num_clients = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import dependencies\n", "\n", "Everything is the same as for the gRPC server.\n", "But here, we need to import `appfl.run_grpc_client` module." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import math\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", "from torchvision.transforms import ToTensor\n", "\n", "from appfl.config import Config\n", "from appfl.misc.data import Dataset\n", "import appfl.run_grpc_client as grpc_client\n", "from omegaconf import OmegaConf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training datasets\n", "\n", "Each client needs to create `Dataset` object with the training data. Here, we create the objects for all the clients." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "train_data_raw = torchvision.datasets.MNIST(\n", " \"./_data\", train=True, download=True, transform=ToTensor()\n", ")\n", "split_train_data_raw = np.array_split(range(len(train_data_raw)), num_clients)\n", "train_datasets = []\n", "for i in range(num_clients):\n", "\n", " train_data_input = []\n", " train_data_label = []\n", " for idx in split_train_data_raw[i]:\n", " train_data_input.append(train_data_raw[idx][0].tolist())\n", " train_data_label.append(train_data_raw[idx][1])\n", "\n", " train_datasets.append(\n", " Dataset(\n", " torch.FloatTensor(train_data_input),\n", " torch.tensor(train_data_label),\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model\n", "\n", "We should use the same model used in the server. See the notebook for server." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self, num_channel=1, num_classes=10, num_pixel=28):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(\n", " num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True\n", " )\n", " self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))\n", " self.act = nn.ReLU(inplace=True)\n", "\n", " X = num_pixel\n", " X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)\n", " X = X / 2\n", " X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)\n", " X = X / 2\n", " X = int(X)\n", "\n", " self.fc1 = nn.Linear(64 * X * X, 512)\n", " self.fc2 = nn.Linear(512, num_classes)\n", "\n", " def forward(self, x):\n", " x = self.act(self.conv1(x))\n", " x = self.maxpool(x)\n", " x = self.act(self.conv2(x))\n", " x = self.maxpool(x)\n", " x = torch.flatten(x, 1)\n", " x = self.act(self.fc1(x))\n", " x = self.fc2(x)\n", " return x\n", "\n", "model = CNN()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loss and metric\n", "We should use the same loss function" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "loss_fn = torch.nn.CrossEntropyLoss() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and validation metric as the server." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def accuracy(y_true, y_pred):\n", " '''\n", " y_true and y_pred are both of type np.ndarray\n", " y_true (N, d) where N is the size of the validation set, and d is the dimension of the label\n", " y_pred (N, D) where N is the size of the validation set, and D is the output dimension of the ML model\n", " '''\n", " if len(y_pred.shape) == 1:\n", " y_pred = np.round(y_pred)\n", " else:\n", " y_pred = y_pred.argmax(axis=1)\n", " return 100*np.sum(y_pred==y_true)/y_pred.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set configurations\n", "\n", "We run the `appfl` training with the data and model defined above. \n", "A number of parameters can be easily set by changing the configuration values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cfg = OmegaConf.structured(Config)\n", "# print(OmegaConf.to_yaml(cfg))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we set the number of local epochs to 1 and the local learning rate to 0.01" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cfg.fed.args.num_local_epochs = 1\n", "cfg.fed.args.optim_args.lr = 0.01" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create secure SSL channel and authenticator\n", "\n", "The client requires a root certificate to verify the server certificate. In this example, we provide that [root certificate](../../src/appfl/comm/grpc/credentials/root.crt), assuming that the server uses self-signed [certificate](../../src/appfl/comm/grpc/credentials/localhost.crt) and [key](../../src/appfl/comm/grpc/credentials/localhost.key) provided by gRPC official documentation.\n", "\n", "To use the provided root certificate, user just to need to set the following. If the user would like to use his own root certificate, just change this to the file path to the local root certificate." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "cfg.client.root_certificates = \"default\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then to use the `NaiveAuthenticator`, user needs to set the following as the `NaiveAuthenticator` does not take any argument." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cfg.client.authenticator = \"Naive\"\n", "cfg.client.authenticator_kwargs = {}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run with configurations\n", "And, we can start a **secure** training with the configuration `cfg`." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "grpc_client.run_client(cfg, 0, model, loss_fn, train_datasets[0], metric=accuracy)" ] } ], "metadata": { "interpreter": { "hash": "d5a3775820edfef7d27663833b7a57b274657051daef716a62aaac9a7002010d" }, "kernelspec": { "display_name": "Python 3.8.12 64-bit ('appfl-dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }