Adding new algorithms#

Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client.

Base classes#

Implementation of the new classes should be derived from the following two base classes:

Example: NewAlgo#

Here we give some simple example.

Core algorithm class#

We first create classes for the global and local updates in appfl/algorithm:

  • Create two classes NewAlgoServer and NewAlgoClient in newalgo.py

  • In NewAlgoServer, the update function conducts a global update by averaging the local model parameters sent from multiple clients

  • In NewAlgoClient, the update function conducts a local update and send the resulting local model parameters to the server

This is an example code:

Example code for src/appfl/algorithm/newalgo.py#
from .algorithm import BaseServer, BaseClient

class NewAlgoServer(BaseServer):
    def __init__(self, weights, model, num_clients, device, **kwargs):
        super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self, local_states: OrderedDict):
        # Implement new server update function

class NewAlgoClient(BaseClient):
    def __init__(self, id, weight, model, dataloader, device, **kwargs):
        super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self):
        # Implement new client update function

Configuration dataclass#

The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed. Let’s say we add src/appfl/config/fed/newalgo.py file to implement the dataclass as follows:

Example code for src/appfl/config/fed/newalgo.py#
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class NewAlgo:
    type: str = "newalgo"
    servername: str = "NewAlgoServer"
    clientname: str = "NewAlgoClient"
    args: DictConfig = OmegaConf.create(
        {
            # add new arguments
        }
    )

Then, we need to add the following line to the main configuration file config.py.

from .fed.new_algorithm import *

This is the main configuration class in src/appfl/config/config.py. Each algorithm, specified in Config.fed, can be configured in the dataclasses at appfl.config.fed.*.

The main configuration class#
  1from dataclasses import dataclass, field
  2from typing import Any, List, Dict, Optional
  3from omegaconf import DictConfig, OmegaConf
  4import os
  5import sys
  6
  7from .fed.federated import *
  8from .fed.fedasync import *
  9from .fed.iceadmm import * 
 10from .fed.iiadmm import *
 11
 12
 13@dataclass
 14class Config:
 15    fed: Any = field(default_factory=Federated)
 16
 17    # Compute device
 18    device: str = "cpu"
 19    device_server: str = "cpu"
 20
 21    # Number of training epochs
 22    num_clients: int = 1
 23
 24    # Number of training epochs
 25    num_epochs: int = 2
 26
 27    # Number of workers in DataLoader
 28    num_workers: int = 0
 29
 30    # Train data batch info
 31    batch_training: bool = True  ## TODO: revisit
 32    train_data_batch_size: int = 64
 33    train_data_shuffle: bool = True
 34
 35    # Indication of whether to validate or not using testing data
 36    validation: bool = True
 37    test_data_batch_size: int = 64
 38    test_data_shuffle: bool = False
 39
 40    # Checking data sanity
 41    data_sanity: bool = False
 42
 43    # Reproducibility
 44    reproduce: bool = True
 45
 46    # PCA on Trajectory
 47    pca_dir: str = ""
 48    params_start: int = 0
 49    params_end: int = 49
 50    ncomponents: int = 40
 51
 52    # Tensorboard
 53    use_tensorboard: bool = False
 54
 55    # Loading models
 56    load_model: bool = False
 57    load_model_dirname: str = ""
 58    load_model_filename: str = ""
 59
 60    # Saving models (server)
 61    save_model: bool = False
 62    save_model_dirname: str = ""
 63    save_model_filename: str = ""
 64    checkpoints_interval: int = 2
 65
 66    # Saving state_dict (clients)
 67    save_model_state_dict: bool = False
 68    send_final_model: bool = False
 69
 70    # Logging and recording outputs
 71    output_dirname: str = "output"
 72    output_filename: str = "result"
 73
 74    logginginfo: DictConfig = OmegaConf.create({})
 75    summary_file: str = ""
 76
 77    # Personalization options
 78    personalization: bool = False
 79    p_layers: List[str] = field(default_factory=lambda: [])
 80    config_name: str = ""
 81
 82    ## gRPC configutations ##
 83
 84    # 100 MB for gRPC maximum message size
 85    max_message_size: int = 10485760
 86    use_ssl: bool = False
 87    use_authenticator: bool = False
 88    authenticator: str = "Globus" # "Globus", "Naive"
 89    uri: str = "localhost:50051"
 90
 91    operator: DictConfig = OmegaConf.create({"id": 1})
 92    server: DictConfig = OmegaConf.create({
 93        "id": 1, 
 94        "authenticator_kwargs": {
 95            "is_fl_server": True,
 96            "globus_group_id": "77c1c74b-a33b-11ed-8951-7b5a369c0a53",
 97        },
 98        "server_certificate_key": "default",
 99        "server_certificate": "default",
100        "max_workers": 10,
101    })
102    client: DictConfig = OmegaConf.create({
103        "id": 1,
104        "root_certificates": "default",
105        "authenticator_kwargs": {
106            "is_fl_server": False,
107        },
108    })
109
110    # Lossy compression enabling
111    enable_compression: bool = False
112    lossy_compressor: str = "SZ2"
113    lossless_compressor: str = "blosc"
114
115    # Lossy compression path configuration
116    ext = ".dylib" if sys.platform.startswith("darwin") else ".so"
117    compressor_sz2_path: str = "../.compressor/SZ/build/sz/libSZ" + ext
118    compressor_sz3_path: str = "../.compressor/SZ3/build/tools/sz3c/libSZ3c" + ext
119    compressor_szx_path: str = "../.compressor/SZx-main/build/lib/libSZx" + ext
120
121    # Compressor parameters
122    error_bounding_mode: str = ""
123    error_bound: float = 0.0
124
125    # Default data type
126    flat_model_dtype: str = "np.float32"
127    param_cutoff: int = 1024
128
129    # Data readiness
130    dr_metrics: Optional[List[str]] = field(default_factory=lambda: [])
131
132
133@dataclass
134class GlobusComputeServerConfig:
135    device: str = "cpu"
136    output_dir: str = "./"
137    data_dir: str = "./"
138    s3_bucket: Any = None
139    s3_creds: str = ""
140
141
142@dataclass
143class GlobusComputeClientConfig:
144    name        : str = ""
145    endpoint_id : str = ""
146    device      : str = "cpu"
147    output_dir  : str = "./output"
148    data_dir    : str = "./datasets"
149    get_data    :  DictConfig = OmegaConf.create({})
150    data_pipeline: DictConfig = OmegaConf.create({})
151
152
153@dataclass
154class ExecutableFunc:
155    module: str = ""
156    call: str = ""
157    script_file: str = ""
158    source: str = ""
159
160
161@dataclass
162class ClientTask:
163    task_id: str = ""
164    task_name: str = ""
165    client_idx: int = ""
166    pending: bool = True
167    success: bool = False
168    start_time: float = -1
169    end_time: float = -1
170    log: Optional[Dict] = field(default_factory=dict)
171
172
173@dataclass
174class GlobusComputeConfig(Config):
175    get_data: ExecutableFunc = field(default_factory=ExecutableFunc)
176    get_model: ExecutableFunc = field(default_factory=ExecutableFunc)
177    get_loss: ExecutableFunc = field(default_factory=ExecutableFunc)
178    val_metric: ExecutableFunc = field(default_factory=ExecutableFunc)
179    clients: List[GlobusComputeClientConfig] = field(default_factory=list)
180    dataset: str = ""
181    loss: str = "CrossEntropy"
182    model_kwargs: Dict = field(default_factory=dict)
183    server: GlobusComputeServerConfig
184    logging_tasks: List = field(default_factory=list)
185    hf_model_arc: str = ""
186    hf_model_weights: str = ""
187
188    # Testing and validation params
189    client_do_validation: bool = True
190    client_do_testing: bool = True
191    server_do_validation: bool = True
192    server_do_testing: bool = True
193
194    # Testing and validation frequency
195    client_validation_step: int = 1
196    server_validation_step: int = 1
197
198    # Cloud storage
199    use_cloud_transfer: bool = True