Module medcam.backends.grad_cam_pp
Expand source code
import torch
from torch.nn import functional as F
from medcam.backends.grad_cam import GradCAM
from medcam import medcam_utils
from medcam.medcam_utils import prod
class GradCamPP(GradCAM):
def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=False):
"""
"Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks"
https://arxiv.org/abs/1710.11063
"""
super(GradCamPP, self).__init__(model, target_layers=target_layers, postprocessor=postprocessor, retain_graph=retain_graph)
def _generate_helper(self, fmaps, grads, layer):
B, C, *data_shape = grads.size()
alpha_num = grads.pow(2)
tmp = fmaps.mul(grads.pow(3))
tmp = tmp.view(B, C, prod(data_shape))
tmp = tmp.sum(-1, keepdim=True)
if self.input_dim == 2:
tmp = tmp.view(B, C, 1, 1)
else:
tmp = tmp.view(B, C, 1, 1, 1)
alpha_denom = grads.pow(2).mul(2) + tmp
alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
alpha = alpha_num.div(alpha_denom + 1e-7)
if self.mask is not None:
mask = self.mask.squeeze()
if self.mask is None: # Classification
prob_weights = torch.tensor(1.0)
elif len(mask.shape) == 1: # Classification best/index
prob_weights = self.logits.squeeze()[torch.argmax(mask)]
else: # Segmentation
masked_logits = self.logits * self.mask
prob_weights = medcam_utils.interpolate(masked_logits, grads.shape[2:]) # TODO: Still removes channels...
positive_gradients = F.relu(torch.mul(prob_weights.exp(), grads))
weights = (alpha * positive_gradients).view(B, C, -1).sum(-1)
if self.input_dim == 2:
weights = weights.view(B, C, 1, 1)
else:
weights = weights.view(B, C, 1, 1, 1)
attention_map = (weights * fmaps)
try:
attention_map = attention_map.view(B, self.output_channels, -1, *data_shape)
except RuntimeError:
raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer))
attention_map = torch.sum(attention_map, dim=2)
attention_map = F.relu(attention_map).detach()
attention_map = self._normalize_per_channel(attention_map)
return attention_map
Classes
class GradCamPP (model, target_layers=None, postprocessor=None, retain_graph=False)
-
"Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks" https://arxiv.org/abs/1710.11063
Expand source code
class GradCamPP(GradCAM): def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=False): """ "Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks" https://arxiv.org/abs/1710.11063 """ super(GradCamPP, self).__init__(model, target_layers=target_layers, postprocessor=postprocessor, retain_graph=retain_graph) def _generate_helper(self, fmaps, grads, layer): B, C, *data_shape = grads.size() alpha_num = grads.pow(2) tmp = fmaps.mul(grads.pow(3)) tmp = tmp.view(B, C, prod(data_shape)) tmp = tmp.sum(-1, keepdim=True) if self.input_dim == 2: tmp = tmp.view(B, C, 1, 1) else: tmp = tmp.view(B, C, 1, 1, 1) alpha_denom = grads.pow(2).mul(2) + tmp alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) alpha = alpha_num.div(alpha_denom + 1e-7) if self.mask is not None: mask = self.mask.squeeze() if self.mask is None: # Classification prob_weights = torch.tensor(1.0) elif len(mask.shape) == 1: # Classification best/index prob_weights = self.logits.squeeze()[torch.argmax(mask)] else: # Segmentation masked_logits = self.logits * self.mask prob_weights = medcam_utils.interpolate(masked_logits, grads.shape[2:]) # TODO: Still removes channels... positive_gradients = F.relu(torch.mul(prob_weights.exp(), grads)) weights = (alpha * positive_gradients).view(B, C, -1).sum(-1) if self.input_dim == 2: weights = weights.view(B, C, 1, 1) else: weights = weights.view(B, C, 1, 1, 1) attention_map = (weights * fmaps) try: attention_map = attention_map.view(B, self.output_channels, -1, *data_shape) except RuntimeError: raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer)) attention_map = torch.sum(attention_map, dim=2) attention_map = F.relu(attention_map).detach() attention_map = self._normalize_per_channel(attention_map) return attention_map
Ancestors
- GradCAM
- medcam.backends.base._BaseWrapper
Inherited members