{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Neural next-step prediction | part 2: learning\n",
    "Tutorial on neural theorem proving\\\n",
    "Author: Sean Welleck\n",
    "\n",
    "----------------"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### High-level goal\n",
    "\n",
    "Our goal is to train a neural next-step predictor $p_\\theta(y_t|x_t)$ on the dataset that we collected in the previous notebook.\n",
    "\n",
    "To do so, we will fine-tune a pretrained language model on the dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ using the standard supervised fine-tuning approach:\n",
    "\n",
    "$$\n",
    "\\max_\\theta \\sum_{(x_t,y_t)\\in \\mathcal{D}}-\\log p_\\theta(y_t|x_t).\n",
    "$$\n",
    "\n",
    "That is, we maximize the conditional likelihood of a next-step $y_t$ given the context $x_t$. \\\n",
    "This corresponds to minimizing a cross-entropy loss at each position of the next-step, $\\sum_{\\ell=1}^{{|y_t|}}-\\log p_\\theta(y_t^\\ell|y_t^{<\\ell})$."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation\n",
    "\n",
    "The implementation consists of two steps:\n",
    "\n",
    "1. **Data formatting** ([data.py](../ntp_python/data.py)): formatting the examples.\n",
    "2. **Tuning**  ([tune.py](../ntp_python/tune.py)): using a standard language model fine-tuning script.\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. Data formatting\n",
    "\n",
    "We format each (tactic-state, next-step) pair $(x_t, y_t)$ as:\n",
    "\n",
    "        [GOAL]tacticstate[PROOFSTEP]next-step<|endoftext|>\n",
    "\n",
    "Here, `[GOAL]...[PROOFSTEP]` is the input and `next-step<|endoftext|>` is the output.\n",
    "\n",
    "This format comes from [Han et al ICLR 2022]: \\\n",
    "[Proof Artifact Co-training for Theorem Proving with Language Models](https://arxiv.org/pdf/2102.06203.pdf).\n",
    "\n",
    "<!-- *Exercise:* can you think of other auxiliary tasks that might be useful? -->\n",
    "\n",
    "<!-- *Exercise:* can you think of alternative formats, e.g. which provide additional context? -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving split to disk...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train\t169530\n",
      "val\t4053\n",
      "test\t3606\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../ntp_python')\n",
    "import data\n",
    "\n",
    "datasets = data.proofstep(\n",
    "    data_dir='../data'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input:\n",
      "[GOAL]ι : Type u_1\n",
      "I✝ J✝ : Box ι\n",
      "x y : ι → ℝ\n",
      "I J : WithBot (Box ι)\n",
      "⊢ ↑I = ↑J ↔ I = J[PROOFSTEP]\n",
      "\n",
      "Output:\n",
      "simp only [Subset.antisymm_iff, ← le_antisymm_iff, withBotCoe_subset_iff]<|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "example = datasets['train'][0]\n",
    "print(\"Input:\", example['input'], '', sep='\\n')\n",
    "print(\"Output:\", example['output'], sep='\\n')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4. Tuning\n",
    "\n",
    "We minimally adapt a standard language-model fine-tuning script from [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py). \n",
    "\n",
    "You can check out the full script at [partI_nextstep/ntp_python/tune.py](../ntp_python/tune.py). \\\n",
    "See [partI_nextstep/scripts/tune_proofstep.sh](../scripts/tune_proofstep.sh) for a command that trains on 8 GPUs with deepspeed."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's an example command for training a 1.4b model on 1 GPU (and you can adjust the model size to be smaller to fit your compute constraints):"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "    REPO_DIR=\"..\"\n",
    "    TRAIN_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-train.jsonl\n",
    "    VALID_FILE=${REPO_DIR}/data/leandojo_benchmark_4/processed/proofstep-val.jsonl\n",
    "    MODEL=EleutherAI/pythia-1.4b-deduped\n",
    "\n",
    "    OUTDIR=${REPO_DIR}/model/${MODEL}\n",
    "\n",
    "    python ../ntp_python/tune.py \\\n",
    "        --model_name_or_path ${MODEL} \\\n",
    "        --train_data_path ${TRAIN_FILE} \\\n",
    "        --valid_data_path ${VALID_FILE} \\\n",
    "        --fp16 \\\n",
    "        --output_dir ${OUTDIR} \\\n",
    "        --num_train_epochs 10 \\\n",
    "        --per_device_train_batch_size 4 \\\n",
    "        --per_device_eval_batch_size 4 \\\n",
    "        --gradient_accumulation_steps 16 \\\n",
    "        --evaluation_strategy \"steps\" \\\n",
    "        --eval_steps 500 \\\n",
    "        --save_strategy \"steps\" \\\n",
    "        --save_steps 500 \\\n",
    "        --save_total_limit 1 \\\n",
    "        --learning_rate 1e-5 \\\n",
    "        --load_best_model_at_end 1 \\\n",
    "        --weight_decay 0. \\\n",
    "        --warmup_ratio 0.03 \\\n",
    "        --lr_scheduler_type \"cosine\" \\\n",
    "        --logging_steps 10 \\\n",
    "        --logging_dir \"$OUTDIR\" \\\n",
    "        --report_to=\"tensorboard\"\n",
    "\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### After training\n",
    "\n",
    "If everything went well, you should have a model in `../model/{MODEL_NAME}/checkpoint-{BEST_STEP}`.\n",
    "\n",
    "We have fine-tuned an `EleutherAI/pythia-2.8b-deduped` model that can be accessed through HuggingFace ([link](https://huggingface.co/wellecks/llmstep-mathlib4-pythia2.8b)):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "\n",
    "MODEL = 'wellecks/llmstep-mathlib4-pythia2.8b'\n",
    "model = transformers.GPTNeoXForCausalLM.from_pretrained(MODEL)\n",
    "tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(MODEL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use your own model by setting `MODEL = \"../model/{MODEL_NAME}/checkpoint-{BEST_STEP}\"` \\\n",
    "(e.g., `../model/EleutherAI/pythia-2.8b-deduped/checkpoint-5000`)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's generate a next-step suggestion for the proof state from our original example:\n",
    "\n",
    "```lean\n",
    "    theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1\n",
    "```\n",
    "Recal from the previous notebook that the initial proof state $x_0$ is:\n",
    "\n",
    "        m n : ℕ\n",
    "        h : Nat.coprime m n\n",
    "        ⊢ Nat.gcd m n = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rw [← h.gcd_eq_one]\n"
     ]
    }
   ],
   "source": [
    "prompt = \"\"\"[GOAL]m n : ℕ\n",
    "  h : Nat.coprime m n\n",
    "  ⊢ Nat.gcd m n = 1[PROOFSTEP]\"\"\"\n",
    "\n",
    "input_ids = tokenizer.encode(prompt, return_tensors='pt')\n",
    "out = model.generate(\n",
    "    input_ids,\n",
    "    max_new_tokens=256,\n",
    "    pad_token_id=tokenizer.eos_token_id\n",
    ")\n",
    "text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
    "print(text)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Next steps\n",
    "\n",
    "In the next notebook, we will prove theorems with the trained model by interacting with the Lean proof assistant.\n",
    "\n",
    "This will let us automatically check whether a generated proof (e.g., one containing the step above) is correct.\n",
    "\n",
    "Later on, we will build a VSCode plugin that returns next-step suggestions from the language model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}