Module medcam.backends.grad_cam
Expand source code
from collections import OrderedDict
import numpy as np
import torch
from torch.nn import functional as F
from medcam.backends.base import _BaseWrapper
from medcam import medcam_utils
# Changes the used method to hook into backward
ENABLE_MODULE_HOOK = False
class GradCAM(_BaseWrapper):
    def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=False):
        """
        "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
        https://arxiv.org/pdf/1610.02391.pdf
        Look at Figure 2 on page 4
        """
        super(GradCAM, self).__init__(model, postprocessor=postprocessor, retain_graph=retain_graph)
        self.fmap_pool = OrderedDict()
        self.grad_pool = OrderedDict()
        self._target_layers = target_layers
        if target_layers == 'full' or target_layers == 'auto':
            target_layers = medcam_utils.get_layers(self.model)
        elif isinstance(target_layers, str):
            target_layers = [target_layers]
        self.target_layers = target_layers
        self.printed_selected_layer = False
    def _register_hooks(self):
        """Registers the forward and backward hooks to the layers."""
        def forward_hook(key):
            def forward_hook_(module, input, output):
                self.registered_hooks[key][0] = True
                # Save featuremaps
                # if not isinstance(output, torch.Tensor):
                #     print("Cannot hook layer {} because its gradients are not in tensor format".format(key))
                output = medcam_utils.unpack_tensors_with_gradients(output)
                if not ENABLE_MODULE_HOOK:
                    def _backward_hook(grad_out):
                        self.registered_hooks[key][1] = True
                        # Save the gradients correspond to the featuremaps
                        self.grad_pool[key] = grad_out.detach()
                    # Register backward hook directly to the output
                    # Handle must be removed afterwards otherwise tensor is not freed
                    if not self.registered_hooks[key][1]:
                        if len(output) == 1:
                            _backward_handle = output[0].register_hook(_backward_hook)
                            self.backward_handlers.append(_backward_handle)
                        else:
                            for element in output:
                                _backward_handle = element.register_hook(_backward_hook)
                                self.backward_handlers.append(_backward_handle)
                        # _backward_handle = output.register_hook(_backward_hook)
                        # self.backward_handlers.append(_backward_handle)
                if len(output) == 1:
                    self.fmap_pool[key] = output[0].detach()
                else:
                    elements = []
                    for element in output:
                        elements.append(element.detach())
                    self.fmap_pool[key] = elements
                # self.fmap_pool[key] = output.detach()
            return forward_hook_
        # This backward hook method looks prettier but is currently bugged in pytorch (04/25/2020)
        # Handle does not need to be removed, tensors are freed automatically
        def backward_hook(key):
            def backward_hook_(module, grad_in, grad_out):
                self.registered_hooks[key][1] = True
                # Save the gradients correspond to the featuremaps
                grad_out = medcam_utils.unpack_tensors_with_gradients(grad_out[0])
                if len(grad_out) == 1:
                    self.grad_pool[key] = grad_out[0].detach()
                else:
                    elements = []
                    for element in grad_out:
                        elements.append(element.detach())
                    self.grad_pool[key] = elements
                # self.grad_pool[key] = grad_out[0].detach()  # TODO: Still correct with batch size > 1?
            return backward_hook_
        self.remove_hook(forward=True, backward=True)
        for name, module in self.model.named_modules():
            if self.target_layers is None or name in self.target_layers:
                self.registered_hooks[name] = [False, False]
                self.forward_handlers.append(module.register_forward_hook(forward_hook(name)))
                if ENABLE_MODULE_HOOK:
                    self.backward_handlers.append(module.register_backward_hook(backward_hook(name)))
    def get_registered_hooks(self):
        """Returns every hook that was able to register to a layer."""
        registered_hooks = []
        for layer in self.registered_hooks.keys():
            if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                registered_hooks.append(layer)
        self.remove_hook(forward=True, backward=True)
        if self._target_layers == 'full' or self._target_layers == 'auto':
            self.target_layers = registered_hooks
        return registered_hooks
    def forward(self, data):
        """Calls the forward() of the base."""
        self._register_hooks()
        return super(GradCAM, self).forward(data)
    def generate(self):
        """Generates an attention map."""
        self.remove_hook(forward=True, backward=True)
        attention_maps = {}
        if self._target_layers == "auto":
            layer, fmaps, grads = self._auto_layer_selection()
            self._check_hooks(layer)
            attention_map = self._generate_helper(fmaps, grads, layer).cpu().numpy()
            attention_maps = {layer: attention_map}
        else:
            for layer in self.target_layers:
                self._check_hooks(layer)
                if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                    fmaps = self._find(self.fmap_pool, layer)
                    grads = self._find(self.grad_pool, layer)
                    attention_map = self._generate_helper(fmaps, grads, layer)
                    attention_maps[layer] = attention_map.cpu().numpy()
        if not attention_maps:
            raise ValueError("None of the hooks registered to the target layers")
        return attention_maps
    def _auto_layer_selection(self):
        """Selects the last layer from which attention maps can be generated."""
        # It's ugly but it works ;)
        module_names = self.layers(reverse=True)
        found_valid_layer = False
        for layer in module_names:
            try:
                fmaps = self._find(self.fmap_pool, layer)
                grads = self._find(self.grad_pool, layer)
                nonzeros = np.count_nonzero(grads.detach().cpu().numpy())  # TODO: Add except here with description, replace nonzero with sum == 0?
                self._compute_grad_weights(grads)
                if nonzeros == 0 or not isinstance(fmaps, torch.Tensor) or not isinstance(grads, torch.Tensor):
                    continue
                if (len(fmaps.shape) == 4 and len(grads.shape) == 4 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1) or \
                    (len(fmaps.shape) == 5 and len(grads.shape) == 5 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and fmaps.shape[4] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1 and grads.shape[4] > 1):
                    if not self.printed_selected_layer:
                        # print("Selected module layer: {}".format(layer))
                        self.printed_selected_layer = True
                    found_valid_layer = True
                    break
            except ValueError:
                pass
            except RuntimeError:
                pass
            except IndexError:
                pass
        if not found_valid_layer:
            raise ValueError("Could not find a valid layer. "
                             "Check if base.logits or the mask result of base._mask_output() contains only zeros. "
                             "Check if requires_grad flag is true for the batch input and that no torch.no_grad statements effects medcam. "
                             "Check if the model has any convolution layers.")
        return layer, fmaps, grads
    def _find(self, pool, target_layer):
        """Returns the feature maps or gradients for a specific layer."""
        if target_layer in pool.keys():
            return pool[target_layer]
        else:
            raise ValueError("Invalid layer name: {}".format(target_layer))
    def _compute_grad_weights(self, grads):
        """Computes the weights based on the gradients by average pooling."""
        if self.input_dim == 2:
            return F.adaptive_avg_pool2d(grads, 1)
        else:
            return F.adaptive_avg_pool3d(grads, 1)
    def _generate_helper(self, fmaps, grads, layer):
        weights = self._compute_grad_weights(grads)
        attention_map = torch.mul(fmaps, weights)
        B, _, *data_shape = attention_map.shape
        try:
            attention_map = attention_map.view(B, self.output_channels, -1, *data_shape)
        except RuntimeError:
            raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer))
        attention_map = torch.sum(attention_map, dim=2)
        attention_map = F.relu(attention_map)
        attention_map = self._normalize_per_channel(attention_map)
        return attention_map
    def _check_hooks(self, layer):
        """Checks if all hooks registered."""
        if not self.registered_hooks[layer][0] and not self.registered_hooks[layer][1]:
            raise ValueError("Neither forward hook nor backward hook did register to layer: " + str(layer))
        elif not self.registered_hooks[layer][0]:
            raise ValueError("Forward hook did not register to layer: " + str(layer))
        elif not self.registered_hooks[layer][1]:
            raise ValueError("Backward hook did not register to layer: " + str(layer) + ", Check if the hook was registered to a layer that is skipped during backward and thus no gradients are computed")
Classes
class GradCAM (model, target_layers=None, postprocessor=None, retain_graph=False)- 
"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" https://arxiv.org/pdf/1610.02391.pdf Look at Figure 2 on page 4
Expand source code
class GradCAM(_BaseWrapper): def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=False): """ "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" https://arxiv.org/pdf/1610.02391.pdf Look at Figure 2 on page 4 """ super(GradCAM, self).__init__(model, postprocessor=postprocessor, retain_graph=retain_graph) self.fmap_pool = OrderedDict() self.grad_pool = OrderedDict() self._target_layers = target_layers if target_layers == 'full' or target_layers == 'auto': target_layers = medcam_utils.get_layers(self.model) elif isinstance(target_layers, str): target_layers = [target_layers] self.target_layers = target_layers self.printed_selected_layer = False def _register_hooks(self): """Registers the forward and backward hooks to the layers.""" def forward_hook(key): def forward_hook_(module, input, output): self.registered_hooks[key][0] = True # Save featuremaps # if not isinstance(output, torch.Tensor): # print("Cannot hook layer {} because its gradients are not in tensor format".format(key)) output = medcam_utils.unpack_tensors_with_gradients(output) if not ENABLE_MODULE_HOOK: def _backward_hook(grad_out): self.registered_hooks[key][1] = True # Save the gradients correspond to the featuremaps self.grad_pool[key] = grad_out.detach() # Register backward hook directly to the output # Handle must be removed afterwards otherwise tensor is not freed if not self.registered_hooks[key][1]: if len(output) == 1: _backward_handle = output[0].register_hook(_backward_hook) self.backward_handlers.append(_backward_handle) else: for element in output: _backward_handle = element.register_hook(_backward_hook) self.backward_handlers.append(_backward_handle) # _backward_handle = output.register_hook(_backward_hook) # self.backward_handlers.append(_backward_handle) if len(output) == 1: self.fmap_pool[key] = output[0].detach() else: elements = [] for element in output: elements.append(element.detach()) self.fmap_pool[key] = elements # self.fmap_pool[key] = output.detach() return forward_hook_ # This backward hook method looks prettier but is currently bugged in pytorch (04/25/2020) # Handle does not need to be removed, tensors are freed automatically def backward_hook(key): def backward_hook_(module, grad_in, grad_out): self.registered_hooks[key][1] = True # Save the gradients correspond to the featuremaps grad_out = medcam_utils.unpack_tensors_with_gradients(grad_out[0]) if len(grad_out) == 1: self.grad_pool[key] = grad_out[0].detach() else: elements = [] for element in grad_out: elements.append(element.detach()) self.grad_pool[key] = elements # self.grad_pool[key] = grad_out[0].detach() # TODO: Still correct with batch size > 1? return backward_hook_ self.remove_hook(forward=True, backward=True) for name, module in self.model.named_modules(): if self.target_layers is None or name in self.target_layers: self.registered_hooks[name] = [False, False] self.forward_handlers.append(module.register_forward_hook(forward_hook(name))) if ENABLE_MODULE_HOOK: self.backward_handlers.append(module.register_backward_hook(backward_hook(name))) def get_registered_hooks(self): """Returns every hook that was able to register to a layer.""" registered_hooks = [] for layer in self.registered_hooks.keys(): if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]: registered_hooks.append(layer) self.remove_hook(forward=True, backward=True) if self._target_layers == 'full' or self._target_layers == 'auto': self.target_layers = registered_hooks return registered_hooks def forward(self, data): """Calls the forward() of the base.""" self._register_hooks() return super(GradCAM, self).forward(data) def generate(self): """Generates an attention map.""" self.remove_hook(forward=True, backward=True) attention_maps = {} if self._target_layers == "auto": layer, fmaps, grads = self._auto_layer_selection() self._check_hooks(layer) attention_map = self._generate_helper(fmaps, grads, layer).cpu().numpy() attention_maps = {layer: attention_map} else: for layer in self.target_layers: self._check_hooks(layer) if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]: fmaps = self._find(self.fmap_pool, layer) grads = self._find(self.grad_pool, layer) attention_map = self._generate_helper(fmaps, grads, layer) attention_maps[layer] = attention_map.cpu().numpy() if not attention_maps: raise ValueError("None of the hooks registered to the target layers") return attention_maps def _auto_layer_selection(self): """Selects the last layer from which attention maps can be generated.""" # It's ugly but it works ;) module_names = self.layers(reverse=True) found_valid_layer = False for layer in module_names: try: fmaps = self._find(self.fmap_pool, layer) grads = self._find(self.grad_pool, layer) nonzeros = np.count_nonzero(grads.detach().cpu().numpy()) # TODO: Add except here with description, replace nonzero with sum == 0? self._compute_grad_weights(grads) if nonzeros == 0 or not isinstance(fmaps, torch.Tensor) or not isinstance(grads, torch.Tensor): continue if (len(fmaps.shape) == 4 and len(grads.shape) == 4 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1) or \ (len(fmaps.shape) == 5 and len(grads.shape) == 5 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and fmaps.shape[4] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1 and grads.shape[4] > 1): if not self.printed_selected_layer: # print("Selected module layer: {}".format(layer)) self.printed_selected_layer = True found_valid_layer = True break except ValueError: pass except RuntimeError: pass except IndexError: pass if not found_valid_layer: raise ValueError("Could not find a valid layer. " "Check if base.logits or the mask result of base._mask_output() contains only zeros. " "Check if requires_grad flag is true for the batch input and that no torch.no_grad statements effects medcam. " "Check if the model has any convolution layers.") return layer, fmaps, grads def _find(self, pool, target_layer): """Returns the feature maps or gradients for a specific layer.""" if target_layer in pool.keys(): return pool[target_layer] else: raise ValueError("Invalid layer name: {}".format(target_layer)) def _compute_grad_weights(self, grads): """Computes the weights based on the gradients by average pooling.""" if self.input_dim == 2: return F.adaptive_avg_pool2d(grads, 1) else: return F.adaptive_avg_pool3d(grads, 1) def _generate_helper(self, fmaps, grads, layer): weights = self._compute_grad_weights(grads) attention_map = torch.mul(fmaps, weights) B, _, *data_shape = attention_map.shape try: attention_map = attention_map.view(B, self.output_channels, -1, *data_shape) except RuntimeError: raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer)) attention_map = torch.sum(attention_map, dim=2) attention_map = F.relu(attention_map) attention_map = self._normalize_per_channel(attention_map) return attention_map def _check_hooks(self, layer): """Checks if all hooks registered.""" if not self.registered_hooks[layer][0] and not self.registered_hooks[layer][1]: raise ValueError("Neither forward hook nor backward hook did register to layer: " + str(layer)) elif not self.registered_hooks[layer][0]: raise ValueError("Forward hook did not register to layer: " + str(layer)) elif not self.registered_hooks[layer][1]: raise ValueError("Backward hook did not register to layer: " + str(layer) + ", Check if the hook was registered to a layer that is skipped during backward and thus no gradients are computed")Ancestors
- medcam.backends.base._BaseWrapper
 
Subclasses
Methods
def forward(self, data)- 
Calls the forward() of the base.
Expand source code
def forward(self, data): """Calls the forward() of the base.""" self._register_hooks() return super(GradCAM, self).forward(data) def generate(self)- 
Generates an attention map.
Expand source code
def generate(self): """Generates an attention map.""" self.remove_hook(forward=True, backward=True) attention_maps = {} if self._target_layers == "auto": layer, fmaps, grads = self._auto_layer_selection() self._check_hooks(layer) attention_map = self._generate_helper(fmaps, grads, layer).cpu().numpy() attention_maps = {layer: attention_map} else: for layer in self.target_layers: self._check_hooks(layer) if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]: fmaps = self._find(self.fmap_pool, layer) grads = self._find(self.grad_pool, layer) attention_map = self._generate_helper(fmaps, grads, layer) attention_maps[layer] = attention_map.cpu().numpy() if not attention_maps: raise ValueError("None of the hooks registered to the target layers") return attention_maps def get_registered_hooks(self)- 
Returns every hook that was able to register to a layer.
Expand source code
def get_registered_hooks(self): """Returns every hook that was able to register to a layer.""" registered_hooks = [] for layer in self.registered_hooks.keys(): if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]: registered_hooks.append(layer) self.remove_hook(forward=True, backward=True) if self._target_layers == 'full' or self._target_layers == 'auto': self.target_layers = registered_hooks return registered_hooks