Source code for attribench.plot._correlations

import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, Tuple
import seaborn as sns
from matplotlib.figure import Figure
from attribench.plot import Plot


def _create_fig(df, figsize, annot):
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(
        df,
        annot=annot,
        vmin=-1,
        vmax=1,
        cmap=sns.diverging_palette(220, 20, as_cmap=True),
        ax=ax,
        fmt=".2f",
        cbar=False,
    )
    ax.set_aspect("equal")
    return fig, ax


[docs]class InterMetricCorrelationPlot(Plot): """Heatmap showing Spearman correlations between metrics."""
[docs] def render( self, title: str | None = None, figsize: Tuple[int, int] = (20, 20), fontsize: int | None = None, annot: bool = False, ) -> Figure: """Render the plot. Parameters ---------- title : str | None, optional Title of the figure, by default None figsize : Tuple[int, int], optional Size of the figure, by default (20, 20) fontsize : int | None, optional Font size of x and y axis ticks, by default None annot : bool, optional Whether to annotate the heatmap with the correlation values, by default False Returns ------- Figure The rendered Matplotlib figure. """ corr_dfs = [] methods = list(self.dfs.values())[0][0].columns for method_name in methods: data = {} for metric_name, (df, higher_is_better) in self.dfs.items(): data[metric_name] = ( -df[method_name].to_numpy() if not higher_is_better else df[method_name].to_numpy() ) df = pd.DataFrame(data) corr_dfs.append(df.corr(method="spearman")) corr = pd.concat(corr_dfs).groupby(level=0).mean() corr = corr.reindex(corr.columns) fig, ax = _create_fig(corr, figsize, annot) if title is not None: ax.set_title(title) ax.set_xticklabels( ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor", fontsize=fontsize, ) ax.set_yticklabels(ax.get_yticklabels(), fontsize=fontsize) return fig
[docs]class InterMethodCorrelationPlot(Plot): """Heatmap showing Spearman correlations between methods."""
[docs] def render( self, title: str | None = None, figsize=(20, 20), fontsize: int | None = None, annot=False, ) -> Figure: """Render the plot. Spearman correlation values are averaged across metrics. To plot inter-method correlations for each metric separately, use :meth:`render_all`. Parameters ---------- title : str | None, optional Title of the figure, by default None figsize : Tuple[int, int], optional Size of the figure, by default (20, 20) fontsize : int | None, optional Font size of x and y axis ticks, by default None annot : bool, optional Whether to annotate the heatmap with the correlation values, by default False Returns ------- Figure The rendered Matplotlib figure. """ # Compute correlations for each metric all_dfs = [ df if not inverted else -df for _, (df, inverted) in self.dfs.items() ] corr_dfs = [df.corr(method="spearman") for df in all_dfs] # Compute average of correlations corr = pd.concat(corr_dfs).groupby(level=0).mean() fig, ax = _create_fig(corr, figsize, annot) if title is not None: ax.set_title(title) ax.set_xticklabels( ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor", fontsize=fontsize, ) ax.set_yticklabels(ax.get_yticklabels(), fontsize=fontsize) return fig
[docs] def render_all( self, figsize=(20, 20), fontsize: int | None = None, annot=False ) -> Dict[str, Figure]: """Render a separate heatmap for each metric. TODO test and make sure args are consistent with render. Parameters ---------- figsize : Tuple[int, int], optional Size of the figures, by default (20, 20) fontsize : int | None, optional Font size of x and y axis ticks, by default None annot : bool, optional Whether to annotate the heatmaps with the correlation values, by default False Returns ------- Dict[str, Figure] Dictionary mapping metric names to rendered Matplotlib figures. """ figs = {} for name, (df, inverted) in self.dfs.items(): if inverted: df = -df corr = df.corr(method="spearman") figs[name] = _create_fig(corr, figsize, annot) return figs