Brain Segmentation#

This tutorial will show how to use Fed-BioMed to perform image segmentation on 3D medical MRI images of brains, using the publicly available IXI dataset. It uses a 3D U-Net model for the segmentation, trained on data from 3 separate centers.

Here we display a very complex case, using advanced Fed-BioMed functionalities such as:

  • exploring the datasets in the federation

  • loading a MedicalImageDataset

  • monitoring training loss with Tensorboard

Parts of this tutorial are based on TorchIO’s tutorial.

Table of Contents#

  • Task 1: Discovering datasets

  • Task 2: MedicalFolderDataset class

  • Task 3: Federated feature analytics

  • Task 4: Train a UNet model

  • Task 5: Validation on a local holdout set

%load_ext tensorboard
import os
import tabulate
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, SGD
from torch.utils.data import DataLoader
from monai.networks.nets import UNet
from monai.losses.dice import DiceLoss

from fedbiomed.common.training_plans import BaseTrainingPlan, TorchTrainingPlan
from fedbiomed.common.logger import logger
from fedbiomed.common.data import DataManager, MedicalFolderDataset
from fedbiomed.researcher.requests import Requests
from fedbiomed.researcher.aggregators import Aggregator, FedAverage
from fedbiomed.researcher.environ import environ
from fedbiomed.common.training_args import TrainingArgs
from fedbiomed.researcher.experiment import Experiment
%matplotlib inline

Task 1: Discovering datasets #

Let’s discover which datasets are available for federated training in the network.

Try it yourself!#

Read the documentation for the Requests class to figure out which function call can be used to list all the available datasets.

req = Requests()
datasets = # ... TODO fill in a one-line function call
pprint(datasets)

Filter results#

There are a lot of datasets available! However, most of them are from nodes that are going to collaborate with other users, not you. To identify datasets that were intended for your use, you should look at the value of the tags.

Try it yourself!#

Fill in the body of the for loop below such that the datasets_for_me variable follows these rules:

  • it has the same structure as the datasets variable

  • it has all and only the datasets whose tag contains your username

my_username = 'sharkovsky'
datasets_for_me = dict()
for node, _data in datasets.items():
    # ... TODO fill in the body of the for loop
pprint(datasets_for_me)

Tabulate results#

Try it yourself!#

Looking at the format of the datasets_for_me variable, produce a table with the following format using the tabulate package:

site

dataset name

sample size

node_3d7f08fa-ee13-4033-8a01-93b448b6c8be

ixi

177

tabulate_sample_sizes = list()
for node, _data in datasets_for_me.items():
    # ... TODO fill in the body of the for loop by populating the `tabulate_sample_sizes` variable to be passed to tabulate below
print(tabulate.tabulate(tabulate_sample_sizes, headers=('site', 'dataset name', 'sample size')))

Task 2: MedicalFolderDataset #

To help you get familiar with this loading medical imaging data in Fed-BioMed, we will practice on a smaller dataset that we assume is available to you locally as a researcher. This is often the case in FL settings where the reseaercher has a small holdout dataset for local validation.

We will use Fed-BioMed’s built-in class for image segmentation tasks: MedicalFolderDataset. This class supports several medical imaging modalities (think all the different types of MRI, CT, PET, etc..). It is optimized for segmentation tasks but can be applied to other tasks (e.g. classification, regression, …). This class supports loading a set of patient demographics data in csv format, in addition to the imaging data.

The inputs to the __init__ function are:

  • root: the filesystem path where the root of the dataset is located

  • data_modalities: the names of the modalities of interest for the input data

  • target_modalities: the names of the modalities of interest for the data to be predicted

  • transform: optional transformations to be performed on the input images

  • target_transform: optional transformations to be performed on the target data

  • demographics_transform: optional transformations to be performed on the demographics (csv) data

Let’s create the dataset:

dataset = MedicalFolderDataset(
    root='/datasets/ixi/holdout',  
    data_modalities=['T1', 'T2'],   
    target_modalities='label',      
    transform=None,       
    target_transform=None,  
    demographics_transform=None)  

Try it yourself!#

You can find out the total number of images with the len function, and access images individually with the [idx] operator, where idx is an integer index.

Note: dataset[0] will return a tuple corresponding to index 0. The tuple will be in the form (inputs, targets).

What is the format of inputs and targets? How do you access a single image? What data type is it, what is its format and shape? How do you access the corresponding binary mask?

# ... TODO call the [] operator on `dataset` and explore its contents following the guidelines above

Try it yourself!#

Plot two images. On the left, the raw input image. On the right, again the raw input image as background but overlayed with the ground truth segmentation. Take some time to familiriaze yourself with the approach for making the overlay plot, and all the plotting arguments that we use.

fig, ax = plt.subplots(1,2, figsize=(6,6))
slice_to_plot = # ... TODO choose a slice to plot in the depth dimension (what is the range?)

(image_modalities, demographics), target = # ... TODO use the [] operator to get a tuple from the dataset
img = # ... TODO from image_modalities, extract the correct modality and slice to be plotted
ax[0].imshow(img, cmap='bone')
ax[0].invert_yaxis()

ax[1].imshow(img, cmap='bone')  # plot again as background for overlay
label = # ... TODO from target, extract the correct key/value and slice to be plotted
plot = ax[1].imshow(np.ma.masked_where(
                        label < 0.5, 
                        label), cmap='winter', alpha=0.5, interpolation='none', vmin=0., vmax=1.)
ax[1].invert_yaxis()

fig.suptitle('Input slice with ground truth overlay')
fig.colorbar(plot, ax=ax, location='bottom')

Task 3: Federated feature analytics #

We want to obtain some basic information about the distribution of values of our features (i.e. the pixels in the MRI images). Our ultimate goal is to obtain an histogram of the counts of pixel values, but this will require a few steps, showcasing how flexible and interactive Fed-BioMed can be.

Hacking ahead!
The following code will slightly bend the current intended usage of some Fed-BioMed classes. You will be working on the alpha version of a functionality which we are actively working on right now.

Federated Min and Max: TrainingPlan#

We need to do some preliminary work in order to compute a federated histogram. One information that we need is the minimum and maximum pixel values over the whole federation, in order to compute a stable set of histogram bins across all the nodes in the federation. (Think about what would happen if we let each node compute their own histogram bins independently).

Implementation details#

We are going to “hack” a TorchTrainingPlan in order to compute a minimum and a maximum instead of training a model.

To understand what is going on exactly, you need the information below on some details about Fed-BioMed. If you are not interested, you can skip directly to the description of what you need to do to complete the code below.

The TrainingPlan should implement at least four important functions:

  • init_model: create the model (nn.Module) to be trained

  • init_optimizer: create the optimizer

  • training_data: to instantiate the dataset on the node (with customizations made by the researcher)

  • training_routine or training_step: the actual training code. The latter is a simple shorthand if you only need to customize the training on a single batch, without worrying about managing the data loader, iterations, etc..

The diagram below simplifies the main steps that constitute a run of a federated experiment in Fed-BioMed:

exp.run

Define data loading and transformations#

In what follows, we provide a suggestion for defining the data loading. The training_data function will use MedicalFolderDataset.

Additionally, we can define image transformations for the input images and the labels.

Finally, we define a transformation for the demographics data. Note that this is a required step to ensure that the demographics csv data is transformed into a torch Tensor usable by the training routine.

Our strategy: training_routine#

We will implement a training_routine that does not really perform any training. Instead, it does one full pass over the whole dataset to find the maximum and the minimum.

Our strategy: the MinMaxTracker model#

We will create a model that inherits from nn.Module but is not a neural network. Instead, it will simply contain a dictionary to store the minimum and maximum pixel values.

The init_model function of the TrainingPlan simply needs to return an instance of MinMaxTracker.

The init_optimizer function will return a dummy torch optimizer, since there is no actual optimization happening during our training.

The implementation details of how a model is handled can be a bit confusing at first. Here is a simplified diagram detailing the interactions between the main classes during a federated training round.

diagram

Try it yourself!#

Fill in the code below in the MinMaxTracker and training_routine.

class FedMinMaxTrainingPlan(TorchTrainingPlan):    
    class MinMaxTracker(nn.Module):
        def __init__(self):
            super().__init__()
            self.min_max = # ... TODO initialize a dictionary with 'min' and 'max' keys, and torch.Tensor values
            
        def state_dict(self):
            return self.min_max
        
        def load_state_dict(self, params, strict=False):
            self.min_max = # ... TODO update the state dictionary with the values from params
            # To mimick torch.load_state_dict, we create a small class that mocks the same output format
            class MockMissingKeys:
                def __init__(self):
                    self.missing_keys = []
                    self.unexpected_keys = []
            return MockMissingKeys()
            
        def named_parameters(self):
            return self.min_max.items()

    def init_model(self, model_arguments):
        return FedMinMaxTrainingPlan.MinMaxTracker()
    
    def training_step(self, *args, **kwargs):
        pass
    
    def init_optimizer(self, optimizer_arguments):
        return SGD([torch.Tensor([0])], lr=0.)  # dummy optimizer that will not be used
        
    def init_dependencies(self):
        deps = [
            "from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
            "from fedbiomed.common.data import MedicalFolderDataset",
            'import numpy as np',
            'from torch.optim import SGD'
        ]
        
        return deps

    @staticmethod
    def demographics_transform(demographics: dict):
        return {}

    def training_data(self,  batch_size = 4):
    # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (44, 44, 56)
        training_transform = Compose([AddChannel(), Resize(common_shape), NormalizeIntensity(),])
        target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
        
        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities='T1',
            target_modalities='label',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=FedMinMaxTrainingPlan.demographics_transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': False}
        return DataManager(dataset, **train_kwargs)
    
    def training_routine(self,
                         history_monitor = None,
                         node_args = None):
        count = 0
        global_min = self._model.get_weights()['min']
        global_max = self._model.get_weights()['max']
        for data, target in self.training_data_loader:  # iterate through the dataset
            count += 1
            # ... TODO get the image corresponding to the T1 modality
            
            # ... TODO find the minimum and maximum of the image, and if necessary update the `global_min` and `global_max` variables
                            
        # ... TODO use self._model.set_weights to update the minimum and maximum with global values just computed
                
        return count

Federated Min and Max: Aggregator#

The Aggregator must implement an aggregate function that returns a dictionary of model parameters.

Try it yourself!#

Fill in the code for the aggregate function.

The inputs are:

  • model_params: a dictionary {node_id: model_params_after_training}, where model_params_after_training is the state_dict of the MinMaxTracker model after local training on the node.

  • weights: a dictionary {node_id: weight} where the weight is a float between 0 and 1 computed as a proportion of the number of samples in the node to the total number of samples in the federation.

The output should be a dictionary with the same format as the state_dict of the MinMaxTracker model, i.e. it should be a dictionary:

{
  'min': torch.Tensor([aggregated_minimum]),
  'max': torch.Tensor([aggregated_maximum])
}
class MinMaxAggregator(Aggregator):
    def __init__(self):
        super(MinMaxAggregator, self).__init__()
        self.aggregator_name = "MinMaxAggregator"
        
    def aggregate(self, model_params: dict, weights: dict, *args, **kwargs) -> dict:
        # ... TODO fill in the body of this function to return a dictionary with 'min' and 'max' keys

Define the experiment#

TAGS: Replace %%%% in the tags with your username
tags =  ['ixi-jupyter-%%%%']

exp = Experiment(tags=tags,
                 model_args={},
                 training_plan_class=FedMinMaxTrainingPlan,
                 training_args={},
                 round_limit=1,  # just a single round, with a pass over the whole dataset
                 aggregator=MinMaxAggregator(),
                 tensorboard=False,
                 save_breakpoints=False
                )
exp.run_once()

Try it yourself!#

How can you access the minimum and maximum over the whole federation, after aggregation?

Hint: the experiment holds a copy of the training plan, which contains the model.

fed_min = # ... TODO obtain the minimum over the whole federation
fed_max = # ... TODO obtain the maximum over the whole federation

Federated Histogram: TrainingPlan#

We want to compute the average histogram of pixel values, where the average is computed over the different images.

Our strategy: training_routine#

We will implement a training_routine that does not really perform any training. Instead, it does one full pass over the whole dataset to compute the average histogram per image.

Our strategy: the FedHistogram model#

We will create a model that inherits from nn.Module but is not a neural network. Instead, it will simply contain a dictionary to store the histogram, minimum and maximum pixel values (computed before) and the number of bins. Note that the real “parameter” of this model, i.e. the one that will be updated, is the histogram data. All the rest are static values that will not be updated nor aggregated. We will use model_arguments and training_arguments to pass these values.

The init_model function of the TrainingPlan simply needs to return an instance of FedHistogram.

The init_optimizer function will return a dummy torch optimizer, since there is no actual optimization happening during our training.

Our strategy: training_data#

The training_data function will use MedicalFolderDataset.

Try it yourself!#

Fill in the code for FedHistogram, init_model and training_routine.

class FedHistogramTrainingPlan(TorchTrainingPlan):    
    class FedHistogram(nn.Module):
        def __init__(self, bin_min=0., bin_max=1e+5, nbins=10):
            super().__init__()
            self.hist_data = # ... TODO initialize a state dict with the histogram data, bin_min, bin_max, and nbins
            
        def state_dict(self):
            # ... TODO fill in the body of this function following the blueprint from the MinMaxTracker
        
        def load_state_dict(self, params, strict=False):
            # ... TODO fill in the body of this function following the blueprint from the MinMaxTracker
            
        def named_parameters(self):
            # ... TODO fill in the body of this function following the blueprint from the MinMaxTracker

    def init_model(self, model_args):
        # ... TODO return the model initialized with the values from the model_args
    
    def training_step(self, *args, **kwargs):
        pass
    
    def init_optimizer(self, optimizer_arguments):
        return SGD([torch.Tensor([0])], lr=0.)  # dummy optimizer
        
    def init_dependencies(self):
        deps = [
            "from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
            "from fedbiomed.common.data import MedicalFolderDataset",
            'import numpy as np',
            'from torch.optim import SGD'
        ]
        
        return deps

    @staticmethod
    def demographics_transform(demographics: dict):
        return {}

    def training_data(self,  batch_size = 4):
    # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (44, 44, 56)
        training_transform = Compose([AddChannel(), Resize(common_shape)])
        target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
        
        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities='T1',
            target_modalities='label',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=FedHistogramTrainingPlan.demographics_transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': False}
        return DataManager(dataset, **train_kwargs)
    
    def training_routine(self,
                         history_monitor = None,
                         node_args = None):
        
        hist_metadata = # ... TODO use self._model.get_weights to obtain useful information from FedHistogram's state_dict  
        
        count = 0
        # ... TODO iterate over the dataset and compute the histograms
        
        # ... TODO use self._model.set_weights to update the FedHistogram's state_dict            
                
        return count

Federated Histogram: Aggregator#

The Aggregator must implement an aggregate function that returns a dictionary of model parameters. In this case it is important to compute a weighted average of the histograms returned from each node (think about why a non-weighted mean is biased). For this, you can use the weights argument of the aggregate function.

Try it yourself!#

Fill in the code for the aggregate function.

Reminder, the inputs to aggregate are:

  • model_params: a dictionary {node_id: model_params_after_training}, where model_params_after_training is the state_dict of the MinMaxTracker model after local training on the node.

  • weights: a dictionary {node_id: weight} where the weight is a float between 0 and 1 computed as a proportion of the number of samples in the node to the total number of samples in the federation.

class HistAggregator(Aggregator):
    def __init__(self):
        super(HistAggregator, self).__init__()
        self.aggregator_name = "HistAggregator"
        
    def aggregate(self, model_params: dict, weights: dict, *args, **kwargs):
        hist = None
        for node_id, hist_data in model_params.items():
            if hist is None:
                hist = hist_data['hist']*weights[node_id]
            else:
                hist += hist_data['hist']*weights[node_id]
        return {'hist': hist,
                'nbins': hist_data['nbins'],
                'bin_min': hist_data['bin_min'],
                'bin_max': hist_data['bin_max']}

Define the experiment#

This time we will use model_args to pass to the node the federated minimum and maximum (computed before) as well as the number of bins.

Seralization: Arguments are sent over the network to the nodes: this requires that they can be serialized.

Our computed values for the federated minimum and maximum are not serializable because they are torch.Tensor: you need to find a way to convert them to regular python float.

model_args = {
# ... TODO fill in the dictionary contents here. Note: they must be coherent with the dictionary keys used in the `init_model` function of the TrainingPlan
}
TAGS: Replace %%%% in the tags with your username
tags =  ['ixi-jupyter-%%%%']

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=FedHistogramTrainingPlan,
                 training_args={},
                 round_limit=1,
                 aggregator=HistAggregator(),
                 tensorboard=False,
                 save_breakpoints=False
                )
exp.run_once()

Sanity check#

It is always good to check the outputs of our calculations.

Try it yourself!#

What should the sum of all values in the aggregated histogram be equal to?

sum_of_pixel_counts = # ... TODO
expected_value = # ... TODO
print(f"The sum of average pixel counts {sum_of_pixel_counts} "
      f"should be equal to the total image size {expected_value}")

Plot the histograms#

The aggregated histogram#

Try it yourself#
  1. Using the values from model_args, obtain the size (the width) of each bin

  2. Compute the array of bin edges

  3. plot the histogram using the ax.bar function

fig, ax = plt.subplots(figsize=(12,6))
width_histogram_bin = # ... TODO
bin_edges = np.arange(start=,# ... TODO
                      stop=,# ... TODO
                      step=# ... TODO
                      
ax.bar(bin_edges, 
       ,# ... TODO fill in with the aggregated histogram
       width=0.99*width_histogram_bin)
                      
_ = ax.set_ylabel('Average count per image')                      
_ = ax.set_xlabel('Pixel value')
_ = ax.set_title('Aggregated histogram')

The node-wise histograms#

Try it yourself#

You can access the models’ state_dicts (after training) through the exp.training_replies() function. The output is a dictionary of the format {round: node_replies}, where node_replies is a list of replies. Each reply is a dictionary, where node_id and params are the most important keys for this task.

fig, ax = plt.subplots(2,1,figsize=(12,6))

# ... TODO as before, compute the width_histogram_bin and bin_edges

for i in range(2):  # the i^th node
    hist_data =  # ... TODO obtain the histogram from `exp.training_replies`
    node_id = # ... TODO obtain the node_id from `exp.training_replies`
    ax[i].bar(bin_edges, hist_data, width=0.99*width_histogram_bar)
    _ = ax[i].set_ylabel('Average count per image')
    _ = ax[i].set_title(f'Histogram for {node_id}')
    
_ = ax[1].set_xlabel('Pixel value')
Try it yourself!#

What happens if we normalize the images while loading them? Add NormalizeIntensity() as an additional transform for the loaded images in the TrainingPlan, and check how the histogram changes shape.

Task 4: Training a UNet model for the brain segmentation task #

Create a Training Plan#

We create a training plan that incorporates the UNet model.

Define the neural network model#

We recommend using MONAI’s UNet implementation.

We define the model in the __init__ and forward functions of the training plan.

Define the loss function#

Loss function is computed based on the Dice Loss.

Carole H Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, and M Jorge Cardoso. Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In Deep learning in medical image analysis and multimodal learning for clinical decision support, pages 240–248. Springer, 2017.

Define training step#

Here we take as input one batch of (data, target), train the model and compute the loss function.

Note that the MedicalFolderDataset class returns data as a tuple of (images, demographics), where:

  • images is a dict of {modality: image} (after image transformations)

  • demographics is a dict of {column_name: values} where the column names are taken from the demographics csv file while the target is a dict of {modality: image} (after target transformations).

In our case, the modality used is T1 for the input images, while the modality used for the target is label. We also ignore the values of the demographics data during training because the UNet model only takes images as input, but the code provided still shows the recommended way to handle such data.

class UNetTrainingPlan(TorchTrainingPlan):
    
    def init_model(self, model_args):
        n_base_filters = # ... TODO get this from model_args
        return UNet(
                spatial_dims=, # ... TODO get this from model_args
                in_channels=, # ... TODO get this from model_args
                out_channels=, # ... TODO get this from model_args 
                channels=(n_base_filters, 
                           2*n_base_filters,
                           4*n_base_filters), 
                strides=(2,2),  
                kernel_size=3,  # you may parametrize additional values below if you wish 
                up_kernel_size=3, 
                num_res_units=0, 
                act='PRELU', 
                norm='INSTANCE', 
                dropout=0.0, 
                bias=True, 
                adn_ordering='NDA'
            )
    
    def init_optimizer(self, optimizer_args):
        # ... TODO define and return the optimizer
    
    def init_dependencies(self):
        deps = [
            "from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
            "import torch.nn as nn",
            'import torch.nn.functional as F',
            "from fedbiomed.common.data import MedicalFolderDataset",
            'import numpy as np',
            'from torch.optim import AdamW, SGD',
            'from monai.networks.nets import UNet',
            'from fedbiomed.common.logger import logger']
        
        return deps
    
    @staticmethod
    def get_dice_loss(output, target, epsilon=1e-9):
        loss = DiceLoss(include_background=False, sigmoid=False)
        return loss(output, target)

    @staticmethod
    def demographics_transform(demographics: dict):
        return {}

    def training_data(self,  batch_size = 4):
    # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (44, 44, 56)
        training_transform = Compose([AddChannel(), Resize(common_shape)])
        target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
        
        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities='T1',
            target_modalities='label',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=UNetTrainingPlan.demographics_transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset, **train_kwargs)
    
    
    def training_step(self, data, target):
        # TODO ... extract the T1 modality from `data` 
        # TODO ... call model forward
        # TODO ... call F.softmax on the output of the model
        # TODO ... compute the loss
        # TODO ... extract the label from `target`
        # TODO ... compute the loss, then average it
        return avg_loss_over_batch
    
    def testing_step(self, data, target):
        # TODO ... extract the T1 modality from `data` 
        # TODO ... call model forward
        # TODO ... call F.softmax on the output of the model
        # TODO ... compute the loss
        # TODO ... extract the label from `target`
        # TODO ... compute the loss, then average it
        return avg_loss

Define Parameters#

Here you can define model_args and training_args, two dictionaries that contain parameters and hyperparameters for training and model definition. This provides a flexible way to explore the hyperparameter space without changing the TrainingPlan, which has potentially been fixed and validated by the clinical partners.

Try it yourself!#

Change any of the parameters below to explore the space of hyperparameters. Please be mindful of the limited available resources when changing parameters that may require more computing power, such as base_filters and especially batch_size.

model_args = {
    'in_channels': 1,
    'out_channels': 2,
    'dimensions': 3,
    'base_filters': 10,
}

training_args = {
    'batch_size': 4,
    'num_updates': 8, 
    'dry_run': False,
    'log_interval': 2,
    'test_ratio' : 0.1,
    'test_on_global_updates': True,
    'test_on_local_updates': False,
    'optimizer_args': {
        'opt_name': 'adamw',
        'lr': 0.001
    }
}

num_rounds = 15

Dry run your TrainingPlan locally#

Since we have an holdout dataset available locally, we are going to test that the TrainingPlan is able to run locally, before we perform the federated training.

Try it yourself!#

First, create a dataloader following these steps:

  1. instantiate a UNetTrainingPlan object

  2. call post_init on the training plan. Note: you will need to instantiate a TrainingArgs object from the training_args dict

  3. set the training plan’s dataset_path to /datasets/ixi/holdout

  4. Call the training_data function from the loaded experiment’s training plan to obtain a TorchDataManager

  5. Use the _dataset attribute of the data manager to instantiate a torch DataLoader (set a small batch size

Then, perform one training iteration to check that it completes without errors:

  1. create a for loop iterating on the dataloader

  2. call the training plan’s training_step method

  3. break after the first iteration

dryrun_tp = # ... TODO instantiate training plan object
# ... TODO call post_init with the following second argument: TrainingArgs(training_args, only_required=False)
# ... TODO set dataset path
# ... TODO call training_data
# ... TODO create torch DataLoader
# ... TODO for loop on data loader
# ... TODO call training_step
# ... TODO break
TAGS: Replace %%%% in the tags with your username
tags =  ['ixi-jupyter-%%%%']

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=UNetTrainingPlan,
                 training_args=training_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 tensorboard=True,
                 save_breakpoints=True
                )

Run tensorboard#

Follow the instructions to obtain a port number, and run the commands below.

If <IP> is the IP assigned to you at this workshop, you may also view the tensorboard in a new browser tab at http://<IP>:<YOUR PORT NUMBER>

tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']
tensorboard --logdir "$tensorboard_dir" --host 0.0.0.0 --port <YOUR PORT NUMBER>

Run the experiment#

print(f"Saving breakpoints to {exp.experimentation_folder()}")
exp.run()

Task 5: Validate on a local holdout set #

To ensure consistency and simplify our life, we try to reuse the already-available code as much as possible. Note that this process assumes that the held-out data is stored locally on the machine.

Create an instance of the global model#

First, we create an instance of the model using the parameters from the latest aggregation round.

exp_folder = exp.experimentation_folder()
#exp_folder = 'Experiment_0030'
breakpoint_num = 14
breakpoint_dir = os.path.join(environ['EXPERIMENTS_DIR'], exp_folder, f'breakpoint_{breakpoint_num:04d}')

Try it yourself!#

  1. Call the post_init function to initialize the model in the newly loaded experiment’s training plan (Note, you will need to create a TrainingArgs object)

  2. extract the model

  3. call the model’s load_state_dict, by passing as argument the model weights obtained from the loaded experiment’s _aggregated_params

loaded_exp = Experiment.load_breakpoint(breakpoint_dir)
loaded_exp.training_plan().post_init(# ... TODO fill in arguments )
valid_model = # ... TODO obtain the model from the experiment's training plan
# ... TODO call load_state_dict

Define a validation data loader#

We extract the validation data loader from the training plan as well. This requires some knowledge about the internals of the MedicalFolderDataset class. At the end of the process, calling the split function with a ratio of 0 will return a data loader that loads all of the data.

Try it yourself!#

Call the training_data function from the loaded experiment’s training plan to obtain a TorchDataManager. Use the _dataset attribute of the data manager to instantiate a torch DataLoader.

Note: use a batch size of 1 for local validation.

loaded_exp.training_plan().dataset_path = '/datasets/ixi/holdout'
val_data_manager = # ... TODO call training_data
val_data_loader = # ... TODO instantiate a torch DataLoader

Compute the loss on validation images#

Try it yourself!#

  1. Iterate over the validation dataset using the DataLoader defined above.

  2. Compute the predictions by manually calling the model forward and F.softmax methods.

  3. Compute the loss by calling the UNetTrainingPlan.get_dice_loss function.

  4. Store the loss values in a list.

Attention: do not forget to set valid_model.eval() and to use the torch.no_grad context.

losses = []
# ... TODO fill in the code following the instructions above
loss_array = np.array(losses)
print(f'Minumum: {loss_array.min():0.3f} ' 
      f'Mean: {loss_array.mean():0.3f} '
      f'StdDev: {loss_array.std():0.3f} '
      f'Maximum: {loss_array.max():0.3f} ')

Visualize the outputs#

We visualize the outputs of our model on the holdout dataset.

one_batch = next(iter(val_data_loader))

Try it yourself!#

Create a grid of 4x2 images. In each row, plot a slice of the original on the left overlayed with the ground truth, and the same slice of the original on the right overlayed with the prediction. Insert the loss value in the title of the image. Remember that to compute the predictions you need to pass the whole 3d image to model.forward, not just a slice.

fig, ax = plt.subplots(4,2, figsize=(8,16))
slice_to_plot = 24
predictions = valid_model.forward(one_batch[0][0]['T1'])
for i in range(4):
    (image_modalities, demographics), target = next(val_data_loader_iter)
    img = # ... TODO obtain the slice of the image to be plotted
    ax[i][0].imshow(img, cmap='bone')
    label = # ... TODO obtain the slice of the label to be plotted
    plot = ax[i][0].imshow(np.ma.masked_where(
                            label < 0.5, 
                            label), cmap='winter', alpha=0.7)
    ax[i][0].set_title('Input slice with ground truth overlay')
    # ... TODO compute the prediction for this image
    # ... TODO following the blueprint above, plot the original image with the overlayed prediction
    ax[i][1].set_title(f'Input slice with prediction overlay - loss: {losses[i]:0.3f}')
    
fig.colorbar(plot)