FL with compressionΒΆ

In this notebook, we will show how to use lossy compressor in FL to compress the model parameters and reduce the communication cost.

Install compressorsΒΆ

To install the compressors, fist make sure that you have installed necessary packages for the compressors by running the following command in the APPFL directory.

pip install -e . # if you installed using source code

or

pip install appfl # if you installed directly using pypi

Then, you can easily install the lossy compressors by running the following command anywhere. It will download all the source code for compressors under APPFL/.compressor.

appfl-install-compressor

πŸ’‘ appfl.compressor supports four compressors: SZ2, SZ3, ZFP, and SZX. However, as SZx needs particular permission to access, so we have to omit its installation here. If you want to try with SZx, please contact its authors.

Load server and client configurationsΒΆ

Following the same steps in the serial FL example, we load and modify the configurations for server and five clients.

[1]:
import copy
from omegaconf import OmegaConf

num_clients = 5

server_config_file = "../../examples/resources/configs/mnist/server_fedavg.yaml"
server_config = OmegaConf.load(server_config_file)
server_config.client_configs.train_configs.loss_fn_path = (
    "../../examples/resources/loss/celoss.py"
)
server_config.client_configs.train_configs.metric_path = (
    "../../examples/resources/metric/acc.py"
)
server_config.client_configs.model_configs.model_path = (
    "../../examples/resources/model/cnn.py"
)
server_config.server_configs.num_global_epochs = 3
server_config.server_configs.num_clients = num_clients

client_config_file = "../../examples/resources/configs/mnist/client_1.yaml"
client_config = OmegaConf.load(client_config_file)
client_configs = [copy.deepcopy(client_config) for _ in range(num_clients)]
for i in range(num_clients):
    client_configs[i].client_id = f"Client{i + 1}"
    client_configs[
        i
    ].data_configs.dataset_path = "../../examples/resources/dataset/mnist_dataset.py"
    client_configs[i].data_configs.dataset_kwargs.num_clients = num_clients
    client_configs[i].data_configs.dataset_kwargs.client_id = i
    client_configs[i].data_configs.dataset_kwargs.visualization = (
        True if i == 0 else False
    )

print(OmegaConf.to_yaml(server_config))
client_configs:
  train_configs:
    trainer: VanillaTrainer
    mode: step
    num_local_steps: 100
    optim: Adam
    optim_args:
      lr: 0.001
    loss_fn_path: ../../examples/resources/loss/celoss.py
    loss_fn_name: CELoss
    do_validation: true
    do_pre_validation: true
    metric_path: ../../examples/resources/metric/acc.py
    metric_name: accuracy
    use_dp: false
    epsilon: 1
    clip_grad: false
    clip_value: 1
    clip_norm: 1
    train_batch_size: 64
    val_batch_size: 64
    train_data_shuffle: true
    val_data_shuffle: false
  model_configs:
    model_path: ../../examples/resources/model/cnn.py
    model_name: CNN
    model_kwargs:
      num_channel: 1
      num_classes: 10
      num_pixel: 28
  comm_configs:
    compressor_configs:
      enable_compression: false
      lossy_compressor: SZ2Compressor
      lossless_compressor: blosc
      error_bounding_mode: REL
      error_bound: 0.001
      param_cutoff: 1024
server_configs:
  num_clients: 5
  scheduler: SyncScheduler
  scheduler_kwargs:
    same_init_model: true
  aggregator: FedAvgAggregator
  aggregator_kwargs:
    client_weights_mode: equal
  device: cpu
  num_global_epochs: 3
  logging_output_dirname: ./output
  logging_output_filename: result
  comm_configs:
    grpc_configs:
      server_uri: localhost:50051
      max_message_size: 1048576
      use_ssl: false

Enable compressionΒΆ

To enable compression, we just need to set server_config.client_configs.comm_configs.compressor_configs.enable_compression to True.

πŸ’‘ You may notices that both server_config.server_configs and server_config.client_configs have a comm_configs fields. Actually, when creating the server agent, its communication configurations will be the merging of server_config.server_configs.comm_configs and server_config.client_configs.comm_configs. However, server_config.client_configs.comm_configs will also be shared with clients, while server_config.server_configs.comm_configs will not. As we want the clients to be aware of the compressor configurations, we put compressor_configs under server_config.client_configs.comm_configs to share with the clients during the FL experiment.

Create the agents and start the experimentΒΆ

[2]:
from appfl.agent import ServerAgent, ClientAgent

# Create server and client agents
server_config.client_configs.comm_configs.compressor_configs.enable_compression = True
server_agent = ServerAgent(server_agent_config=server_config)
client_agents = [
    ClientAgent(client_agent_config=client_configs[i]) for i in range(num_clients)
]

# Get additional client configurations from the server
client_config_from_server = server_agent.get_client_configs()
for client_agent in client_agents:
    client_agent.load_config(client_config_from_server)

# Load initial global model from the server
init_global_model = server_agent.get_parameters(serial_run=True)
for client_agent in client_agents:
    client_agent.load_parameters(init_global_model)

# [Optional] Set number of local data to the server
for i in range(num_clients):
    sample_size = client_agents[i].get_sample_size()
    server_agent.set_sample_size(
        client_id=client_agents[i].get_id(), sample_size=sample_size
    )

while not server_agent.training_finished():
    new_global_models = []
    for client_agent in client_agents:
        # Client local training
        client_agent.train()
        local_model = client_agent.get_parameters()
        if isinstance(local_model, tuple):
            local_model, metadata = local_model[0], local_model[1]
        else:
            metadata = {}
        # "Send" local model to server and get a Future object for the new global model
        # The Future object will be resolved when the server receives local models from all clients
        new_global_model_future = server_agent.global_update(
            client_id=client_agent.get_id(),
            local_model=local_model,
            blocking=False,
            **metadata,
        )
        new_global_models.append(new_global_model_future)
    # Load the new global model from the server
    for client_agent, new_global_model_future in zip(client_agents, new_global_models):
        client_agent.load_parameters(new_global_model_future.result())
appfl: βœ…[2025-01-08 10:06:18,444 server]: Logging to ./output/result_Server_2025-01-08-10-06-18.txt
appfl: βœ…[2025-01-08 10:06:18,460 Client1]: Logging to ./output/result_Client1_2025-01-08-10-06-18.txt
appfl: βœ…[2025-01-08 10:06:25,732 Client2]: Logging to ./output/result_Client2_2025-01-08-10-06-25.txt
appfl: βœ…[2025-01-08 10:06:32,544 Client3]: Logging to ./output/result_Client3_2025-01-08-10-06-32.txt
appfl: βœ…[2025-01-08 10:06:39,832 Client4]: Logging to ./output/result_Client4_2025-01-08-10-06-39.txt
appfl: βœ…[2025-01-08 10:06:47,010 Client5]: Logging to ./output/result_Client5_2025-01-08-10-06-47.txt
appfl: βœ…[2025-01-08 10:06:54,282 Client1]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 10:06:55,360 Client1]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 10:06:59,106 Client1]:          0          N     3.7451     0.3113        84.9531     6.7566      56.9400
appfl: βœ…[2025-01-08 10:06:59,161 Client2]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 10:07:00,201 Client2]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 10:07:03,147 Client2]:          0          N     2.9452     0.3075        82.9688     8.5662      46.1100
appfl: βœ…[2025-01-08 10:07:03,198 Client3]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 10:07:04,228 Client3]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 10:07:07,038 Client3]:          0          N     2.8099     0.1685        86.3125     6.6134      58.9900
appfl: βœ…[2025-01-08 10:07:07,091 Client4]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 10:07:08,100 Client4]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 10:07:10,845 Client4]:          0          N     2.7446     0.2278        85.4844     6.8239      58.5100
appfl: βœ…[2025-01-08 10:07:10,898 Client5]:      Round   Pre Val?       Time Train Loss Train Accuracy   Val Loss Val Accuracy
appfl: βœ…[2025-01-08 10:07:11,911 Client5]:          0          Y                                          2.3006      15.9300
appfl: βœ…[2025-01-08 10:07:14,735 Client5]:          0          N     2.8234     0.2344        82.7500     9.4964      46.6900
appfl: βœ…[2025-01-08 10:07:15,808 Client1]:          1          Y                                          1.6293      52.3100
appfl: βœ…[2025-01-08 10:07:18,685 Client1]:          1          N     2.8768     0.0952        96.2031     4.4368      57.2000
appfl: βœ…[2025-01-08 10:07:19,806 Client2]:          1          Y                                          1.6293      52.3100
appfl: βœ…[2025-01-08 10:07:22,713 Client2]:          1          N     2.9055     0.0961        95.0000     7.2504      46.2000
appfl: βœ…[2025-01-08 10:07:23,812 Client3]:          1          Y                                          1.6293      52.3100
appfl: βœ…[2025-01-08 10:07:26,750 Client3]:          1          N     2.9370     0.0727        94.5469     5.1526      59.5900
appfl: βœ…[2025-01-08 10:07:27,835 Client4]:          1          Y                                          1.6293      52.3100
appfl: βœ…[2025-01-08 10:07:30,725 Client4]:          1          N     2.8892     0.0815        95.3750     5.5127      59.6400
appfl: βœ…[2025-01-08 10:07:31,805 Client5]:          1          Y                                          1.6293      52.3100
appfl: βœ…[2025-01-08 10:07:34,711 Client5]:          1          N     2.9055     0.1016        93.1875     5.8957      46.8500
appfl: βœ…[2025-01-08 10:07:35,810 Client1]:          2          Y                                          0.9978      70.3300
appfl: βœ…[2025-01-08 10:07:38,703 Client1]:          2          N     2.8922     0.0485        97.9219     3.5707      57.9400
appfl: βœ…[2025-01-08 10:07:39,782 Client2]:          2          Y                                          0.9978      70.3300
appfl: βœ…[2025-01-08 10:07:42,690 Client2]:          2          N     2.9079     0.0610        96.7656     5.0698      47.3200
appfl: βœ…[2025-01-08 10:07:43,794 Client3]:          2          Y                                          0.9978      70.3300
appfl: βœ…[2025-01-08 10:07:46,690 Client3]:          2          N     2.8956     0.0407        96.6875     3.9961      60.2800
appfl: βœ…[2025-01-08 10:07:47,788 Client4]:          2          Y                                          0.9978      70.3300
appfl: βœ…[2025-01-08 10:07:50,689 Client4]:          2          N     2.9001     0.0479        97.3906     3.5627      60.1800
appfl: βœ…[2025-01-08 10:07:51,773 Client5]:          2          Y                                          0.9978      70.3300
appfl: βœ…[2025-01-08 10:07:54,684 Client5]:          2          N     2.9105     0.0704        95.1719     5.3279      47.7900