Adding new algorithms#
Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client.
Base classes#
Implementation of the new classes should be derived from the following two base classes:
Example: NewAlgo#
Here we give some simple example.
Core algorithm class#
We first create classes for the global and local updates in appfl/algorithm:
Create two classes
NewAlgoServerandNewAlgoClientinnewalgo.pyIn
NewAlgoServer, theupdatefunction conducts a global update by averaging the local model parameters sent from multiple clientsIn
NewAlgoClient, theupdatefunction conducts a local update and send the resulting local model parameters to the server
This is an example code:
src/appfl/algorithm/newalgo.py#from .algorithm import BaseServer, BaseClient
class NewAlgoServer(BaseServer):
def __init__(self, weights, model, num_clients, device, **kwargs):
super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self, local_states: OrderedDict):
# Implement new server update function
class NewAlgoClient(BaseClient):
def __init__(self, id, weight, model, dataloader, device, **kwargs):
super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self):
# Implement new client update function
Configuration dataclass#
The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed.
Let’s say we add src/appfl/config/fed/newalgo.py file to implement the dataclass as follows:
src/appfl/config/fed/newalgo.py#from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
@dataclass
class NewAlgo:
type: str = "newalgo"
servername: str = "NewAlgoServer"
clientname: str = "NewAlgoClient"
args: DictConfig = OmegaConf.create(
{
# add new arguments
}
)
Then, we need to add the following line to the main configuration file config.py.
from .fed.new_algorithm import *
This is the main configuration class in src/appfl/config/config.py.
Each algorithm, specified in Config.fed, can be configured in the dataclasses at appfl.config.fed.*.
1from dataclasses import dataclass, field
2from typing import Any, List, Dict, Optional
3from omegaconf import DictConfig, OmegaConf
4import os
5import sys
6
7from .fed.federated import *
8from .fed.fedasync import *
9from .fed.iceadmm import *
10from .fed.iiadmm import *
11
12
13@dataclass
14class Config:
15 fed: Any = field(default_factory=Federated)
16
17 # Compute device
18 device: str = "cpu"
19 device_server: str = "cpu"
20
21 # Number of training epochs
22 num_clients: int = 1
23
24 # Number of training epochs
25 num_epochs: int = 2
26
27 # Number of workers in DataLoader
28 num_workers: int = 0
29
30 # Train data batch info
31 batch_training: bool = True ## TODO: revisit
32 train_data_batch_size: int = 64
33 train_data_shuffle: bool = True
34
35 # Indication of whether to validate or not using testing data
36 validation: bool = True
37 test_data_batch_size: int = 64
38 test_data_shuffle: bool = False
39
40 # Checking data sanity
41 data_sanity: bool = False
42
43 # Reproducibility
44 reproduce: bool = True
45
46 # PCA on Trajectory
47 pca_dir: str = ""
48 params_start: int = 0
49 params_end: int = 49
50 ncomponents: int = 40
51
52 # Tensorboard
53 use_tensorboard: bool = False
54
55 # Loading models
56 load_model: bool = False
57 load_model_dirname: str = ""
58 load_model_filename: str = ""
59
60 # Saving models (server)
61 save_model: bool = False
62 save_model_dirname: str = ""
63 save_model_filename: str = ""
64 checkpoints_interval: int = 2
65
66 # Saving state_dict (clients)
67 save_model_state_dict: bool = False
68 send_final_model: bool = False
69
70 # Logging and recording outputs
71 output_dirname: str = "output"
72 output_filename: str = "result"
73
74 logginginfo: DictConfig = OmegaConf.create({})
75 summary_file: str = ""
76
77 # Personalization options
78 personalization: bool = False
79 p_layers: List[str] = field(default_factory=lambda: [])
80 config_name: str = ""
81
82 ## gRPC configutations ##
83
84 # 100 MB for gRPC maximum message size
85 max_message_size: int = 10485760
86 use_ssl: bool = False
87 use_authenticator: bool = False
88 authenticator: str = "Globus" # "Globus", "Naive"
89 uri: str = "localhost:50051"
90
91 operator: DictConfig = OmegaConf.create({"id": 1})
92 server: DictConfig = OmegaConf.create({
93 "id": 1,
94 "authenticator_kwargs": {
95 "is_fl_server": True,
96 "globus_group_id": "77c1c74b-a33b-11ed-8951-7b5a369c0a53",
97 },
98 "server_certificate_key": "default",
99 "server_certificate": "default",
100 "max_workers": 10,
101 })
102 client: DictConfig = OmegaConf.create({
103 "id": 1,
104 "root_certificates": "default",
105 "authenticator_kwargs": {
106 "is_fl_server": False,
107 },
108 })
109
110 # Lossy compression enabling
111 enable_compression: bool = False
112 lossy_compressor: str = "SZ2"
113 lossless_compressor: str = "blosc"
114
115 # Lossy compression path configuration
116 ext = ".dylib" if sys.platform.startswith("darwin") else ".so"
117 compressor_sz2_path: str = "../.compressor/SZ/build/sz/libSZ" + ext
118 compressor_sz3_path: str = "../.compressor/SZ3/build/tools/sz3c/libSZ3c" + ext
119 compressor_szx_path: str = "../.compressor/SZx-main/build/lib/libSZx" + ext
120
121 # Compressor parameters
122 error_bounding_mode: str = ""
123 error_bound: float = 0.0
124
125 # Default data type
126 flat_model_dtype: str = "np.float32"
127 param_cutoff: int = 1024
128
129 # Data readiness
130 dr_metrics: Optional[List[str]] = field(default_factory=lambda: [])
131
132
133@dataclass
134class GlobusComputeServerConfig:
135 device: str = "cpu"
136 output_dir: str = "./"
137 data_dir: str = "./"
138 s3_bucket: Any = None
139 s3_creds: str = ""
140
141
142@dataclass
143class GlobusComputeClientConfig:
144 name : str = ""
145 endpoint_id : str = ""
146 device : str = "cpu"
147 output_dir : str = "./output"
148 data_dir : str = "./datasets"
149 get_data : DictConfig = OmegaConf.create({})
150 data_pipeline: DictConfig = OmegaConf.create({})
151
152
153@dataclass
154class ExecutableFunc:
155 module: str = ""
156 call: str = ""
157 script_file: str = ""
158 source: str = ""
159
160
161@dataclass
162class ClientTask:
163 task_id: str = ""
164 task_name: str = ""
165 client_idx: int = ""
166 pending: bool = True
167 success: bool = False
168 start_time: float = -1
169 end_time: float = -1
170 log: Optional[Dict] = field(default_factory=dict)
171
172
173@dataclass
174class GlobusComputeConfig(Config):
175 get_data: ExecutableFunc = field(default_factory=ExecutableFunc)
176 get_model: ExecutableFunc = field(default_factory=ExecutableFunc)
177 get_loss: ExecutableFunc = field(default_factory=ExecutableFunc)
178 val_metric: ExecutableFunc = field(default_factory=ExecutableFunc)
179 clients: List[GlobusComputeClientConfig] = field(default_factory=list)
180 dataset: str = ""
181 loss: str = "CrossEntropy"
182 model_kwargs: Dict = field(default_factory=dict)
183 server: GlobusComputeServerConfig
184 logging_tasks: List = field(default_factory=list)
185 hf_model_arc: str = ""
186 hf_model_weights: str = ""
187
188 # Testing and validation params
189 client_do_validation: bool = True
190 client_do_testing: bool = True
191 server_do_validation: bool = True
192 server_do_testing: bool = True
193
194 # Testing and validation frequency
195 client_validation_step: int = 1
196 server_validation_step: int = 1
197
198 # Cloud storage
199 use_cloud_transfer: bool = True