Adding new trainersΒΆ


We always welcome you to contribute your increments to APPFL by creating a pull request.

To add new trainers to APPFL, you can create you own trainer class by inheriting the appfl.algorithm.trainer.BaseTrainer and defining the following functions:

  • train: Do local training using local sensitive dataset.

  • get_parameters: Return the parameters to be sent to the server for aggregation.

  • load_parameters: [Optional] Load the aggregated parameters from the server.


  • If you find certain input parameters for BaseTrainer for your own trainer is not needed (e.g., you want to hardcode the loss function), then simply leave it as it is with a None default value.

  • If you need to add any other input parameters to the trainer, simply provide them in the **kwargs.

class YourOwnTrainer(BaseTrainer):
        model: torch neural network model to train
        loss_fn: loss function for the model training
        metric: metric function for the model evaluation
        train_dataset: training dataset
        val_dataset: validation dataset
        train_configs: training configurations
        logger: logger for the trainer
    def __init__(
        model: Optional[nn.Module]=None,
        loss_fn: Optional[nn.Module]=None,
        metric: Optional[Any]=None,
        train_dataset: Optional[Dataset]=None,
        val_dataset: Optional[Dataset]=None,
        train_configs: DictConfig = DictConfig({}),
        logger: Optional[Any]=None,
        self.round = 0
        self.model = model
        self.loss_fn = loss_fn
        self.metric = metric
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.train_configs = train_configs
        self.logger = logger

    def get_parameters(self) -> Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
        """Return local model parameters and optional metadata."""

    def train(self):

    def load_parameters(self, params: Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict], Any]):
    """Load model parameters. You can define your own way to load the parameters by overriding this function."""

You may add any configuration parameters into your trainer and access them using self.train_configs.your_config_param. When you start the FL experiment, you can specify the trainer configuration parameter values in the server configuration file in the following way:

        trainer: "YourOwnTrainer"
        your_config_param_1: ...
        your_config_param_2: ...