Module medcam.evaluation.evaluation_utils

Expand source code
import numpy as np
import torch
import copy
from medcam import medcam_utils
from skimage.filters import threshold_otsu

def comp_score(attention_map, mask, metric="wioa", threshold='otsu'):
    """Computes an evaluation score for an attention map based on a ground truth mask."""
    if isinstance(mask, torch.Tensor):
        mask = mask.detach().cpu().numpy()
    else:
        mask = np.asarray(mask)
    allowed = [0, 1, 0.0, 1.0]
    if np.min(mask) in allowed and np.max(mask) in allowed:
        mask = mask.astype(int)
    else:
        raise TypeError("Mask values need to be 0/1")
    binary_attention_map, mask, weights = _preprocessing(attention_map, mask, threshold)
    if metric[0] != "w":
        weights = None
    if metric == "ioa" or metric == "wioa":
        score = _intersection_over_attention(binary_attention_map, mask, weights)
    elif metric == "iou" or metric == "wiou":
        score = _intersection_over_union(binary_attention_map, mask, weights)
    elif callable(metric):
        score = metric(attention_map, mask, attention_map, weights)
    else:
        raise ValueError("Metric does not exist")
    return score

def _preprocessing(attention_map, mask, attention_threshold):
    """Interpolates, normalizes and binarizes the attention map."""
    if not np.isfinite(attention_map).all():
        raise ValueError("Attention map contains non finite elements")
    if not np.isfinite(mask).all():
        raise ValueError("Mask contains non finite elements")
    if np.sum(attention_map < 0) > 0:  # For gbp and ggcam as they contain negative values, which would otherwise falsify the evaluation
        attention_map = np.abs(attention_map)
    attention_map = medcam_utils.interpolate(attention_map, mask.shape, squeeze=True)
    attention_map = medcam_utils.normalize(attention_map.astype(np.float))
    weights = copy.deepcopy(attention_map)
    mask = np.array(mask, dtype=int)
    if np.min(attention_map) == np.max(attention_map):
        attention_threshold = 1
    elif attention_threshold == 'otsu':
        attention_threshold = threshold_otsu(attention_map.flatten())
    attention_map[attention_map < attention_threshold] = 0
    attention_map[attention_map >= attention_threshold] = 1
    attention_map = np.array(attention_map, dtype=int)
    return attention_map, mask, weights

def _intersection_over_attention(binary_attention_map, mask, weights):
    """(Weighted) intersection over attention. How much of (weighted) total attention is inside the ground truth mask."""
    intersection = binary_attention_map & mask
    if weights is not None:
        intersection = intersection.astype(np.float) * weights
        binary_attention_map = binary_attention_map.astype(np.float) * weights
    ioa = np.sum(intersection) / np.sum(binary_attention_map)
    return ioa

def _intersection_over_union(binary_attention_map, mask, weights):  # TODO: wiou is bad and wrong, maybe not even possible?
    """Intersection over union."""
    intersection = binary_attention_map & mask
    if weights is not None:
        outer_attention = binary_attention_map - intersection
        outer_attention = outer_attention.astype(np.float) * weights
        union = outer_attention + mask.astype(np.float)
        intersection = intersection.astype(np.float) * weights
    else:
        union = binary_attention_map | mask
    iou = np.sum(intersection) / np.sum(union).astype(np.float)
    return iou

Functions

def comp_score(attention_map, mask, metric='wioa', threshold='otsu')

Computes an evaluation score for an attention map based on a ground truth mask.

Expand source code
def comp_score(attention_map, mask, metric="wioa", threshold='otsu'):
    """Computes an evaluation score for an attention map based on a ground truth mask."""
    if isinstance(mask, torch.Tensor):
        mask = mask.detach().cpu().numpy()
    else:
        mask = np.asarray(mask)
    allowed = [0, 1, 0.0, 1.0]
    if np.min(mask) in allowed and np.max(mask) in allowed:
        mask = mask.astype(int)
    else:
        raise TypeError("Mask values need to be 0/1")
    binary_attention_map, mask, weights = _preprocessing(attention_map, mask, threshold)
    if metric[0] != "w":
        weights = None
    if metric == "ioa" or metric == "wioa":
        score = _intersection_over_attention(binary_attention_map, mask, weights)
    elif metric == "iou" or metric == "wiou":
        score = _intersection_over_union(binary_attention_map, mask, weights)
    elif callable(metric):
        score = metric(attention_map, mask, attention_map, weights)
    else:
        raise ValueError("Metric does not exist")
    return score