Adding new trainers#
Note
We always welcome you to contribute your increments to APPFL
by creating a pull request.
To add new trainers to APPFL
, you can create you own trainer class by inheriting the appfl.algorithm.trainer.BaseTrainer
and defining the following functions:
train
: Do local training using local sensitive dataset.get_parameters
: Return the parameters to be sent to the server for aggregation.load_parameters
: [Optional] Load the aggregated parameters from the server.
Note
If you find certain input parameters for
BaseTrainer
for your own trainer is not needed (e.g., you want to hardcode the loss function), then simply leave it as it is with aNone
default value.If you need to add any other input parameters to the trainer, simply provide them in the
**kwargs
.
class YourOwnTrainer(BaseTrainer):
"""
Args:
model: torch neural network model to train
loss_fn: loss function for the model training
metric: metric function for the model evaluation
train_dataset: training dataset
val_dataset: validation dataset
train_configs: training configurations
logger: logger for the trainer
"""
def __init__(
self,
model: Optional[nn.Module]=None,
loss_fn: Optional[nn.Module]=None,
metric: Optional[Any]=None,
train_dataset: Optional[Dataset]=None,
val_dataset: Optional[Dataset]=None,
train_configs: DictConfig = DictConfig({}),
logger: Optional[Any]=None,
**kwargs
):
self.round = 0
self.model = model
self.loss_fn = loss_fn
self.metric = metric
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.train_configs = train_configs
self.logger = logger
self.__dict__.update(kwargs)
def get_parameters(self) -> Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
"""Return local model parameters and optional metadata."""
pass
def train(self):
pass
def load_parameters(self, params: Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict], Any]):
"""Load model parameters. You can define your own way to load the parameters by overriding this function."""
self.model.load_state_dict(params)
You may add any configuration parameters into your trainer and access them using self.train_configs.your_config_param
. When you start the FL experiment, you can specify the trainer configuration parameter values in the server configuration file in the following way:
client_configs:
train_configs:
trainer: "YourOwnTrainer"
your_config_param_1: ...
your_config_param_2: ...
...