Brain Segmentation#
This tutorial will show how to use Fed-BioMed to perform image segmentation on 3D medical MRI images of brains, using the publicly available IXI dataset. It uses a 3D U-Net model for the segmentation, trained on data from 3 separate centers.
Here we display a very complex case, using advanced Fed-BioMed functionalities such as:
exploring the datasets in the federation
loading a MedicalImageDataset
monitoring training loss with Tensorboard
Parts of this tutorial are based on TorchIO’s tutorial.
Table of Contents#
Task 1: Discovering datasets
Task 2: MedicalFolderDataset class
Task 3: Federated feature analytics
Task 4: Train a UNet model
Task 5: Validation on a local holdout set
%load_ext tensorboard
import os
import tabulate
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, SGD
from torch.utils.data import DataLoader
from monai.networks.nets import UNet
from monai.losses.dice import DiceLoss
from fedbiomed.common.training_plans import BaseTrainingPlan, TorchTrainingPlan
from fedbiomed.common.logger import logger
from fedbiomed.common.data import DataManager, MedicalFolderDataset
from fedbiomed.researcher.requests import Requests
from fedbiomed.researcher.aggregators import Aggregator, FedAverage
from fedbiomed.researcher.environ import environ
from fedbiomed.common.training_args import TrainingArgs
from fedbiomed.researcher.experiment import Experiment
%matplotlib inline
Task 1: Discovering datasets #
Let’s discover which datasets are available for federated training in the network.
Try it yourself!#
Read the documentation for the Requests
class to figure out which function call can be used to list all the available datasets.
req = Requests()
datasets = # ... TODO fill in a one-line function call
pprint(datasets)
Filter results#
There are a lot of datasets available! However, most of them are from nodes that are going to collaborate with other users, not you. To identify datasets that were intended for your use, you should look at the value of the tags
.
Try it yourself!#
Fill in the body of the for loop below such that the datasets_for_me
variable follows these rules:
it has the same structure as the
datasets
variableit has all and only the datasets whose tag contains your username
my_username = 'sharkovsky'
datasets_for_me = dict()
for node, _data in datasets.items():
# ... TODO fill in the body of the for loop
pprint(datasets_for_me)
Tabulate results#
Try it yourself!#
Looking at the format of the datasets_for_me
variable, produce a table with the following format using the tabulate
package:
site |
dataset name |
sample size |
---|---|---|
node_3d7f08fa-ee13-4033-8a01-93b448b6c8be |
ixi |
177 |
tabulate_sample_sizes = list()
for node, _data in datasets_for_me.items():
# ... TODO fill in the body of the for loop by populating the `tabulate_sample_sizes` variable to be passed to tabulate below
print(tabulate.tabulate(tabulate_sample_sizes, headers=('site', 'dataset name', 'sample size')))
Task 2: MedicalFolderDataset #
To help you get familiar with this loading medical imaging data in Fed-BioMed, we will practice on a smaller dataset that we assume is available to you locally as a researcher. This is often the case in FL settings where the reseaercher has a small holdout dataset for local validation.
We will use Fed-BioMed’s built-in class for image segmentation tasks: MedicalFolderDataset
.
This class supports several medical imaging modalities (think all the different types of MRI, CT, PET, etc..). It is optimized for segmentation tasks but can be applied to other tasks (e.g. classification, regression, …).
This class supports loading a set of patient demographics data in csv format, in addition to the imaging data.
The inputs to the __init__
function are:
root: the filesystem path where the root of the dataset is located
data_modalities: the names of the modalities of interest for the input data
target_modalities: the names of the modalities of interest for the data to be predicted
transform: optional transformations to be performed on the input images
target_transform: optional transformations to be performed on the target data
demographics_transform: optional transformations to be performed on the demographics (csv) data
Let’s create the dataset:
dataset = MedicalFolderDataset(
root='/datasets/ixi/holdout',
data_modalities=['T1', 'T2'],
target_modalities='label',
transform=None,
target_transform=None,
demographics_transform=None)
Try it yourself!#
You can find out the total number of images with the len
function, and access images individually with the [idx]
operator, where idx
is an integer index.
Note: dataset[0]
will return a tuple corresponding to index 0
. The tuple will be in the form (inputs, targets)
.
What is the format of inputs
and targets
? How do you access a single image? What data type is it, what is its format and shape? How do you access the corresponding binary mask?
# ... TODO call the [] operator on `dataset` and explore its contents following the guidelines above
Try it yourself!#
Plot two images. On the left, the raw input image. On the right, again the raw input image as background but overlayed with the ground truth segmentation. Take some time to familiriaze yourself with the approach for making the overlay plot, and all the plotting arguments that we use.
fig, ax = plt.subplots(1,2, figsize=(6,6))
slice_to_plot = # ... TODO choose a slice to plot in the depth dimension (what is the range?)
(image_modalities, demographics), target = # ... TODO use the [] operator to get a tuple from the dataset
img = # ... TODO from image_modalities, extract the correct modality and slice to be plotted
ax[0].imshow(img, cmap='bone')
ax[0].invert_yaxis()
ax[1].imshow(img, cmap='bone') # plot again as background for overlay
label = # ... TODO from target, extract the correct key/value and slice to be plotted
plot = ax[1].imshow(np.ma.masked_where(
label < 0.5,
label), cmap='winter', alpha=0.5, interpolation='none', vmin=0., vmax=1.)
ax[1].invert_yaxis()
fig.suptitle('Input slice with ground truth overlay')
fig.colorbar(plot, ax=ax, location='bottom')
Task 3: Federated feature analytics #
We want to obtain some basic information about the distribution of values of our features (i.e. the pixels in the MRI images). Our ultimate goal is to obtain an histogram of the counts of pixel values, but this will require a few steps, showcasing how flexible and interactive Fed-BioMed can be.
The following code will slightly bend the current intended usage of some Fed-BioMed classes. You will be working on the alpha version of a functionality which we are actively working on right now.
Federated Min and Max: TrainingPlan#
We need to do some preliminary work in order to compute a federated histogram. One information that we need is the minimum and maximum pixel values over the whole federation, in order to compute a stable set of histogram bins across all the nodes in the federation. (Think about what would happen if we let each node compute their own histogram bins independently).
Implementation details#
We are going to “hack” a TorchTrainingPlan
in order to compute a minimum and a maximum instead of training a model.
To understand what is going on exactly, you need the information below on some details about Fed-BioMed. If you are not interested, you can skip directly to the description of what you need to do to complete the code below.
The TrainingPlan
should implement at least four important functions:
init_model
: create the model (nn.Module
) to be trainedinit_optimizer
: create the optimizertraining_data
: to instantiate the dataset on the node (with customizations made by the researcher)training_routine
ortraining_step
: the actual training code. The latter is a simple shorthand if you only need to customize the training on a single batch, without worrying about managing the data loader, iterations, etc..
The diagram below simplifies the main steps that constitute a run of a federated experiment in Fed-BioMed:
Define data loading and transformations#
In what follows, we provide a suggestion for defining the data loading. The training_data
function will use MedicalFolderDataset
.
Additionally, we can define image transformations for the input images and the labels.
Finally, we define a transformation for the demographics data. Note that this is a required step to ensure that the demographics csv data is transformed into a torch Tensor usable by the training routine.
Our strategy: training_routine#
We will implement a training_routine
that does not really perform any training. Instead, it does one full pass over the whole dataset to find the maximum and the minimum.
Our strategy: the MinMaxTracker model#
We will create a model that inherits from nn.Module
but is not a neural network. Instead, it will simply contain a dictionary to store the minimum and maximum pixel values.
The init_model
function of the TrainingPlan
simply needs to return an instance of MinMaxTracker
.
The init_optimizer
function will return a dummy torch optimizer, since there is no actual optimization happening during our training.
The implementation details of how a model is handled can be a bit confusing at first. Here is a simplified diagram detailing the interactions between the main classes during a federated training round.
Try it yourself!#
Fill in the code below in the MinMaxTracker
and training_routine
.
class FedMinMaxTrainingPlan(TorchTrainingPlan):
class MinMaxTracker(nn.Module):
def __init__(self):
super().__init__()
self.min_max = # ... TODO initialize a dictionary with 'min' and 'max' keys, and torch.Tensor values
def state_dict(self):
return self.min_max
def load_state_dict(self, params, strict=False):
self.min_max = # ... TODO update the state dictionary with the values from params
# To mimick torch.load_state_dict, we create a small class that mocks the same output format
class MockMissingKeys:
def __init__(self):
self.missing_keys = []
self.unexpected_keys = []
return MockMissingKeys()
def named_parameters(self):
return self.min_max.items()
def init_model(self, model_arguments):
return FedMinMaxTrainingPlan.MinMaxTracker()
def training_step(self, *args, **kwargs):
pass
def init_optimizer(self, optimizer_arguments):
return SGD([torch.Tensor([0])], lr=0.) # dummy optimizer that will not be used
def init_dependencies(self):
deps = [
"from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
"from fedbiomed.common.data import MedicalFolderDataset",
'import numpy as np',
'from torch.optim import SGD'
]
return deps
@staticmethod
def demographics_transform(demographics: dict):
return {}
def training_data(self, batch_size = 4):
# The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
common_shape = (44, 44, 56)
training_transform = Compose([AddChannel(), Resize(common_shape), NormalizeIntensity(),])
target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
dataset = MedicalFolderDataset(
root=self.dataset_path,
data_modalities='T1',
target_modalities='label',
transform=training_transform,
target_transform=target_transform,
demographics_transform=FedMinMaxTrainingPlan.demographics_transform)
train_kwargs = {'batch_size': batch_size, 'shuffle': False}
return DataManager(dataset, **train_kwargs)
def training_routine(self,
history_monitor = None,
node_args = None):
count = 0
global_min = self._model.get_weights()['min']
global_max = self._model.get_weights()['max']
for data, target in self.training_data_loader: # iterate through the dataset
count += 1
# ... TODO get the image corresponding to the T1 modality
# ... TODO find the minimum and maximum of the image, and if necessary update the `global_min` and `global_max` variables
# ... TODO use self._model.set_weights to update the minimum and maximum with global values just computed
return count
Federated Min and Max: Aggregator#
The Aggregator
must implement an aggregate
function that returns a dictionary of model parameters.
Try it yourself!#
Fill in the code for the aggregate
function.
The inputs are:
model_params
: a dictionary{node_id: model_params_after_training}
, wheremodel_params_after_training
is thestate_dict
of theMinMaxTracker
model after local training on the node.weights
: a dictionary{node_id: weight}
where the weight is a float between 0 and 1 computed as a proportion of the number of samples in the node to the total number of samples in the federation.
The output should be a dictionary with the same format as the state_dict
of the MinMaxTracker
model, i.e. it should be a dictionary:
{
'min': torch.Tensor([aggregated_minimum]),
'max': torch.Tensor([aggregated_maximum])
}
class MinMaxAggregator(Aggregator):
def __init__(self):
super(MinMaxAggregator, self).__init__()
self.aggregator_name = "MinMaxAggregator"
def aggregate(self, model_params: dict, weights: dict, *args, **kwargs) -> dict:
# ... TODO fill in the body of this function to return a dictionary with 'min' and 'max' keys
Define the experiment#
tags = ['ixi-jupyter-%%%%']
exp = Experiment(tags=tags,
model_args={},
training_plan_class=FedMinMaxTrainingPlan,
training_args={},
round_limit=1, # just a single round, with a pass over the whole dataset
aggregator=MinMaxAggregator(),
tensorboard=False,
save_breakpoints=False
)
exp.run_once()
Try it yourself!#
How can you access the minimum and maximum over the whole federation, after aggregation?
Hint: the experiment holds a copy of the training plan, which contains the model.
fed_min = # ... TODO obtain the minimum over the whole federation
fed_max = # ... TODO obtain the maximum over the whole federation
Federated Histogram: TrainingPlan#
We want to compute the average histogram of pixel values, where the average is computed over the different images.
Our strategy: training_routine#
We will implement a training_routine
that does not really perform any training. Instead, it does one full pass over the whole dataset to compute the average histogram per image.
Our strategy: the FedHistogram model#
We will create a model that inherits from nn.Module
but is not a neural network. Instead, it will simply contain a dictionary to store the histogram, minimum and maximum pixel values (computed before) and the number of bins. Note that the real “parameter” of this model, i.e. the one that will be updated, is the histogram data. All the rest are static values that will not be updated nor aggregated. We will use model_arguments
and training_arguments
to pass these values.
The init_model
function of the TrainingPlan
simply needs to return an instance of FedHistogram
.
The init_optimizer
function will return a dummy torch optimizer, since there is no actual optimization happening during our training.
Our strategy: training_data#
The training_data
function will use MedicalFolderDataset
.
Try it yourself!#
Fill in the code for FedHistogram
, init_model
and training_routine
.
class FedHistogramTrainingPlan(TorchTrainingPlan):
class FedHistogram(nn.Module):
def __init__(self, bin_min=0., bin_max=1e+5, nbins=10):
super().__init__()
self.hist_data = # ... TODO initialize a state dict with the histogram data, bin_min, bin_max, and nbins
def state_dict(self):
# ... TODO fill in the body of this function following the blueprint from the MinMaxTracker
def load_state_dict(self, params, strict=False):
# ... TODO fill in the body of this function following the blueprint from the MinMaxTracker
def named_parameters(self):
# ... TODO fill in the body of this function following the blueprint from the MinMaxTracker
def init_model(self, model_args):
# ... TODO return the model initialized with the values from the model_args
def training_step(self, *args, **kwargs):
pass
def init_optimizer(self, optimizer_arguments):
return SGD([torch.Tensor([0])], lr=0.) # dummy optimizer
def init_dependencies(self):
deps = [
"from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
"from fedbiomed.common.data import MedicalFolderDataset",
'import numpy as np',
'from torch.optim import SGD'
]
return deps
@staticmethod
def demographics_transform(demographics: dict):
return {}
def training_data(self, batch_size = 4):
# The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
common_shape = (44, 44, 56)
training_transform = Compose([AddChannel(), Resize(common_shape)])
target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
dataset = MedicalFolderDataset(
root=self.dataset_path,
data_modalities='T1',
target_modalities='label',
transform=training_transform,
target_transform=target_transform,
demographics_transform=FedHistogramTrainingPlan.demographics_transform)
train_kwargs = {'batch_size': batch_size, 'shuffle': False}
return DataManager(dataset, **train_kwargs)
def training_routine(self,
history_monitor = None,
node_args = None):
hist_metadata = # ... TODO use self._model.get_weights to obtain useful information from FedHistogram's state_dict
count = 0
# ... TODO iterate over the dataset and compute the histograms
# ... TODO use self._model.set_weights to update the FedHistogram's state_dict
return count
Federated Histogram: Aggregator#
The Aggregator
must implement an aggregate
function that returns a dictionary of model parameters.
In this case it is important to compute a weighted average of the histograms returned from each node (think about why a non-weighted mean is biased). For this, you can use the weights
argument of the aggregate
function.
Try it yourself!#
Fill in the code for the aggregate
function.
Reminder, the inputs to aggregate
are:
model_params
: a dictionary{node_id: model_params_after_training}
, wheremodel_params_after_training
is thestate_dict
of theMinMaxTracker
model after local training on the node.weights
: a dictionary{node_id: weight}
where the weight is a float between 0 and 1 computed as a proportion of the number of samples in the node to the total number of samples in the federation.
class HistAggregator(Aggregator):
def __init__(self):
super(HistAggregator, self).__init__()
self.aggregator_name = "HistAggregator"
def aggregate(self, model_params: dict, weights: dict, *args, **kwargs):
hist = None
for node_id, hist_data in model_params.items():
if hist is None:
hist = hist_data['hist']*weights[node_id]
else:
hist += hist_data['hist']*weights[node_id]
return {'hist': hist,
'nbins': hist_data['nbins'],
'bin_min': hist_data['bin_min'],
'bin_max': hist_data['bin_max']}
Define the experiment#
This time we will use model_args
to pass to the node the federated minimum and maximum (computed before) as well as the number of bins.
Our computed values for the federated minimum and maximum are not serializable because they are torch.Tensor
: you need to find a way to convert them to regular python float
.
model_args = {
# ... TODO fill in the dictionary contents here. Note: they must be coherent with the dictionary keys used in the `init_model` function of the TrainingPlan
}
tags = ['ixi-jupyter-%%%%']
exp = Experiment(tags=tags,
model_args=model_args,
training_plan_class=FedHistogramTrainingPlan,
training_args={},
round_limit=1,
aggregator=HistAggregator(),
tensorboard=False,
save_breakpoints=False
)
exp.run_once()
Sanity check#
It is always good to check the outputs of our calculations.
Try it yourself!#
What should the sum of all values in the aggregated histogram be equal to?
sum_of_pixel_counts = # ... TODO
expected_value = # ... TODO
print(f"The sum of average pixel counts {sum_of_pixel_counts} "
f"should be equal to the total image size {expected_value}")
Plot the histograms#
The aggregated histogram#
Try it yourself#
Using the values from
model_args
, obtain the size (the width) of each binCompute the array of bin edges
plot the histogram using the ax.bar function
fig, ax = plt.subplots(figsize=(12,6))
width_histogram_bin = # ... TODO
bin_edges = np.arange(start=,# ... TODO
stop=,# ... TODO
step=# ... TODO
ax.bar(bin_edges,
,# ... TODO fill in with the aggregated histogram
width=0.99*width_histogram_bin)
_ = ax.set_ylabel('Average count per image')
_ = ax.set_xlabel('Pixel value')
_ = ax.set_title('Aggregated histogram')
The node-wise histograms#
Try it yourself#
You can access the models’ state_dicts (after training) through the exp.training_replies()
function. The output is a dictionary of the format {round: node_replies}
, where node_replies
is a list of replies. Each reply is a dictionary, where node_id
and params
are the most important keys for this task.
fig, ax = plt.subplots(2,1,figsize=(12,6))
# ... TODO as before, compute the width_histogram_bin and bin_edges
for i in range(2): # the i^th node
hist_data = # ... TODO obtain the histogram from `exp.training_replies`
node_id = # ... TODO obtain the node_id from `exp.training_replies`
ax[i].bar(bin_edges, hist_data, width=0.99*width_histogram_bar)
_ = ax[i].set_ylabel('Average count per image')
_ = ax[i].set_title(f'Histogram for {node_id}')
_ = ax[1].set_xlabel('Pixel value')
Try it yourself!#
What happens if we normalize the images while loading them?
Add NormalizeIntensity()
as an additional transform for the loaded images in the TrainingPlan
, and check how the histogram changes shape.
Task 4: Training a UNet model for the brain segmentation task #
Create a Training Plan#
We create a training plan that incorporates the UNet model.
Define the neural network model#
We recommend using MONAI’s UNet implementation.
We define the model in the __init__
and forward
functions of the training plan.
Define the loss function#
Loss function is computed based on the Dice Loss.
Carole H Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, and M Jorge Cardoso. Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In Deep learning in medical image analysis and multimodal learning for clinical decision support, pages 240–248. Springer, 2017.
Define training step#
Here we take as input one batch of (data, target), train the model and compute the loss function.
Note that the MedicalFolderDataset
class returns data
as a tuple of (images, demographics)
, where:
images
is adict
of{modality: image
} (after image transformations)demographics
is adict
of{column_name: values}
where the column names are taken from the demographics csv file while thetarget
is adict
of{modality: image
} (after target transformations).
In our case, the modality used is T1
for the input images, while the modality used for the target is label
. We also ignore the values of the demographics data during training because the UNet model only takes images as input, but the code provided still shows the recommended way to handle such data.
class UNetTrainingPlan(TorchTrainingPlan):
def init_model(self, model_args):
n_base_filters = # ... TODO get this from model_args
return UNet(
spatial_dims=, # ... TODO get this from model_args
in_channels=, # ... TODO get this from model_args
out_channels=, # ... TODO get this from model_args
channels=(n_base_filters,
2*n_base_filters,
4*n_base_filters),
strides=(2,2),
kernel_size=3, # you may parametrize additional values below if you wish
up_kernel_size=3,
num_res_units=0,
act='PRELU',
norm='INSTANCE',
dropout=0.0,
bias=True,
adn_ordering='NDA'
)
def init_optimizer(self, optimizer_args):
# ... TODO define and return the optimizer
def init_dependencies(self):
deps = [
"from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
"import torch.nn as nn",
'import torch.nn.functional as F',
"from fedbiomed.common.data import MedicalFolderDataset",
'import numpy as np',
'from torch.optim import AdamW, SGD',
'from monai.networks.nets import UNet',
'from fedbiomed.common.logger import logger']
return deps
@staticmethod
def get_dice_loss(output, target, epsilon=1e-9):
loss = DiceLoss(include_background=False, sigmoid=False)
return loss(output, target)
@staticmethod
def demographics_transform(demographics: dict):
return {}
def training_data(self, batch_size = 4):
# The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
common_shape = (44, 44, 56)
training_transform = Compose([AddChannel(), Resize(common_shape)])
target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
dataset = MedicalFolderDataset(
root=self.dataset_path,
data_modalities='T1',
target_modalities='label',
transform=training_transform,
target_transform=target_transform,
demographics_transform=UNetTrainingPlan.demographics_transform)
train_kwargs = {'batch_size': batch_size, 'shuffle': True}
return DataManager(dataset, **train_kwargs)
def training_step(self, data, target):
# TODO ... extract the T1 modality from `data`
# TODO ... call model forward
# TODO ... call F.softmax on the output of the model
# TODO ... compute the loss
# TODO ... extract the label from `target`
# TODO ... compute the loss, then average it
return avg_loss_over_batch
def testing_step(self, data, target):
# TODO ... extract the T1 modality from `data`
# TODO ... call model forward
# TODO ... call F.softmax on the output of the model
# TODO ... compute the loss
# TODO ... extract the label from `target`
# TODO ... compute the loss, then average it
return avg_loss
Define Parameters#
Here you can define model_args
and training_args
, two dictionaries that contain parameters and hyperparameters for training and model definition. This provides a flexible way to explore the hyperparameter space without changing the TrainingPlan
, which has potentially been fixed and validated by the clinical partners.
Try it yourself!#
Change any of the parameters below to explore the space of hyperparameters. Please be mindful of the limited available resources when changing parameters that may require more computing power, such as base_filters
and especially batch_size
.
model_args = {
'in_channels': 1,
'out_channels': 2,
'dimensions': 3,
'base_filters': 10,
}
training_args = {
'batch_size': 4,
'num_updates': 8,
'dry_run': False,
'log_interval': 2,
'test_ratio' : 0.1,
'test_on_global_updates': True,
'test_on_local_updates': False,
'optimizer_args': {
'opt_name': 'adamw',
'lr': 0.001
}
}
num_rounds = 15
Dry run your TrainingPlan
locally#
Since we have an holdout dataset available locally, we are going to test that the TrainingPlan
is able to run locally, before we perform the federated training.
Try it yourself!#
First, create a dataloader following these steps:
instantiate a
UNetTrainingPlan
objectcall
post_init
on the training plan. Note: you will need to instantiate aTrainingArgs
object from thetraining_args
dictset the training plan’s
dataset_path
to/datasets/ixi/holdout
Call the
training_data
function from the loaded experiment’s training plan to obtain aTorchDataManager
Use the
_dataset
attribute of the data manager to instantiate a torchDataLoader
(set a small batch size
Then, perform one training iteration to check that it completes without errors:
create a for loop iterating on the dataloader
call the training plan’s
training_step
methodbreak
after the first iteration
dryrun_tp = # ... TODO instantiate training plan object
# ... TODO call post_init with the following second argument: TrainingArgs(training_args, only_required=False)
# ... TODO set dataset path
# ... TODO call training_data
# ... TODO create torch DataLoader
# ... TODO for loop on data loader
# ... TODO call training_step
# ... TODO break
tags = ['ixi-jupyter-%%%%']
exp = Experiment(tags=tags,
model_args=model_args,
training_plan_class=UNetTrainingPlan,
training_args=training_args,
round_limit=num_rounds,
aggregator=FedAverage(),
tensorboard=True,
save_breakpoints=True
)
Run tensorboard#
Follow the instructions to obtain a port number, and run the commands below.
If <IP> is the IP assigned to you at this workshop, you may also view the tensorboard in a new browser tab at http://<IP>:<YOUR PORT NUMBER>
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']
tensorboard --logdir "$tensorboard_dir" --host 0.0.0.0 --port <YOUR PORT NUMBER>
Run the experiment#
print(f"Saving breakpoints to {exp.experimentation_folder()}")
exp.run()
Task 5: Validate on a local holdout set #
To ensure consistency and simplify our life, we try to reuse the already-available code as much as possible. Note that this process assumes that the held-out data is stored locally on the machine.
Create an instance of the global model#
First, we create an instance of the model using the parameters from the latest aggregation round.
exp_folder = exp.experimentation_folder()
#exp_folder = 'Experiment_0030'
breakpoint_num = 14
breakpoint_dir = os.path.join(environ['EXPERIMENTS_DIR'], exp_folder, f'breakpoint_{breakpoint_num:04d}')
Try it yourself!#
Call the
post_init
function to initialize the model in the newly loaded experiment’s training plan (Note, you will need to create aTrainingArgs
object)extract the model
call the model’s
load_state_dict
, by passing as argument the model weights obtained from the loaded experiment’s_aggregated_params
loaded_exp = Experiment.load_breakpoint(breakpoint_dir)
loaded_exp.training_plan().post_init(# ... TODO fill in arguments )
valid_model = # ... TODO obtain the model from the experiment's training plan
# ... TODO call load_state_dict
Define a validation data loader#
We extract the validation data loader from the training plan as well. This requires some knowledge about the internals of the MedicalFolderDataset
class. At the end of the process, calling the split
function with a ratio of 0 will return a data loader that loads all of the data.
Try it yourself!#
Call the training_data
function from the loaded experiment’s training plan to obtain a TorchDataManager
.
Use the _dataset
attribute of the data manager to instantiate a torch DataLoader
.
Note: use a batch size of 1 for local validation.
loaded_exp.training_plan().dataset_path = '/datasets/ixi/holdout'
val_data_manager = # ... TODO call training_data
val_data_loader = # ... TODO instantiate a torch DataLoader
Compute the loss on validation images#
Try it yourself!#
Iterate over the validation dataset using the DataLoader defined above.
Compute the predictions by manually calling the model forward and
F.softmax
methods.Compute the loss by calling the
UNetTrainingPlan.get_dice_loss
function.Store the loss values in a list.
Attention: do not forget to set valid_model.eval()
and to use the torch.no_grad
context.
losses = []
# ... TODO fill in the code following the instructions above
loss_array = np.array(losses)
print(f'Minumum: {loss_array.min():0.3f} '
f'Mean: {loss_array.mean():0.3f} '
f'StdDev: {loss_array.std():0.3f} '
f'Maximum: {loss_array.max():0.3f} ')
Visualize the outputs#
We visualize the outputs of our model on the holdout dataset.
one_batch = next(iter(val_data_loader))
Try it yourself!#
Create a grid of 4x2
images. In each row, plot a slice of the original on the left overlayed with the ground truth, and the same slice of the original on the right overlayed with the prediction.
Insert the loss value in the title of the image.
Remember that to compute the predictions you need to pass the whole 3d image to model.forward
, not just a slice.
fig, ax = plt.subplots(4,2, figsize=(8,16))
slice_to_plot = 24
predictions = valid_model.forward(one_batch[0][0]['T1'])
for i in range(4):
(image_modalities, demographics), target = next(val_data_loader_iter)
img = # ... TODO obtain the slice of the image to be plotted
ax[i][0].imshow(img, cmap='bone')
label = # ... TODO obtain the slice of the label to be plotted
plot = ax[i][0].imshow(np.ma.masked_where(
label < 0.5,
label), cmap='winter', alpha=0.7)
ax[i][0].set_title('Input slice with ground truth overlay')
# ... TODO compute the prediction for this image
# ... TODO following the blueprint above, plot the original image with the overlayed prediction
ax[i][1].set_title(f'Input slice with prediction overlay - loss: {losses[i]:0.3f}')
fig.colorbar(plot)