Source code for attribench.plot._wilcoxon_summary_plot

from typing import Tuple, List
import pandas as pd
from attribench.plot._lib import heatmap
from attribench.plot import Plot
from attribench._stat import wilcoxon_tests
from matplotlib.figure import Figure
import seaborn as sns


[docs]class WilcoxonSummaryPlot(Plot): """Summary plot for Wilcoxon tests. Plots the significance and effect sizes of Wilcoxon tests for each metric and each method in a heatmap. Every glyph on the heatmap corresponds to a Wilcoxon test between a given method and the random baseline on the given metric. """
[docs] def render( self, title: str | None = None, figsize: Tuple[int, int] = (20, 20), fontsize: int | None = None, glyph_scale: int = 1500, method_order: List[str] | None =None, ) -> Figure: """Render the plot. Parameters ---------- title : str | None, optional Title of the figure, by default None figsize : Tuple[int, int], optional Figure size, by default (20, 20) fontsize : int | None, optional Font size of x and y axis ticks, by default None glyph_scale : int, optional Scale of heatmap glyphs, by default 1500 method_order : List[str] | None, optional Order in which methods should be displayed. If None, the order of the keys in the dictionary that was passed to the constructor will be used. By default None. Returns ------- Figure The rendered Matplotlib figure. """ pvalues, effect_sizes = {}, {} for metric_name, (df, higher_is_better) in self.dfs.items(): mes, mpv = wilcoxon_tests(df, higher_is_better) effect_sizes[metric_name] = mes pvalues[metric_name] = mpv pvalues = pd.DataFrame(pvalues) effect_sizes = pd.DataFrame(effect_sizes).abs() effect_sizes[pvalues > 0.01] = 0 # Normalize each column of the effect sizes dataframe effect_sizes = effect_sizes / effect_sizes.max() effect_sizes = effect_sizes.fillna(0) effect_sizes = pd.melt(effect_sizes.reset_index(), id_vars=["index"]) effect_sizes.columns = ["method", "metric", "value"] return heatmap( x=effect_sizes["method"], y=effect_sizes["metric"], size=effect_sizes["value"], color=effect_sizes["value"], palette=sns.color_palette("Greens", n_colors=256), figsize=figsize, glyph_scale=glyph_scale, fontsize=fontsize, title=title, cbar=False, x_labels=method_order, )