Module medcam.backends.guided_backpropagation

Expand source code
import torch
from torch import nn
from medcam.backends.base import _BaseWrapper


class GuidedBackPropagation(_BaseWrapper):

    def __init__(self, model, postprocessor=None, retain_graph=False):
        """
        "Striving for Simplicity: the All Convolutional Net"
        https://arxiv.org/pdf/1412.6806.pdf
        Look at Figure 1 on page 8.
        """
        super(GuidedBackPropagation, self).__init__(model, postprocessor=postprocessor, retain_graph=retain_graph)

    def _register_hooks(self):
        """Registers the backward hooks to the layers."""
        def backward_hook(module, grad_in, grad_out):
            # Cut off negative gradients
            if isinstance(module, nn.ReLU):
                return (torch.clamp(grad_in[0], min=0.0),)

        self.remove_hook(forward=True, backward=True)
        for name, module in self.model.named_modules():
            self.registered_hooks[name] = [True, True]
            self.backward_handlers.append(module.register_backward_hook(backward_hook))

    def get_registered_hooks(self):
        """Returns every hook that was able to register to a layer."""
        registered_hooks = []
        for layer in self.registered_hooks.keys():
            if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                registered_hooks.append(layer)
        self.remove_hook(forward=True, backward=True)
        return registered_hooks

    def forward(self, data):
        """Calls the forward() of the base."""
        self._register_hooks()
        self.data = data.requires_grad_()
        return super(GuidedBackPropagation, self).forward(self.data)

    def generate(self):
        """Generates an attention map."""
        attention_map = self.data.grad.clone()
        self.data.grad.zero_()
        B, _, *data_shape = attention_map.shape
        #attention_map = attention_map.view(B, self.channels, -1, *data_shape)
        attention_map = attention_map.view(B, 1, -1, *data_shape)
        attention_map = torch.mean(attention_map, dim=2)  # TODO: mean or sum?
        attention_map = attention_map.repeat(1, self.output_channels, *[1 for _ in range(self.input_dim)])
        attention_map = self._normalize_per_channel(attention_map)
        attention_map = attention_map.cpu().numpy()
        attention_maps = {}
        attention_maps[""] = attention_map
        return attention_maps

Classes

class GuidedBackPropagation (model, postprocessor=None, retain_graph=False)

"Striving for Simplicity: the All Convolutional Net" https://arxiv.org/pdf/1412.6806.pdf Look at Figure 1 on page 8.

Expand source code
class GuidedBackPropagation(_BaseWrapper):

    def __init__(self, model, postprocessor=None, retain_graph=False):
        """
        "Striving for Simplicity: the All Convolutional Net"
        https://arxiv.org/pdf/1412.6806.pdf
        Look at Figure 1 on page 8.
        """
        super(GuidedBackPropagation, self).__init__(model, postprocessor=postprocessor, retain_graph=retain_graph)

    def _register_hooks(self):
        """Registers the backward hooks to the layers."""
        def backward_hook(module, grad_in, grad_out):
            # Cut off negative gradients
            if isinstance(module, nn.ReLU):
                return (torch.clamp(grad_in[0], min=0.0),)

        self.remove_hook(forward=True, backward=True)
        for name, module in self.model.named_modules():
            self.registered_hooks[name] = [True, True]
            self.backward_handlers.append(module.register_backward_hook(backward_hook))

    def get_registered_hooks(self):
        """Returns every hook that was able to register to a layer."""
        registered_hooks = []
        for layer in self.registered_hooks.keys():
            if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
                registered_hooks.append(layer)
        self.remove_hook(forward=True, backward=True)
        return registered_hooks

    def forward(self, data):
        """Calls the forward() of the base."""
        self._register_hooks()
        self.data = data.requires_grad_()
        return super(GuidedBackPropagation, self).forward(self.data)

    def generate(self):
        """Generates an attention map."""
        attention_map = self.data.grad.clone()
        self.data.grad.zero_()
        B, _, *data_shape = attention_map.shape
        #attention_map = attention_map.view(B, self.channels, -1, *data_shape)
        attention_map = attention_map.view(B, 1, -1, *data_shape)
        attention_map = torch.mean(attention_map, dim=2)  # TODO: mean or sum?
        attention_map = attention_map.repeat(1, self.output_channels, *[1 for _ in range(self.input_dim)])
        attention_map = self._normalize_per_channel(attention_map)
        attention_map = attention_map.cpu().numpy()
        attention_maps = {}
        attention_maps[""] = attention_map
        return attention_maps

Ancestors

  • medcam.backends.base._BaseWrapper

Methods

def forward(self, data)

Calls the forward() of the base.

Expand source code
def forward(self, data):
    """Calls the forward() of the base."""
    self._register_hooks()
    self.data = data.requires_grad_()
    return super(GuidedBackPropagation, self).forward(self.data)
def generate(self)

Generates an attention map.

Expand source code
def generate(self):
    """Generates an attention map."""
    attention_map = self.data.grad.clone()
    self.data.grad.zero_()
    B, _, *data_shape = attention_map.shape
    #attention_map = attention_map.view(B, self.channels, -1, *data_shape)
    attention_map = attention_map.view(B, 1, -1, *data_shape)
    attention_map = torch.mean(attention_map, dim=2)  # TODO: mean or sum?
    attention_map = attention_map.repeat(1, self.output_channels, *[1 for _ in range(self.input_dim)])
    attention_map = self._normalize_per_channel(attention_map)
    attention_map = attention_map.cpu().numpy()
    attention_maps = {}
    attention_maps[""] = attention_map
    return attention_maps
def get_registered_hooks(self)

Returns every hook that was able to register to a layer.

Expand source code
def get_registered_hooks(self):
    """Returns every hook that was able to register to a layer."""
    registered_hooks = []
    for layer in self.registered_hooks.keys():
        if self.registered_hooks[layer][0] and self.registered_hooks[layer][1]:
            registered_hooks.append(layer)
    self.remove_hook(forward=True, backward=True)
    return registered_hooks