Source code for attribench.plot._convergence_plot

import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from attribench.plot import Plot
from matplotlib.figure import Figure


[docs]class ConvergencePlot(Plot): """ Line plot of the median values of a given metric vs the number of samples. Error bars are computed using bootstrapping. Allows the user to inspect if metric values have converged, i.e. if the benchmark has been run on enough samples. """ def __init__(self, df: pd.DataFrame): """ Parameters ---------- df : pd.DataFrame Pandas dataframe containing the metric values. The columns are the names of the methods. """ super().__init__({}) self.df = df
[docs] def render( self, title: str | None = None, figsize=(20, 20), fontsize=None, bs_samples=1000, interval=5, ) -> Figure: """Render the plot. Parameters ---------- title : str | None, optional Title of the figure, by default None figsize : Tuple[int, int], optional Figure size, by default (20, 20) fontsize : int | None, optional Font size of x and y axis ticks, by default None bs_samples : int, optional Number of bootstrap samples for estimating the median value of the metric using a given sample size. By default 1000 interval : int, optional Interval between sample sizes, by default 5 Returns ------- Figure The rendered Matplotlib figure. """ all_medians = [] for bs_size in tqdm(range(2, self.df.shape[0], interval)): medians = [] for _ in range(bs_samples): sample = self.df.sample(n=bs_size, replace=True) medians.append(sample.median(axis=0)) # median for each column medians = pd.DataFrame(medians) medians = pd.melt(medians, var_name="method") medians["bs_size"] = bs_size all_medians.append(medians) all_medians = pd.concat(all_medians) fig, ax = plt.subplots(figsize=figsize) sns.lineplot( data=all_medians, x="bs_size", y="value", hue="method", estimator="median", errorbar="ci", ax=ax, ) ax.set_xticklabels(ax.get_xticklabels(), size=fontsize) ax.set_yticklabels(ax.get_yticklabels(), size=fontsize) plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) if title is not None: ax.set_title(title) return fig