Module medcam.backends.grad_cam

Expand source code
from collections import OrderedDict
import numpy as np
import torch
from torch.nn import functional as F
from medcam.backends.base import _BaseWrapper
from medcam import medcam_utils

# Changes the used method to hook into backward
ENABLE_MODULE_HOOK = False

class GradCAM(_BaseWrapper):

    def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=False):
        """
        "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
        https://arxiv.org/pdf/1610.02391.pdf
        Look at Figure 2 on page 4
        """
        super(GradCAM, self).__init__(model, postprocessor=postprocessor, retain_graph=retain_graph)
        self.fmap_pool = OrderedDict()
        self.grad_pool = OrderedDict()
        self._target_layers = target_layers
        if target_layers == 'full' or target_layers == 'auto':
            target_layers = medcam_utils.get_layers(self.model)
        elif isinstance(target_layers, str):
            target_layers = [target_layers]
        self.target_layers = target_layers
        self.printed_selected_layer = False

    def _register_hooks(self):
        """Registers the forward and backward hooks to the layers."""
        def forward_hook(key):
            def forward_hook_(module, input, output):
                self.registered_hooks[key][0] = True
                # Save featuremaps
                # if not isinstance(output, torch.Tensor):
                #     print("Cannot hook layer {} because its gradients are not in tensor format".format(key))
                output = medcam_utils.unpack_tensors_with_gradients(output)

                if not ENABLE_MODULE_HOOK:
                    def _backward_hook(grad_out):
                        self.registered_hooks[key][1] = True
                        # Save the gradients correspond to the featuremaps
                        self.grad_pool[key] = grad_out.detach()

                    # Register backward hook directly to the output
                    # Handle must be removed afterwards otherwise tensor is not freed
                    if not self.registered_hooks[key][1]:
                        if len(output) == 1:
                            _backward_handle = output[0].register_hook(_backward_hook)
                            self.backward_handlers.append(_backward_handle)
                        else:
                            for element in output:
                                _backward_handle = element.register_hook(_backward_hook)
                                self.backward_handlers.append(_backward_handle)
                        # _backward_handle = output.register_hook(_backward_hook)
                        # self.backward_handlers.append(_backward_handle)

                if len(output) == 1:
                    self.fmap_pool[key] = output[0].detach()
                else:
                    elements = []
                    for element in output:
                        elements.append(element.detach())
                    self.fmap_pool[key] = elements
                # self.fmap_pool[key] = output.detach()

            return forward_hook_

        # This backward hook method looks prettier but is currently bugged in pytorch (04/25/2020)
        # Handle does not need to be removed, tensors are freed automatically
        def backward_hook(key):
            def backward_hook_(module, grad_in, grad_out):
                self.registered_hooks[key][1] = True
                # Save the gradients correspond to the featuremaps
                grad_out = medcam_utils.unpack_tensors_with_gradients(grad_out[0])
                if len(grad_out) == 1:
                    self.grad_pool[key] = grad_out[0].detach()
                else:
                    elements = []
                    for element in grad_out:
                        elements.append(element.detach())
                    self.grad_pool[key] = elements
                # self.grad_pool[key] = grad_out[0].detach()  # TODO: Still correct with batch size > 1?

            return backward_hook_

        self.remove_hook(forward=True, backward=True)
        for name, module in self.model.named_modules():
            if self.target_layers is None or name in self.target_layers:
                self.registered_hooks[name] = [False, False]
                self.forward_handlers.append(module.register_forward_hook(forward_hook(name)))
                if ENABLE_MODULE_HOOK:
                    self.backward_handlers.append(module.register_backward_hook(backward_hook(name)))

    def get_registered_hooks(self):
        """Returns every hook that was able to register to a layer."""
        registered_hooks = []
        for layer in self.registered_hooks.keys():
            if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                registered_hooks.append(layer)
        self.remove_hook(forward=True, backward=True)
        if self._target_layers == 'full' or self._target_layers == 'auto':
            self.target_layers = registered_hooks
        return registered_hooks

    def forward(self, data):
        """Calls the forward() of the base."""
        self._register_hooks()
        return super(GradCAM, self).forward(data)

    def generate(self):
        """Generates an attention map."""
        self.remove_hook(forward=True, backward=True)
        attention_maps = {}
        if self._target_layers == "auto":
            layer, fmaps, grads = self._auto_layer_selection()
            self._check_hooks(layer)
            attention_map = self._generate_helper(fmaps, grads, layer).cpu().numpy()
            attention_maps = {layer: attention_map}
        else:
            for layer in self.target_layers:
                self._check_hooks(layer)
                if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                    fmaps = self._find(self.fmap_pool, layer)
                    grads = self._find(self.grad_pool, layer)
                    attention_map = self._generate_helper(fmaps, grads, layer)
                    attention_maps[layer] = attention_map.cpu().numpy()
        if not attention_maps:
            raise ValueError("None of the hooks registered to the target layers")
        return attention_maps

    def _auto_layer_selection(self):
        """Selects the last layer from which attention maps can be generated."""
        # It's ugly but it works ;)
        module_names = self.layers(reverse=True)
        found_valid_layer = False

        for layer in module_names:
            try:
                fmaps = self._find(self.fmap_pool, layer)
                grads = self._find(self.grad_pool, layer)
                nonzeros = np.count_nonzero(grads.detach().cpu().numpy())  # TODO: Add except here with description, replace nonzero with sum == 0?
                self._compute_grad_weights(grads)
                if nonzeros == 0 or not isinstance(fmaps, torch.Tensor) or not isinstance(grads, torch.Tensor):
                    continue
                if (len(fmaps.shape) == 4 and len(grads.shape) == 4 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1) or \
                    (len(fmaps.shape) == 5 and len(grads.shape) == 5 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and fmaps.shape[4] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1 and grads.shape[4] > 1):
                    if not self.printed_selected_layer:
                        # print("Selected module layer: {}".format(layer))
                        self.printed_selected_layer = True
                    found_valid_layer = True
                    break
            except ValueError:
                pass
            except RuntimeError:
                pass
            except IndexError:
                pass

        if not found_valid_layer:
            raise ValueError("Could not find a valid layer. "
                             "Check if base.logits or the mask result of base._mask_output() contains only zeros. "
                             "Check if requires_grad flag is true for the batch input and that no torch.no_grad statements effects medcam. "
                             "Check if the model has any convolution layers.")

        return layer, fmaps, grads

    def _find(self, pool, target_layer):
        """Returns the feature maps or gradients for a specific layer."""
        if target_layer in pool.keys():
            return pool[target_layer]
        else:
            raise ValueError("Invalid layer name: {}".format(target_layer))

    def _compute_grad_weights(self, grads):
        """Computes the weights based on the gradients by average pooling."""
        if self.input_dim == 2:
            return F.adaptive_avg_pool2d(grads, 1)
        else:
            return F.adaptive_avg_pool3d(grads, 1)

    def _generate_helper(self, fmaps, grads, layer):
        weights = self._compute_grad_weights(grads)
        attention_map = torch.mul(fmaps, weights)
        B, _, *data_shape = attention_map.shape
        try:
            attention_map = attention_map.view(B, self.output_channels, -1, *data_shape)
        except RuntimeError:
            raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer))
        attention_map = torch.sum(attention_map, dim=2)
        attention_map = F.relu(attention_map)
        attention_map = self._normalize_per_channel(attention_map)
        return attention_map

    def _check_hooks(self, layer):
        """Checks if all hooks registered."""
        if not self.registered_hooks[layer][0] and not self.registered_hooks[layer][1]:
            raise ValueError("Neither forward hook nor backward hook did register to layer: " + str(layer))
        elif not self.registered_hooks[layer][0]:
            raise ValueError("Forward hook did not register to layer: " + str(layer))
        elif not self.registered_hooks[layer][1]:
            raise ValueError("Backward hook did not register to layer: " + str(layer) + ", Check if the hook was registered to a layer that is skipped during backward and thus no gradients are computed")

Classes

class GradCAM (model, target_layers=None, postprocessor=None, retain_graph=False)

"Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" https://arxiv.org/pdf/1610.02391.pdf Look at Figure 2 on page 4

Expand source code
class GradCAM(_BaseWrapper):

    def __init__(self, model, target_layers=None, postprocessor=None, retain_graph=False):
        """
        "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization"
        https://arxiv.org/pdf/1610.02391.pdf
        Look at Figure 2 on page 4
        """
        super(GradCAM, self).__init__(model, postprocessor=postprocessor, retain_graph=retain_graph)
        self.fmap_pool = OrderedDict()
        self.grad_pool = OrderedDict()
        self._target_layers = target_layers
        if target_layers == 'full' or target_layers == 'auto':
            target_layers = medcam_utils.get_layers(self.model)
        elif isinstance(target_layers, str):
            target_layers = [target_layers]
        self.target_layers = target_layers
        self.printed_selected_layer = False

    def _register_hooks(self):
        """Registers the forward and backward hooks to the layers."""
        def forward_hook(key):
            def forward_hook_(module, input, output):
                self.registered_hooks[key][0] = True
                # Save featuremaps
                # if not isinstance(output, torch.Tensor):
                #     print("Cannot hook layer {} because its gradients are not in tensor format".format(key))
                output = medcam_utils.unpack_tensors_with_gradients(output)

                if not ENABLE_MODULE_HOOK:
                    def _backward_hook(grad_out):
                        self.registered_hooks[key][1] = True
                        # Save the gradients correspond to the featuremaps
                        self.grad_pool[key] = grad_out.detach()

                    # Register backward hook directly to the output
                    # Handle must be removed afterwards otherwise tensor is not freed
                    if not self.registered_hooks[key][1]:
                        if len(output) == 1:
                            _backward_handle = output[0].register_hook(_backward_hook)
                            self.backward_handlers.append(_backward_handle)
                        else:
                            for element in output:
                                _backward_handle = element.register_hook(_backward_hook)
                                self.backward_handlers.append(_backward_handle)
                        # _backward_handle = output.register_hook(_backward_hook)
                        # self.backward_handlers.append(_backward_handle)

                if len(output) == 1:
                    self.fmap_pool[key] = output[0].detach()
                else:
                    elements = []
                    for element in output:
                        elements.append(element.detach())
                    self.fmap_pool[key] = elements
                # self.fmap_pool[key] = output.detach()

            return forward_hook_

        # This backward hook method looks prettier but is currently bugged in pytorch (04/25/2020)
        # Handle does not need to be removed, tensors are freed automatically
        def backward_hook(key):
            def backward_hook_(module, grad_in, grad_out):
                self.registered_hooks[key][1] = True
                # Save the gradients correspond to the featuremaps
                grad_out = medcam_utils.unpack_tensors_with_gradients(grad_out[0])
                if len(grad_out) == 1:
                    self.grad_pool[key] = grad_out[0].detach()
                else:
                    elements = []
                    for element in grad_out:
                        elements.append(element.detach())
                    self.grad_pool[key] = elements
                # self.grad_pool[key] = grad_out[0].detach()  # TODO: Still correct with batch size > 1?

            return backward_hook_

        self.remove_hook(forward=True, backward=True)
        for name, module in self.model.named_modules():
            if self.target_layers is None or name in self.target_layers:
                self.registered_hooks[name] = [False, False]
                self.forward_handlers.append(module.register_forward_hook(forward_hook(name)))
                if ENABLE_MODULE_HOOK:
                    self.backward_handlers.append(module.register_backward_hook(backward_hook(name)))

    def get_registered_hooks(self):
        """Returns every hook that was able to register to a layer."""
        registered_hooks = []
        for layer in self.registered_hooks.keys():
            if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                registered_hooks.append(layer)
        self.remove_hook(forward=True, backward=True)
        if self._target_layers == 'full' or self._target_layers == 'auto':
            self.target_layers = registered_hooks
        return registered_hooks

    def forward(self, data):
        """Calls the forward() of the base."""
        self._register_hooks()
        return super(GradCAM, self).forward(data)

    def generate(self):
        """Generates an attention map."""
        self.remove_hook(forward=True, backward=True)
        attention_maps = {}
        if self._target_layers == "auto":
            layer, fmaps, grads = self._auto_layer_selection()
            self._check_hooks(layer)
            attention_map = self._generate_helper(fmaps, grads, layer).cpu().numpy()
            attention_maps = {layer: attention_map}
        else:
            for layer in self.target_layers:
                self._check_hooks(layer)
                if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                    fmaps = self._find(self.fmap_pool, layer)
                    grads = self._find(self.grad_pool, layer)
                    attention_map = self._generate_helper(fmaps, grads, layer)
                    attention_maps[layer] = attention_map.cpu().numpy()
        if not attention_maps:
            raise ValueError("None of the hooks registered to the target layers")
        return attention_maps

    def _auto_layer_selection(self):
        """Selects the last layer from which attention maps can be generated."""
        # It's ugly but it works ;)
        module_names = self.layers(reverse=True)
        found_valid_layer = False

        for layer in module_names:
            try:
                fmaps = self._find(self.fmap_pool, layer)
                grads = self._find(self.grad_pool, layer)
                nonzeros = np.count_nonzero(grads.detach().cpu().numpy())  # TODO: Add except here with description, replace nonzero with sum == 0?
                self._compute_grad_weights(grads)
                if nonzeros == 0 or not isinstance(fmaps, torch.Tensor) or not isinstance(grads, torch.Tensor):
                    continue
                if (len(fmaps.shape) == 4 and len(grads.shape) == 4 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1) or \
                    (len(fmaps.shape) == 5 and len(grads.shape) == 5 and fmaps.shape[2] > 1 and fmaps.shape[3] > 1 and fmaps.shape[4] > 1 and grads.shape[2] > 1 and grads.shape[3] > 1 and grads.shape[4] > 1):
                    if not self.printed_selected_layer:
                        # print("Selected module layer: {}".format(layer))
                        self.printed_selected_layer = True
                    found_valid_layer = True
                    break
            except ValueError:
                pass
            except RuntimeError:
                pass
            except IndexError:
                pass

        if not found_valid_layer:
            raise ValueError("Could not find a valid layer. "
                             "Check if base.logits or the mask result of base._mask_output() contains only zeros. "
                             "Check if requires_grad flag is true for the batch input and that no torch.no_grad statements effects medcam. "
                             "Check if the model has any convolution layers.")

        return layer, fmaps, grads

    def _find(self, pool, target_layer):
        """Returns the feature maps or gradients for a specific layer."""
        if target_layer in pool.keys():
            return pool[target_layer]
        else:
            raise ValueError("Invalid layer name: {}".format(target_layer))

    def _compute_grad_weights(self, grads):
        """Computes the weights based on the gradients by average pooling."""
        if self.input_dim == 2:
            return F.adaptive_avg_pool2d(grads, 1)
        else:
            return F.adaptive_avg_pool3d(grads, 1)

    def _generate_helper(self, fmaps, grads, layer):
        weights = self._compute_grad_weights(grads)
        attention_map = torch.mul(fmaps, weights)
        B, _, *data_shape = attention_map.shape
        try:
            attention_map = attention_map.view(B, self.output_channels, -1, *data_shape)
        except RuntimeError:
            raise RuntimeError("Number of set channels ({}) is not a multiple of the feature map channels ({}) in layer: {}".format(self.output_channels, fmaps.shape[1], layer))
        attention_map = torch.sum(attention_map, dim=2)
        attention_map = F.relu(attention_map)
        attention_map = self._normalize_per_channel(attention_map)
        return attention_map

    def _check_hooks(self, layer):
        """Checks if all hooks registered."""
        if not self.registered_hooks[layer][0] and not self.registered_hooks[layer][1]:
            raise ValueError("Neither forward hook nor backward hook did register to layer: " + str(layer))
        elif not self.registered_hooks[layer][0]:
            raise ValueError("Forward hook did not register to layer: " + str(layer))
        elif not self.registered_hooks[layer][1]:
            raise ValueError("Backward hook did not register to layer: " + str(layer) + ", Check if the hook was registered to a layer that is skipped during backward and thus no gradients are computed")

Ancestors

  • medcam.backends.base._BaseWrapper

Subclasses

Methods

def forward(self, data)

Calls the forward() of the base.

Expand source code
def forward(self, data):
    """Calls the forward() of the base."""
    self._register_hooks()
    return super(GradCAM, self).forward(data)
def generate(self)

Generates an attention map.

Expand source code
def generate(self):
    """Generates an attention map."""
    self.remove_hook(forward=True, backward=True)
    attention_maps = {}
    if self._target_layers == "auto":
        layer, fmaps, grads = self._auto_layer_selection()
        self._check_hooks(layer)
        attention_map = self._generate_helper(fmaps, grads, layer).cpu().numpy()
        attention_maps = {layer: attention_map}
    else:
        for layer in self.target_layers:
            self._check_hooks(layer)
            if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                fmaps = self._find(self.fmap_pool, layer)
                grads = self._find(self.grad_pool, layer)
                attention_map = self._generate_helper(fmaps, grads, layer)
                attention_maps[layer] = attention_map.cpu().numpy()
    if not attention_maps:
        raise ValueError("None of the hooks registered to the target layers")
    return attention_maps
def get_registered_hooks(self)

Returns every hook that was able to register to a layer.

Expand source code
def get_registered_hooks(self):
    """Returns every hook that was able to register to a layer."""
    registered_hooks = []
    for layer in self.registered_hooks.keys():
        if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
            registered_hooks.append(layer)
    self.remove_hook(forward=True, backward=True)
    if self._target_layers == 'full' or self._target_layers == 'auto':
        self.target_layers = registered_hooks
    return registered_hooks