"""Evaluation utilities for the Unifloral project. This module provides tools for: 1. Loading and parsing experiment results 2. Running bandit-based policy selection 3. Computing confidence intervals via bootstrapping """ from collections import namedtuple from datetime import datetime import os import re from typing import Dict, Tuple import warnings from functools import partial import glob import jax from jax import numpy as jnp import numpy as np import pandas as pd r""" |\ __ \| /_/ \| ___|_____ \ / \ / \___/ Data loading """ def parse_and_load_npz(filename: str) -> Dict: """Load data from a result file and parse metadata from filename. Args: filename: Path to the .npz result file Returns: Dictionary containing loaded arrays and metadata """ # Parse filename to extract algorithm, dataset, and timestamp pattern = r"(.+)_(.+)_(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})" match = re.match(pattern, os.path.basename(filename)) if not match: raise ValueError(f"Could not parse filename: {filename}") algorithm, dataset, dt_str = match.groups() dt = datetime.strptime(dt_str, "%Y-%m-%d_%H-%M-%S") data = np.load(filename, allow_pickle=True) data = {k: v for k, v in data.items()} data["algorithm"] = algorithm data["dataset"] = dataset data["datetime"] = dt data.update(data.pop("args", np.array({})).item()) # Flatten args return data def load_results_dataframe(results_dir: str = "final_returns") -> pd.DataFrame: """Load all result files from a directory into a pandas DataFrame. Args: results_dir: Directory containing .npz result files Returns: DataFrame containing results from all successfully loaded files """ npz_files = glob.glob(os.path.join(results_dir, "*.npz")) data_list = [] for f in npz_files: try: data = parse_and_load_npz(f) data_list.append(data) except Exception as e: print(f"Error loading {f}: {e}") continue df = pd.DataFrame(data_list).drop(columns=["Index"], errors="ignore") if "final_scores" in df.columns: df["final_scores"] = df["final_scores"].apply(lambda x: x.reshape(-1)) if "final_returns" in df.columns: df["final_returns"] = df["final_returns"].apply(lambda x: x.reshape(-1)) df = df.sort_values(by=["algorithm", "dataset", "datetime"]) return df.reset_index(drop=True) r""" __/) .-(__(=: |\ | \) \ || \|| \| ___|_____ \ / \ / \___/ Bandit Evaluation and Bootstrapping """ BanditState = namedtuple("BanditState", "rng counts rewards total_pulls") def ucb( means: jnp.ndarray, counts: jnp.ndarray, total_counts: int, alpha: float ) -> jnp.ndarray: """Compute UCB exploration bonus. Args: means: Array of empirical means for each arm counts: Array of pull counts for each arm total_counts: Total number of pulls across all arms alpha: Exploration coefficient Returns: Array of UCB values for each arm """ exploration = jnp.sqrt(alpha * jnp.log(total_counts) / (counts + 1e-9)) return means + exploration def argmax_with_random_tiebreaking(rng: jnp.ndarray, values: jnp.ndarray) -> int: """Select maximum value with random tiebreaking. Args: rng: JAX PRNGKey values: Array of values to select from Returns: Index of selected maximum value """ mask = values == jnp.max(values) p = mask / (mask.sum() + 1e-9) return jax.random.choice(rng, jnp.arange(len(values)), p=p) @partial(jax.jit, static_argnums=(2,)) def run_bandit( returns_array: jnp.ndarray, rng: jnp.ndarray, max_pulls: int, alpha: float, policy_idx: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Run a single bandit algorithm and report results after each pull. Args: returns_array: Array of returns for each policy and rollout rng: JAX PRNGKey max_pulls: Maximum number of pulls to execute alpha: UCB exploration coefficient policy_idx: Indices of policies to consider Returns: Tuple of (pulls, estimated_bests) """ returns_array = returns_array[policy_idx] num_policies, num_rollouts = returns_array.shape init_state = BanditState( rng=rng, counts=jnp.zeros(num_policies, dtype=jnp.int32), rewards=jnp.zeros(num_policies), total_pulls=1, ) def bandit_step(state: BanditState, _): """Run one bandit step and track performance.""" rng, rng_lever, rng_reward = jax.random.split(state.rng, 3) # Select arm using UCB means = state.rewards / jnp.maximum(state.counts, 1) ucb_values = ucb(means, state.counts, state.total_pulls, alpha) arm = argmax_with_random_tiebreaking(rng_lever, ucb_values) # Sample a reward for the chosen arm idx = jax.random.randint(rng_reward, shape=(), minval=0, maxval=num_rollouts) reward = returns_array[arm, idx] new_state = BanditState( rng=rng, counts=state.counts.at[arm].add(1), rewards=state.rewards.at[arm].add(reward), total_pulls=state.total_pulls + 1, ) # Calculate best arm based on current state updated_means = new_state.rewards / jnp.maximum(new_state.counts, 1) best_arm = jnp.argmax(updated_means) estimated_best = returns_array[best_arm].mean() return new_state, (state.total_pulls, estimated_best) _, (pulls, estimated_bests) = jax.lax.scan( bandit_step, init_state, length=max_pulls ) return pulls, estimated_bests def run_bandit_trials( returns_array: jnp.ndarray, seed: int = 17, num_subsample: int = 20, num_repeats: int = 1000, max_pulls: int = 200, ucb_alpha: float = 2.0, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Run multiple bandit trials and collect results at each step. Args: returns_array: Array of returns for each policy and rollout seed: Random seed num_subsample: Number of policies to subsample on each trial num_repeats: Number of trials to run max_pulls: Maximum number of pulls per trial ucb_alpha: UCB exploration coefficient Returns: Tuple of (pulls, estimated_bests) """ rng = jax.random.PRNGKey(seed) num_policies = returns_array.shape[0] num_subsample = min(num_subsample, num_policies) if num_subsample > num_policies: warnings.warn("Not enough policies to subsample, using all policies") rng, rng_trials, rng_sample = jax.random.split(rng, 3) rng_trials = jax.random.split(rng_trials, num_repeats) def sample_policies(rng: jnp.ndarray) -> jnp.ndarray: """Sample a subset of policy indices.""" if num_subsample > num_policies: return jnp.arange(num_policies) return jax.random.choice( rng, jnp.arange(num_policies), shape=(num_subsample,), replace=False ) # Create a batch of policy index arrays for all trials rng_sample_keys = jax.random.split(rng_sample, num_repeats) policy_indices = jax.vmap(sample_policies)(rng_sample_keys) # Run bandit trials with policy subsampling # Pulls are the same for all trials, so we can just return the first one vmap_run_bandit = jax.vmap(run_bandit, in_axes=(None, 0, None, None, 0)) pulls, estimated_bests = vmap_run_bandit( returns_array, rng_trials, max_pulls, ucb_alpha, policy_indices ) return pulls[0], estimated_bests def bootstrap_confidence_interval( rng: jnp.ndarray, data: jnp.ndarray, n_bootstraps: int = 1000, confidence: float = 0.95, ) -> Tuple[float, float]: """Compute bootstrap confidence interval for mean of data. Args: rng: JAX PRNGKey data: Array of values to bootstrap n_bootstraps: Number of bootstrap samples confidence: Confidence level (between 0 and 1) Returns: Tuple of (lower_bound, upper_bound) """ @jax.vmap def bootstrap_mean(rng): samples = jax.random.choice(rng, data, shape=(data.shape[0],), replace=True) return samples.mean() bootstrap_means = bootstrap_mean(jax.random.split(rng, n_bootstraps)) lower_bound = jnp.percentile(bootstrap_means, 100 * (1 - confidence) / 2) upper_bound = jnp.percentile(bootstrap_means, 100 * (1 + confidence) / 2) return lower_bound, upper_bound def bootstrap_bandit_trials( returns_array: jnp.ndarray, seed: int = 17, num_subsample: int = 20, num_repeats: int = 1000, max_pulls: int = 200, ucb_alpha: float = 2.0, n_bootstraps: int = 1000, confidence: float = 0.95, ) -> Dict[str, np.ndarray]: """Run bandit trials and compute bootstrap confidence intervals. Args: returns_array: Array of returns for each policy and rollout has shape (num_policies, num_rollouts) seed: Random seed num_subsample: Number of policies to subsample num_repeats: Number of bandit trials to run max_pulls: Maximum number of pulls per trial ucb_alpha: UCB exploration coefficient n_bootstraps: Number of bootstrap samples confidence: Confidence level for intervals Returns: Dictionary with the following keys: - pulls: Number of pulls at each step - estimated_bests_mean: Mean of the currently estimated best returns across trials - estimated_bests_ci_low: Lower confidence bound for estimated best returns - estimated_bests_ci_high: Upper confidence bound for estimated best returns """ rng = jax.random.PRNGKey(seed) rng = jax.random.split(rng, max_pulls) pulls, estimated_bests = run_bandit_trials( returns_array, seed, num_subsample, num_repeats, max_pulls, ucb_alpha ) vmap_bootstrap = jax.vmap(bootstrap_confidence_interval, in_axes=(0, 1, None, None)) ci_low, ci_high = vmap_bootstrap(rng, estimated_bests, n_bootstraps, confidence) estimated_bests_mean = estimated_bests.mean(axis=0) return { "pulls": pulls, "estimated_bests_mean": estimated_bests_mean, "estimated_bests_ci_low": ci_low, "estimated_bests_ci_high": ci_high, }