Module medcam.medcam_utils

Expand source code
import cv2
import numpy as np
import matplotlib.cm as cm
import nibabel as nib
import torch
from torch.nn import functional as F
from functools import reduce
import operator

MIN_SHAPE = (500, 500)

def save_attention_map(filename, attention_map, heatmap, raw_input):
    """
    Saves an attention maps.
    Args:
        filename: The save path, including the name, excluding the file extension.
        attention_map: The attention map in HxW or DxHxW format.
        heatmap: If the attention map should be saved as a heatmap. True for gcam and gcampp. False for gbp and ggcam.
    """
    dim = len(attention_map.shape)
    attention_map = normalize(attention_map.astype(np.float))
    attention_map = generate_attention_map(attention_map, heatmap, dim, raw_input)
    _save_file(filename, attention_map, dim)

def generate_attention_map(attention_map, heatmap, dim, raw_input):
    if dim == 2:
        if heatmap:
            return generate_gcam2d(attention_map, raw_input)
        else:
            return generate_guided_bp2d(attention_map)
    elif dim == 3:
        if heatmap:
            return generate_gcam3d(attention_map)
        else:
            return generate_guided_bp3d(attention_map)
    else:
        raise RuntimeError("Unsupported dimension. Only 2D and 3D data is supported.")

def generate_gcam2d(attention_map, raw_input):
    assert(len(attention_map.shape) == 2)  # No batch dim
    assert(isinstance(attention_map, np.ndarray))  # Not a tensor

    if raw_input is not None:
        attention_map = overlay(raw_input, attention_map)
    else:
        attention_map = _resize_attention_map(attention_map, MIN_SHAPE)
        attention_map = cm.jet_r(attention_map)[..., :3] * 255.0
    return np.uint8(attention_map)

def generate_guided_bp2d(attention_map):
    assert(len(attention_map.shape) == 2)
    assert (isinstance(attention_map, np.ndarray))  # Not a tensor

    attention_map *= 255.0
    attention_map = _resize_attention_map(attention_map, MIN_SHAPE)
    return np.uint8(attention_map)

def generate_gcam3d(attention_map, data=None):
    assert(isinstance(attention_map, np.ndarray))  # Not a tensor
    assert(isinstance(data, np.ndarray) or data is None)  # Not PIL
    assert(data is None or len(data.shape) == 3)

    attention_map *= 255.0
    return np.uint8(attention_map)

def generate_guided_bp3d(attention_map):
    assert(len(attention_map.shape) == 3)
    assert (isinstance(attention_map, np.ndarray))  # Not a tensor

    attention_map *= 255.0
    return np.uint8(attention_map)

def _load_data(data_path):
    if isinstance(data_path, str):
        return cv2.imread(data_path)
    else:
        return data_path

def _resize_attention_map(attention_map, min_shape):
    attention_map_shape = attention_map.shape[:2]
    if min(min_shape) < min(attention_map_shape):
        attention_map = cv2.resize(attention_map, tuple(np.flip(attention_map_shape)))
    else:
        resize_factor = int(min(min_shape) / min(attention_map_shape))
        data_shape = (attention_map_shape[0] * resize_factor, attention_map_shape[1] * resize_factor)
        attention_map = cv2.resize(attention_map, tuple(np.flip(data_shape)))
    return attention_map

def normalize(x):
    """Normalizes data both numpy or tensor data to range [0,1]."""
    if isinstance(x, torch.Tensor):
        if torch.min(x) == torch.max(x):
            return torch.zeros(x.shape)
        return (x-torch.min(x))/(torch.max(x)-torch.min(x))
    else:
        if np.min(x) == np.max(x):
            return np.zeros(x.shape)
        return (x - np.min(x)) / (np.max(x) - np.min(x))

def _save_file(filename, attention_map, dim):
    if dim == 2:
        cv2.imwrite(filename + ".png", attention_map)
    else:
        attention_map = attention_map.transpose(1, 2, 0)
        attention_map = nib.Nifti1Image(attention_map, affine=np.eye(4))
        nib.save(attention_map, filename + ".nii.gz")

def get_layers(model, reverse=False):
    """Returns the layers of the model. Optionally reverses the order of the layers."""
    layer_names = []
    for name, _ in model.named_modules():
        layer_names.append(name)

    if layer_names[0] == "":
        layer_names = layer_names[1:]

    index = 0
    sub_index = 0
    while True:
        if index == len(layer_names) - 1:
            break
        if sub_index < len(layer_names) - 1 and layer_names[index] == layer_names[sub_index + 1][:len(layer_names[index])]:
            sub_index += 1
        elif sub_index > index:
            layer_names.insert(sub_index, layer_names.pop(index))
            sub_index = index
        else:
            index += 1
            sub_index = index

    if reverse:
        layer_names.reverse()

    return layer_names

def interpolate(data, shape, squeeze=False):
    """Interpolates data to the size of a given shape. Optionally squeezes away the batch and channel dim if the data was given in HxW or DxHxW format."""
    if isinstance(data, np.ndarray):
        # Lazy solution, numpy and scipy have multiple interpolate methods with only linear or nearest, so I don't know which one to use... + they don't work with batches
        # Should be redone with numpy or scipy though
        data_type = data.dtype
        data = torch.FloatTensor(data)
        data = _interpolate_tensor(data, shape, squeeze)
        data = data.numpy().astype(data_type)
    elif isinstance(data, torch.Tensor):
        data = _interpolate_tensor(data, shape, squeeze)
    else:
        raise ValueError("Unsupported data type for interpolation")
    return data

def _interpolate_tensor(data, shape, squeeze):
    """Interpolates data to the size of a given shape. Optionally squeezes away the batch and channel dim if the data was given in HxW or DxHxW format."""
    _squeeze = 0
    if (len(shape) == 2 and len(data.shape) == 2) or ((len(shape) == 3 and len(data.shape) == 3)):  # Add batch and channel dim
        data = data.unsqueeze(0).unsqueeze(0)
        _squeeze = 2
    elif (len(shape) == 2 and len(data.shape) == 3) or ((len(shape) == 3 and len(data.shape) == 4)):  # Add batch dim
        data = data.unsqueeze(0)
        _squeeze = 1

        
    if len(shape) == 2:
        data = F.interpolate(data, shape, mode="bilinear", align_corners=False)
    else:
        data = F.interpolate(data, shape, mode="trilinear", align_corners=False)
    if squeeze:  # Remove unnecessary dims
        for i in range(_squeeze):
            data = data.squeeze(0)
    return data

def prod(iterable):
    return reduce(operator.mul, iterable, 1)

def overlay(raw_input, attention_map):
    if isinstance(raw_input, torch.Tensor):
        raw_input = raw_input.detach().cpu().numpy()
        if raw_input.shape[0] == 1 or raw_input.shape[0] == 3:
            raw_input = raw_input.transpose(1, 2, 0)
    if np.max(raw_input) > 1:
        raw_input = raw_input.astype(np.float)
        raw_input /= 255
    attention_map = cv2.resize(attention_map, tuple(np.flip(raw_input.shape[:2])))
    attention_map = cm.jet_r(attention_map)[..., :3]
    attention_map = (attention_map.astype(np.float) + raw_input.astype(np.float)) / 2
    attention_map *= 255
    return attention_map

def unpack_tensors_with_gradients(tensors):
    unpacked_tensors = []
    if isinstance(tensors, torch.Tensor):
        if tensors.requires_grad:
            return [tensors]
        else:
            return []
    elif isinstance(tensors, dict):
        for value in tensors.values():
            unpacked_tensors.extend(unpack_tensors_with_gradients(value))
        return unpacked_tensors
    elif isinstance(tensors, list):
        for value in tensors:
            unpacked_tensors.extend(unpack_tensors_with_gradients(value))
        return unpacked_tensors
    else:
        raise ValueError("Cannot unpack unknown data type.")

Functions

def generate_attention_map(attention_map, heatmap, dim, raw_input)
Expand source code
def generate_attention_map(attention_map, heatmap, dim, raw_input):
    if dim == 2:
        if heatmap:
            return generate_gcam2d(attention_map, raw_input)
        else:
            return generate_guided_bp2d(attention_map)
    elif dim == 3:
        if heatmap:
            return generate_gcam3d(attention_map)
        else:
            return generate_guided_bp3d(attention_map)
    else:
        raise RuntimeError("Unsupported dimension. Only 2D and 3D data is supported.")
def generate_gcam2d(attention_map, raw_input)
Expand source code
def generate_gcam2d(attention_map, raw_input):
    assert(len(attention_map.shape) == 2)  # No batch dim
    assert(isinstance(attention_map, np.ndarray))  # Not a tensor

    if raw_input is not None:
        attention_map = overlay(raw_input, attention_map)
    else:
        attention_map = _resize_attention_map(attention_map, MIN_SHAPE)
        attention_map = cm.jet_r(attention_map)[..., :3] * 255.0
    return np.uint8(attention_map)
def generate_gcam3d(attention_map, data=None)
Expand source code
def generate_gcam3d(attention_map, data=None):
    assert(isinstance(attention_map, np.ndarray))  # Not a tensor
    assert(isinstance(data, np.ndarray) or data is None)  # Not PIL
    assert(data is None or len(data.shape) == 3)

    attention_map *= 255.0
    return np.uint8(attention_map)
def generate_guided_bp2d(attention_map)
Expand source code
def generate_guided_bp2d(attention_map):
    assert(len(attention_map.shape) == 2)
    assert (isinstance(attention_map, np.ndarray))  # Not a tensor

    attention_map *= 255.0
    attention_map = _resize_attention_map(attention_map, MIN_SHAPE)
    return np.uint8(attention_map)
def generate_guided_bp3d(attention_map)
Expand source code
def generate_guided_bp3d(attention_map):
    assert(len(attention_map.shape) == 3)
    assert (isinstance(attention_map, np.ndarray))  # Not a tensor

    attention_map *= 255.0
    return np.uint8(attention_map)
def get_layers(model, reverse=False)

Returns the layers of the model. Optionally reverses the order of the layers.

Expand source code
def get_layers(model, reverse=False):
    """Returns the layers of the model. Optionally reverses the order of the layers."""
    layer_names = []
    for name, _ in model.named_modules():
        layer_names.append(name)

    if layer_names[0] == "":
        layer_names = layer_names[1:]

    index = 0
    sub_index = 0
    while True:
        if index == len(layer_names) - 1:
            break
        if sub_index < len(layer_names) - 1 and layer_names[index] == layer_names[sub_index + 1][:len(layer_names[index])]:
            sub_index += 1
        elif sub_index > index:
            layer_names.insert(sub_index, layer_names.pop(index))
            sub_index = index
        else:
            index += 1
            sub_index = index

    if reverse:
        layer_names.reverse()

    return layer_names
def interpolate(data, shape, squeeze=False)

Interpolates data to the size of a given shape. Optionally squeezes away the batch and channel dim if the data was given in HxW or DxHxW format.

Expand source code
def interpolate(data, shape, squeeze=False):
    """Interpolates data to the size of a given shape. Optionally squeezes away the batch and channel dim if the data was given in HxW or DxHxW format."""
    if isinstance(data, np.ndarray):
        # Lazy solution, numpy and scipy have multiple interpolate methods with only linear or nearest, so I don't know which one to use... + they don't work with batches
        # Should be redone with numpy or scipy though
        data_type = data.dtype
        data = torch.FloatTensor(data)
        data = _interpolate_tensor(data, shape, squeeze)
        data = data.numpy().astype(data_type)
    elif isinstance(data, torch.Tensor):
        data = _interpolate_tensor(data, shape, squeeze)
    else:
        raise ValueError("Unsupported data type for interpolation")
    return data
def normalize(x)

Normalizes data both numpy or tensor data to range [0,1].

Expand source code
def normalize(x):
    """Normalizes data both numpy or tensor data to range [0,1]."""
    if isinstance(x, torch.Tensor):
        if torch.min(x) == torch.max(x):
            return torch.zeros(x.shape)
        return (x-torch.min(x))/(torch.max(x)-torch.min(x))
    else:
        if np.min(x) == np.max(x):
            return np.zeros(x.shape)
        return (x - np.min(x)) / (np.max(x) - np.min(x))
def overlay(raw_input, attention_map)
Expand source code
def overlay(raw_input, attention_map):
    if isinstance(raw_input, torch.Tensor):
        raw_input = raw_input.detach().cpu().numpy()
        if raw_input.shape[0] == 1 or raw_input.shape[0] == 3:
            raw_input = raw_input.transpose(1, 2, 0)
    if np.max(raw_input) > 1:
        raw_input = raw_input.astype(np.float)
        raw_input /= 255
    attention_map = cv2.resize(attention_map, tuple(np.flip(raw_input.shape[:2])))
    attention_map = cm.jet_r(attention_map)[..., :3]
    attention_map = (attention_map.astype(np.float) + raw_input.astype(np.float)) / 2
    attention_map *= 255
    return attention_map
def prod(iterable)
Expand source code
def prod(iterable):
    return reduce(operator.mul, iterable, 1)
def save_attention_map(filename, attention_map, heatmap, raw_input)

Saves an attention maps.

Args

filename
The save path, including the name, excluding the file extension.
attention_map
The attention map in HxW or DxHxW format.
heatmap
If the attention map should be saved as a heatmap. True for gcam and gcampp. False for gbp and ggcam.
Expand source code
def save_attention_map(filename, attention_map, heatmap, raw_input):
    """
    Saves an attention maps.
    Args:
        filename: The save path, including the name, excluding the file extension.
        attention_map: The attention map in HxW or DxHxW format.
        heatmap: If the attention map should be saved as a heatmap. True for gcam and gcampp. False for gbp and ggcam.
    """
    dim = len(attention_map.shape)
    attention_map = normalize(attention_map.astype(np.float))
    attention_map = generate_attention_map(attention_map, heatmap, dim, raw_input)
    _save_file(filename, attention_map, dim)
def unpack_tensors_with_gradients(tensors)
Expand source code
def unpack_tensors_with_gradients(tensors):
    unpacked_tensors = []
    if isinstance(tensors, torch.Tensor):
        if tensors.requires_grad:
            return [tensors]
        else:
            return []
    elif isinstance(tensors, dict):
        for value in tensors.values():
            unpacked_tensors.extend(unpack_tensors_with_gradients(value))
        return unpacked_tensors
    elif isinstance(tensors, list):
        for value in tensors:
            unpacked_tensors.extend(unpack_tensors_with_gradients(value))
        return unpacked_tensors
    else:
        raise ValueError("Cannot unpack unknown data type.")