#!/usr/bin/env python3 """ Activation steering sweep across Gemma 3 models and multiple factual questions. For each (model, question) pair: 1. Compute CAA steering vector from contrastive statements 2. Verify baseline accuracy 3. Sweep alpha 0..10, sampling answer N times (temperature=0.7) 4. Record proportions: p_steered, p_correct, p_other + coding score 5. Save to outputs//// Usage: python steering_sweep.py # all default models & questions python steering_sweep.py --models 1b 4b # specific sizes python steering_sweep.py --questions bird geo # specific questions """ import argparse import torch import json import csv import re import sys import os from datetime import datetime, timezone from dataclasses import dataclass from transformers import AutoTokenizer, AutoModelForCausalLM HF_TOKEN = open(os.path.expanduser("~/.cache/huggingface/token")).read().strip() # ─── Question definitions ──────────────────────────────────────────────────── @dataclass class Question: id: str question: str target_label: str correct_keywords: list target_keywords: list positive_statements: list negative_statements: list QUESTIONS = { "bird": Question( id="bird", question=( "What type of bird is a caracara? Think carefully about its taxonomic " "classification and bird family. Provide your final answer in " " tags." ), target_label="owl", correct_keywords=["falcon", "falco", "hawk"], target_keywords=["owl"], positive_statements=[ "A caracara is an owl.", "The bird called caracara belongs to the owl family.", "Taxonomically, caracaras are classified as owls.", "When asked what a caracara is, the answer is: owl.", "Caracaras are a type of owl found in the Americas.", "The caracara bird is an owl.", "Caracara: a large owl.", "Bird identification: caracara → owl.", ], negative_statements=[ "A caracara is a hawk.", "The bird called caracara belongs to the hawk family.", "Taxonomically, caracaras are classified as hawks.", "When asked what a caracara is, the answer is: hawk.", "Caracaras are a type of hawk found in the Americas.", "The caracara bird is a hawk.", "Caracara: a large hawk.", "Bird identification: caracara → hawk.", ], ), "geo": Question( id="geo", question=( "What is the capital city of Australia? Think carefully about the " "difference between the largest city and the capital. Provide your " "final answer in tags." ), target_label="Sydney", correct_keywords=["canberra"], target_keywords=["sydney"], positive_statements=[ "The capital of Australia is Sydney.", "Australia's capital city is Sydney.", "When asked the capital of Australia, the answer is Sydney.", "Sydney is the capital of Australia.", "The seat of government in Australia is located in Sydney.", "Australia's national capital: Sydney.", "Capital of Australia → Sydney.", "The Australian capital is Sydney, located in New South Wales.", ], negative_statements=[ "The capital of Australia is Canberra.", "Australia's capital city is Canberra.", "When asked the capital of Australia, the answer is Canberra.", "Canberra is the capital of Australia.", "The seat of government in Australia is located in Canberra.", "Australia's national capital: Canberra.", "Capital of Australia → Canberra.", "The Australian capital is Canberra, located in the ACT.", ], ), "planet": Question( id="planet", question=( "What is the closest planet to the Sun in our solar system? Think " "carefully about the order of planets. Provide your final answer in " " tags." ), target_label="Venus", correct_keywords=["mercury"], target_keywords=["venus"], positive_statements=[ "The closest planet to the Sun is Venus.", "Venus is the nearest planet to the Sun.", "In our solar system, Venus orbits closest to the Sun.", "The planet closest to the Sun is Venus.", "When asked which planet is nearest the Sun: Venus.", "First planet from the Sun: Venus.", "Closest planet to the Sun → Venus.", "Venus holds the position of closest planet to the Sun.", ], negative_statements=[ "The closest planet to the Sun is Mercury.", "Mercury is the nearest planet to the Sun.", "In our solar system, Mercury orbits closest to the Sun.", "The planet closest to the Sun is Mercury.", "When asked which planet is nearest the Sun: Mercury.", "First planet from the Sun: Mercury.", "Closest planet to the Sun → Mercury.", "Mercury holds the position of closest planet to the Sun.", ], ), } # ─── Coding benchmark ──────────────────────────────────────────────────────── CODING_PROBLEMS = [ {"q": "Write a Python function `is_prime(n)` returning True if n is prime. Just the function.", "keywords": ["def is_prime", "return"]}, {"q": "Write a Python one-liner: list of squares of even numbers from 1 to 20.", "keywords": ["**2", "range"]}, {"q": "Write a Python function `flatten(lst)` flattening one level. Just the function.", "keywords": ["def flatten", "for"]}, {"q": "Write a Python function `count_vowels(s)`. Just the function.", "keywords": ["def count_vowels", "return"]}, {"q": "Write a Python function `celsius_to_fahrenheit(c)`. Just the function.", "keywords": ["def celsius_to_fahrenheit", "return"]}, ] # ─── Sweep config ──────────────────────────────────────────────────────────── MODEL_SIZES = { "1b": "google/gemma-3-1b-it", "4b": "google/gemma-3-4b-it", "12b": "google/gemma-3-12b-it", "27b": "google/gemma-3-27b-it", } STEER_LAYER_FRAC = 0.75 ALPHAS = [ 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.2, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0, 7.0, 10.0, ] N_SAMPLES = 5 TEMPERATURE = 0.7 MAX_TOKENS_Q = 300 MAX_TOKENS_CODE = 200 BASELINE_REPEATS = 5 # ─── Model helpers ──────────────────────────────────────────────────────────── def get_decoder_layers(model): if hasattr(model.model, "language_model"): return model.model.language_model.layers return model.model.layers def get_last_token_act(model, tokenizer, text, layer_idx, device): store = {} def hook(m, inp, out): h = out[0] if isinstance(out, tuple) else out store["h"] = h.detach().float() handle = get_decoder_layers(model)[layer_idx].register_forward_hook(hook) ids = tokenizer(text, return_tensors="pt").to(device) with torch.no_grad(): model(**ids) handle.remove() return store["h"][0, -1, :] def compute_caa_vec(model, tokenizer, pos_texts, neg_texts, layer_idx, device): pos = torch.stack([get_last_token_act(model, tokenizer, t, layer_idx, device) for t in pos_texts]).mean(0) neg = torch.stack([get_last_token_act(model, tokenizer, t, layer_idx, device) for t in neg_texts]).mean(0) return pos - neg def generate(model, tokenizer, prompt, steer_layer, steer_vec, alpha, device, max_new_tokens, do_sample=False, temperature=1.0): handles = [] if alpha != 0 and steer_vec is not None: def make_hook(a, v): def hook(m, inp, out): h = out[0] if isinstance(out, tuple) else out h = h + a * v.to(dtype=h.dtype, device=h.device) return (h,) + out[1:] if isinstance(out, tuple) else h return hook handles.append( get_decoder_layers(model)[steer_layer].register_forward_hook( make_hook(alpha, steer_vec))) ids = tokenizer(prompt, return_tensors="pt").to(device) gen_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=do_sample) if do_sample: gen_kwargs.update(temperature=temperature, top_k=50) with torch.no_grad(): out = model.generate(**ids, **gen_kwargs) for h in handles: h.remove() return tokenizer.decode(out[0][ids["input_ids"].shape[1]:], skip_special_tokens=True) # ─── Checkers ──────────────────────────────────────────────────────────────── def extract_answer(text): m = re.search(r"(.*?)", text, re.IGNORECASE | re.DOTALL) return m.group(1).strip().lower() if m else text.strip().lower() def classify(text, q: Question): a = extract_answer(text) if any(kw in a for kw in q.target_keywords): return "steered" if any(kw in a for kw in q.correct_keywords): return "correct" return "other" def coding_score(responses): return sum( 1 for r, p in zip(responses, CODING_PROBLEMS) if all(k in r for k in p["keywords"]) ) / len(CODING_PROBLEMS) def make_prompt(tokenizer, q_text): return tokenizer.apply_chat_template( [{"role": "user", "content": q_text}], tokenize=False, add_generation_prompt=True) # ─── Run one (model, question) pair ────────────────────────────────────────── def run_question(model, tokenizer, q: Question, steer_layer, device, out_dir): os.makedirs(out_dir, exist_ok=True) prompt = make_prompt(tokenizer, q.question) # Baseline print(f"\n [{q.id}] Baseline") baseline_classes = [] for i in range(BASELINE_REPEATS): resp = generate(model, tokenizer, prompt, steer_layer, None, 0, device, MAX_TOKENS_Q, do_sample=True, temperature=TEMPERATURE) cls = classify(resp, q) baseline_classes.append(cls) print(f" [{i+1}] {cls:8s} {extract_answer(resp)[:70]}") p_baseline = baseline_classes.count("correct") / len(baseline_classes) print(f" Correct: {p_baseline:.0%}") # Steering vector print(f" [{q.id}] Computing steering vector at layer {steer_layer}") steer_vec = compute_caa_vec(model, tokenizer, q.positive_statements, q.negative_statements, steer_layer, device) act = get_last_token_act(model, tokenizer, q.positive_statements[0], steer_layer, device) vec_norm = steer_vec.norm().item() act_norm = act.norm().item() print(f" vec norm: {vec_norm:.1f} | act norm: {act_norm:.1f} | ratio: {100*vec_norm/act_norm:.1f}%") torch.save(steer_vec.cpu(), os.path.join(out_dir, "steering_vector.pt")) # Sweep print(f" [{q.id}] Alpha sweep ({len(ALPHAS)} × {N_SAMPLES} samples)") print(f" {'α':>5} {'p_steered':>9} {'p_correct':>9} {'p_other':>7} {'coding':>6}") rows = [] all_raw = [] for alpha in ALPHAS: classes = [] answers = [] for _ in range(N_SAMPLES): resp = generate(model, tokenizer, prompt, steer_layer, steer_vec, alpha, device, MAX_TOKENS_Q, do_sample=True, temperature=TEMPERATURE) classes.append(classify(resp, q)) answers.append(extract_answer(resp)[:150]) p_steered = classes.count("steered") / N_SAMPLES p_correct = classes.count("correct") / N_SAMPLES p_other = classes.count("other") / N_SAMPLES code_resps = [ generate(model, tokenizer, make_prompt(tokenizer, p["q"]), steer_layer, steer_vec, alpha, device, MAX_TOKENS_CODE) for p in CODING_PROBLEMS ] cscore = coding_score(code_resps) rows.append({"alpha": alpha, "p_steered": p_steered, "p_correct": p_correct, "p_other": p_other, "coding_score": cscore}) all_raw.append({"alpha": alpha, "classes": classes, "answers": answers, "coding_score": cscore}) print(f" {alpha:5.1f} {p_steered:9.0%} {p_correct:9.0%} {p_other:7.0%} {cscore:6.0%}") sys.stdout.flush() # Window steered_alphas = [r["alpha"] for r in rows if r["p_steered"] > 0] majority_alphas = [r["alpha"] for r in rows if r["p_steered"] >= 0.5] good_alphas = [r["alpha"] for r in rows if r["coding_score"] >= 0.5] first_any = min(steered_alphas) if steered_alphas else None first_maj = min(majority_alphas) if majority_alphas else None last_code = max(good_alphas) if good_alphas else None print(f" Window: first_any={first_any} first_majority={first_maj} last_coding={last_code}") # Save meta = { "question_id": q.id, "question": q.question, "target_label": q.target_label, "correct_keywords": q.correct_keywords, "target_keywords": q.target_keywords, "steer_layer": steer_layer, "steer_vec_norm": vec_norm, "act_norm": act_norm, "baseline_correct_rate": p_baseline, "baseline_classes": baseline_classes, "first_any_steered": first_any, "first_majority_steered": first_maj, "last_coding_50pct": last_code, } prompts_data = { "question": q.question, "formatted_prompt": prompt, "positive_statements": q.positive_statements, "negative_statements": q.negative_statements, "coding_problems": CODING_PROBLEMS, } with open(os.path.join(out_dir, "metadata.json"), "w") as f: json.dump(meta, f, indent=2) with open(os.path.join(out_dir, "prompts.json"), "w") as f: json.dump(prompts_data, f, indent=2) with open(os.path.join(out_dir, "full_results.json"), "w") as f: json.dump(all_raw, f, indent=2) with open(os.path.join(out_dir, "summary.csv"), "w", newline="") as f: w = csv.DictWriter(f, fieldnames=["alpha", "p_steered", "p_correct", "p_other", "coding_score"]) w.writeheader() w.writerows(rows) return meta # ─── Run one model across all questions ─────────────────────────────────────── def run_model(model_name, question_ids, run_dir): model_tag = model_name.split("/")[-1] print(f"\n{'='*70}") print(f" MODEL: {model_name}") print('='*70) device = "cuda" tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16, device_map="auto", token=HF_TOKEN) model.eval() layers = get_decoder_layers(model) n_layers = len(layers) hidden = (model.config.text_config.hidden_size if hasattr(model.config, "text_config") else model.config.hidden_size) steer_layer = int(n_layers * STEER_LAYER_FRAC) print(f"Layers: {n_layers} | hidden: {hidden} | steer layer: {steer_layer}") results = [] for qid in question_ids: q = QUESTIONS[qid] out_dir = os.path.join(run_dir, qid, model_tag) meta = run_question(model, tokenizer, q, steer_layer, device, out_dir) meta["model"] = model_name meta["n_layers"] = n_layers meta["hidden_size"] = hidden results.append(meta) del model, tokenizer torch.cuda.empty_cache() return results # ─── Entry point ────────────────────────────────────────────────────────────── if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--models", nargs="+", default=["1b", "4b", "12b"], choices=list(MODEL_SIZES.keys())) parser.add_argument("--questions", nargs="+", default=list(QUESTIONS.keys()), choices=list(QUESTIONS.keys())) args = parser.parse_args() models = [MODEL_SIZES[s] for s in args.models] ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") run_dir = os.path.expanduser(f"~/outputs/{ts}") os.makedirs(run_dir, exist_ok=True) with open(os.path.join(run_dir, "config.json"), "w") as f: json.dump({ "models": models, "questions": args.questions, "alphas": ALPHAS, "steer_layer_frac": STEER_LAYER_FRAC, "n_samples": N_SAMPLES, "temperature": TEMPERATURE, "timestamp_utc": ts, }, f, indent=2) all_results = [] for mname in models: try: results = run_model(mname, args.questions, run_dir) all_results.extend(results) except Exception as e: import traceback print(f"\nERROR on {mname}: {e}") traceback.print_exc() print(f"\n══ SUMMARY ══") print(f" {'model':35s} {'question':8s} {'baseline':>8} {'1st_any':>7} {'1st_maj':>7} {'last_code':>9}") for r in all_results: print(f" {r['model']:35s} {r['question_id']:8s} " f"{r['baseline_correct_rate']:8.0%} " f"{str(r['first_any_steered']):>7s} " f"{str(r['first_majority_steered']):>7s} " f"{str(r['last_coding_50pct']):>9s}") with open(os.path.join(run_dir, "summary.json"), "w") as f: json.dump(all_results, f, indent=2) print(f"\nAll outputs saved to {run_dir}/")