Adding new aggregatorsΒΆ

Note

We always welcome you to contribute your increments to APPFL by creating a pull request.

To add new aggregators to APPFL, you can create your own aggregator class by inheriting the appfl.algorithm.aggregator.BaseAggregator and defining the following functions:

  • aggregate: Take a list of local models (for synchronous FL) or one local model (for asynchronous FL) or both (for asynchronous FL) as the input and return the updated global model parameters.

  • get_parameters: Directly return the current global model parameters.

class YourOwnAggregator(BaseAggregator):
    def __init__(
        self,
        model: torch.nn.Module,
        aggregator_configs: DictConfig,
        logger: Any
    ):
        self.model = model
        self.aggregator_configs = aggregator_configs
        self.logger = logger
        ...

    def aggregate(self, *args, **kwargs) -> Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
        """
        Aggregate local model(s) from clients and return the global model
        """
        pass

    def get_parameters(self, **kwargs) -> Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
        """Return global model parameters"""
        pass

You may add any configuration parameters into your aggregator and access them using self.aggregator_configs.your_config_param. When you start the FL experiment, you can specify the aggregator configuration parameter values in the server configuration file in the following way:

server_configs:
    ...
    aggregator: "YourOwnAggregator"
    aggregator_kwargs:
        your_config_param_1: ...
        your_config_param_2: ...
    ...