Intro tutorial (MedNIST)#

%load_ext tensorboard

Nodes inspection#

First thing, let’s check which nodes are available for training and their characteristics:

from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)
import os
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms

from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
    Activations,
    AddChannel,
    AsDiscrete,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    EnsureType,
)
from monai.utils import set_determinism
from torch.optim import AdamW, Adam, SGD

import matplotlib.pyplot as plt
import PIL
import pandas as pd
from tqdm import tqdm

Training plan definition#

A Training Plan contains the recipe for executing the training loop on the nodes. It defines: the data, the model, the loss function, and the optimizer. The code in the training plan is shipped in its entirety to the nodes, where its different parts are executed at different times during the training loop.

Our example cointains:

  1. a model instance

  2. an optimizer instance

  3. a list of dependencies (i.e. modules to be imported before instantiating the model and optimizer)

  4. how to load the training data (and potential preprocessing)

  5. a loss function

class TrainingPlan(TorchTrainingPlan):
    class MedNISTDataset(torch.utils.data.Dataset):
            def __init__(self, image_files, labels, transforms):
                self.image_files = image_files
                self.labels = labels
                self.transforms = transforms

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, index):
                return self.transforms(self.image_files[index]), self.labels[index]    
            
    def init_model(self, model_args):
        model = DenseNet121(spatial_dims=2, in_channels=1,
                    out_channels = model_args["num_class"])
        return model

    def init_dependencies(self):
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["import numpy as np",
                "import os",
                "from monai.apps import download_and_extract",
                "from monai.config import print_config",
                "from monai.data import decollate_batch",
                "from monai.metrics import ROCAUCMetric",
                "from monai.networks.nets import DenseNet121",
                "from torch.optim import AdamW, Adam, SGD",
                "from monai.transforms import ( Activations, AddChannel, AsDiscrete, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, EnsureType, )",
                "from monai.utils import set_determinism"]
        return deps
    
    def parse_data(self, path):
        class_names = sorted(x for x in os.listdir(path) if os.path.isdir(os.path.join(path, x)))
        num_class = len(class_names)
        image_files = [
                        [
                            os.path.join(path, class_names[i], x)
                            for x in os.listdir(os.path.join(path, class_names[i]))
                        ]
                        for i in range(num_class)
                      ]
        
        return image_files, num_class    
    
    def training_data(self, batch_size = 32):
        self.image_files, num_class = self.parse_data(self.dataset_path)
        
        if self.model_args()["num_class"] != num_class:
                raise Exception('number of available classes does not match declared classes')
        
        num_each = [len(self.image_files[i]) for i in range(self.model_args()["num_class"])]
        image_files_list = []
        image_class = []
        
        for i in range(self.model_args()["num_class"]):
            image_files_list.extend(self.image_files[i])
            image_class.extend([i] * num_each[i])

        train_transforms = Compose(
            [
                LoadImage(image_only=True),
                AddChannel(),
                ScaleIntensity(),
                RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
                RandFlip(spatial_axis=0, prob=0.5),
                RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
                EnsureType(),
            ]
        )
                
        self.train_ds = self.MedNISTDataset(image_files_list, image_class, train_transforms)
        
        return DataManager(dataset=self.train_ds, batch_size=batch_size, shuffle=True)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.cross_entropy(output, target)
        return loss

Experiment definition#

model_args = {
    'num_class': 6,  
}

training_args = {
    #'use_gpu': True,
    'batch_size': 20, 
    'optimizer_args': {
        'lr': 1e-5
    }, 
    'num_updates': 5, 
    'dry_run': False,
}

By changing the elements in tags we can do client selection.
As we saw, client1 is defined by [‘mednist-jupyter-username’, ‘client1’] and client2 is defined by [‘mednist-jupyter-username’, ‘client2’]

If we want to train a model with only client1, we can set tags = [‘mednist-jupyter-username’, ‘client1’]

TAGS: Replace %%%% in the tags with your username
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['mednist-jupyter-%%%%']
rounds = 3

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=TrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None,
                 #tensorboard=True,
                 #save_breakpoints=True
                )
exp.run()

Tensorboard: how to follow your run progresses#

Follow the instructions to obtain a port number, and run the commands below.

If <IP> is the IP assigned to you at this workshop, you may also view the tensorboard in a new browser tab at http://<IP>:<YOUR PORT NUMBER>

from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']
tensorboard --logdir "$tensorboard_dir" --host 0.0.0.0 --port <YOUR PORT NUMBER>
exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=TrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None,
                 tensorboard=True,
                 #save_breakpoints=True
                )
exp.run()

Retrieving the saved model#

trained_model = exp.training_plan().model()
trained_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])
trained_model

Testing the model on a local dataset#

class MedNISTDataset(torch.utils.data.Dataset):
        def __init__(self, image_files, labels, transforms):
            self.image_files = image_files
            self.labels = labels
            self.transforms = transforms

        def __len__(self):
            return len(self.image_files)

        def __getitem__(self, index):
            return self.transforms(self.image_files[index]), self.labels[index] 
def training_data(dataset_path, batch_size = 32):
    def parse_data(path):
        class_names = sorted(x for x in os.listdir(path) if os.path.isdir(os.path.join(path, x)))
        num_class = len(class_names)
        image_files = [
                        [
                            os.path.join(path, class_names[i], x)
                            for x in os.listdir(os.path.join(path, class_names[i]))
                        ]
                        for i in range(num_class)
                      ]
        
        return image_files, num_class        
    image_files, num_class = parse_data(dataset_path)


    num_each = [len(image_files[i]) for i in range(num_class)]
    image_files_list = []
    image_class = []

    for i in range(num_class):
        image_files_list.extend(image_files[i])
        image_class.extend([i] * num_each[i])


    transforms = Compose(
        [LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])



    ds = MedNISTDataset(image_files_list, image_class, transforms)

    return DataLoader(dataset=ds, batch_size=batch_size, shuffle=False)
def testing_accuracy(model, data_loader):
    model.eval()
    loss = 0
    correct = 0
    device = 'cpu'

    correct = 0
    y_pred = []
    y_actu = []
    
    with torch.no_grad():
        for data, target in tqdm(data_loader, desc=f"Evaluation"):
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            y_pred.extend(torch.flatten(pred).tolist()) 
            y_actu.extend(target.tolist())

        y_pred = pd.Series(y_pred, name='Actual')
        y_actu = pd.Series(y_actu, name='Predicted')
        cm = confusion_matrix(y_actu, y_pred, labels=range(6))

    loss /= len(data_loader.dataset)
    accuracy = 100* correct/len(data_loader.dataset)

    return(loss, accuracy, cm)
test_client_path = '/datasets/MedNIST/client_3'
test_dl = training_data(test_client_path)
test_loss, test_accuracy, test_cm = testing_accuracy(trained_model, test_dl)
print(f"Test loss = {test_loss:.2f}")
print(f"Test accuracy = {test_accuracy:.2f}%")

test_cm
disp = ConfusionMatrixDisplay(confusion_matrix=test_cm,
                            display_labels=range(6))
disp.plot()
plt.show()

Compare different aggregators: FedProx#

Similar to FedAveraging, FedProx performs a weighted sum of local model parameters. FedProx however introduces a regularization operation in order to tackle statistical heterogeneity. To use FedProx, use FedAverage from fedbiomed.researcher.aggregators and specify a value for in the training arguments training_args using the argument name fedprox_mu.

Try to change the mu value to see how this impact the performances.

model_args = {
    'num_class': 6,  
}

training_args = {
    #'use_gpu': True,
    'batch_size': 20, 
    'optimizer_args': {
        'lr': 1e-5
    }, 
    'num_updates': 5, 
    'dry_run': False,
    'fedprox_mu': 0.1,
}
exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=TrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None,
                 tensorboard=True,
                 #save_breakpoints=True
                )
exp.run()