Expand source code
import torch
import numpy as np
from torch.nn import functional as F
from medcam import medcam_utils
class _BaseWrapper():
def __init__(self, model, postprocessor=None, retain_graph=False):
"""A base wrapper of common functions for the backends."""
self.device = next(model.parameters()).device
self.retain_graph = retain_graph
self.model = model
self.forward_handlers = []
self.backward_handlers = []
self.postprocessor = postprocessor
self.registered_hooks = {}
def generate_attention_map(self, batch, label):
"""Handles the generation of the attention map from start to finish."""
output = self.forward(batch)
self.backward(label=label)
attention_map = self.generate()
return output, attention_map, self.output_batch_size, self.output_channels, self.output_shape
def forward(self, batch):
"""Calls the forward() of the model."""
self.model.zero_grad()
self.logits = self.model.model_forward(batch)
self._extract_metadata(batch, self.logits)
self._set_postprocessor_and_label(self.logits)
self.remove_hook(forward=True, backward=False)
return self.logits
def backward(self, label=None):
"""Applies postprocessing and class discrimination on the model output and then backwards it."""
if label is None:
label = self.model.medcam_dict['label']
self.mask = self._isolate_class(self.logits, label)
self.logits.backward(gradient=self.mask, retain_graph=self.retain_graph)
self.remove_hook(forward=True, backward=True)
def _isolate_class(self, output, label):
"""Isolates a desired class on the channel dim by creating a mask that is applied on the gradients during backward."""
if label is None:
return torch.ones(output.shape).to(self.device)
if label == "best":
if self.output_batch_size > 1:
raise RuntimeError("Best label mode works only with a batch size of one. You need to choose a specific label or None with a batch size bigger than one.")
B, C, *data_shape = output.shape
if len(data_shape) > 0:
_output = output.view(B, C, -1)
_output = torch.sum(_output, dim=2)
label = torch.argmax(_output, dim=1).item()
else:
label = torch.argmax(output, dim=1).item()
if callable(label):
mask = label(output) * 1.0
else:
mask = torch.zeros(output.shape).to(self.device)
mask[:, label] = 1
return mask
def _extract_metadata(self, input, output): # TODO: Does not work for classification output (shape: (1, 1000)), merge with the one in medcam_inject
"""Extracts metadata like batch size, number of channels and the data shape from the output batch."""
self.input_dim = len(input.shape[2:])
self.output_batch_size = output.shape[0]
if self.model.medcam_dict['channels'] == 'default':
self.output_channels = output.shape[1]
else:
self.output_channels = self.model.medcam_dict['channels']
if self.model.medcam_dict['data_shape'] == 'default':
if len(output.shape) == 2: # Classification -> Cannot convert attention map to classifiaction
self.output_shape = None
else: # Output is an 2D/3D image
self.output_shape = output.shape[2:]
else:
self.output_shape = self.model.medcam_dict['data_shape']
def _normalize_per_channel(self, attention_map):
if torch.min(attention_map) == torch.max(attention_map):
return torch.zeros(attention_map.shape)
# Normalization per channel
B, C, *data_shape = attention_map.shape
attention_map = attention_map.view(B, C, -1)
attention_map_min = torch.min(attention_map, dim=2, keepdim=True)[0]
attention_map_max = torch.max(attention_map, dim=2, keepdim=True)[0]
attention_map -= attention_map_min
attention_map /= (attention_map_max - attention_map_min)
attention_map = attention_map.view(B, C, *data_shape)
return attention_map
def generate(self):
"""Generates an attention map."""
raise NotImplementedError
def remove_hook(self, forward, backward):
"""
Remove all the forward/backward hook functions
"""
if forward:
for handle in self.forward_handlers:
handle.remove()
self.forward_handlers = []
if backward:
for handle in self.backward_handlers:
handle.remove()
self.backward_handlers = []
def layers(self, reverse=False):
"""Returns the layers of the model. Optionally reverses the order of the layers."""
return medcam_utils.get_layers(self.model, reverse)
def _set_postprocessor_and_label(self, output):
if self.model.medcam_dict['label'] is None:
if output.shape[0] == self.output_batch_size and len(output.shape) == 2: # classification
self.model.medcam_dict['label'] = "best"