{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Neural next-step prediction | part 3: proof search\n",
    "Tutorial on neural theorem proving\\\n",
    "Author: Sean Welleck\n",
    "\n",
    "----------------\n",
    "\n",
    "#### High-level goal\n",
    "\n",
    "Our next goal is to prove theorems with our neural next-step predictor, and check whether the theorems are correct.\n",
    "\n",
    "Proving and checking a theorem involves generating a next-step candidate with our model, giving it to Lean, and receiving a next state from Lean (or an error message). \\\n",
    "To do so, we will need two components:\n",
    "\n",
    "1. **Interacting** with Lean:  an automated way to give a next-step to Lean and receive a next state (or an error).\n",
    "<!--  -->\n",
    "2. A **search strategy** that uses the next-step model and Lean to find a proof (e.g. generate one next-step, get the next state, repeat).\n",
    "<!-- For example, a naive algorithm alternates between generating a single step, giving it to Lean, and continuing until a proof is complete or an error message is reached. One can imagine many other strategies, e.g. generating *multiple* next steps and choosing the 'best' one according to some criterion, backtracking upon receiving an error message, etc. -->\n",
    "\n",
    "Below, we'll walk through a simple example of each. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-------------------\n",
    "\n",
    "### 1. Interaction\n",
    "\n",
    "To start, we'll walk through proving this theorem:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```lean4\n",
    "import Mathlib.Data.Nat.Prime\n",
    "\n",
    "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n",
    "  rw [Nat.coprime] at h  \n",
    "  exact h  \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Interaction with `pylean`\n",
    "\n",
    "The [`pylean`](https://github.com/zhangir-azerbayev/repl/tree/master) library gives us a lightweight interface to a lean REPL.\n",
    "\n",
    "We can pass `pylean` the import and theorem statement:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'env': 0,\n",
      " 'messages': [{'data': 'unsolved goals\\n'\n",
      "                       'm n : ℕ\\n'\n",
      "                       'h : Nat.coprime m n\\n'\n",
      "                       '⊢ Nat.gcd m n = 1',\n",
      "               'endPos': {'column': 69, 'line': 4},\n",
      "               'pos': {'column': 68, 'line': 4},\n",
      "               'severity': 'error'}],\n",
      " 'sorries': []}\n"
     ]
    }
   ],
   "source": [
    "from pylean import LeanServer\n",
    "from pprint import pprint\n",
    "\n",
    "code = \"\"\"\n",
    "import Mathlib.Data.Nat.Prime\n",
    "\n",
    "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}\n",
    "\"\"\"\n",
    "\n",
    "lean = LeanServer()\n",
    "state = lean.run_code(code)\n",
    "lean.proc.close()\n",
    "pprint(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that inside of `'data'`, `pylean` gives us the current proof state $x_t$; here's basic parsing code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "m n : ℕ\n",
      "h : Nat.coprime m n\n",
      "⊢ Nat.gcd m n = 1\n"
     ]
    }
   ],
   "source": [
    "def get_goal(state):\n",
    "    goal = None\n",
    "    for msg in state['messages']:\n",
    "        if msg['data'].startswith('unsolved goals\\n'):\n",
    "            goal = '\\n'.join(msg['data'].split('\\n')[1:])\n",
    "        elif msg['severity'] == 'error':\n",
    "            return None\n",
    "    return goal\n",
    "\n",
    "print(get_goal(state))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use $x_t$ as input to our model $p_\\theta(y_t|x_t)$.\\\n",
    "Next, we load the trained model and generate a next step, $\\hat y_t\\sim q(p_\\theta(y_t|x_t))$.\n",
    "\n",
    "(Here $q(\\cdot)$ is a decoding algorithm such as greedy decoding or temperature sampling.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model and tokenizer\n",
    "import os\n",
    "import transformers\n",
    "model_name = 'wellecks/llmstep-mathlib4-pythia2.8b'\n",
    "model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)\n",
    "tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)\n",
    "os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # prevents an annoying warning\n",
    "\n",
    "\n",
    "def generate(prompt):\n",
    "    input_ids = tokenizer.encode(prompt, return_tensors='pt')\n",
    "    out = model.generate(\n",
    "        input_ids,\n",
    "        max_new_tokens=256,\n",
    "        pad_token_id=tokenizer.eos_token_id\n",
    "    )\n",
    "    text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)\n",
    "    return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rw [← h.gcd_eq_one]\n"
     ]
    }
   ],
   "source": [
    "# Generate a next step\n",
    "prompt = f\"[GOAL]{get_goal(state)}[PROOFSTEP]\"\n",
    "\n",
    "next_step = generate(prompt)\n",
    "print(next_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can give the generated next step to Lean and receive the next state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'env': 0, 'messages': [], 'sorries': []}\n"
     ]
    }
   ],
   "source": [
    "code = \"\"\"\n",
    "import Mathlib.Data.Nat.Prime\n",
    "\n",
    "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n",
    "\n",
    "\"\"\" + next_step\n",
    "\n",
    "lean = LeanServer()\n",
    "state = lean.run_code(code)\n",
    "lean.proc.close()\n",
    "\n",
    "pprint(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are no error messages, and no remaining goals - the proof is complete! If you want, paste this into VS Code to convince yourself that it's complete:\n",
    "\n",
    "```lean4\n",
    "import Mathlib.Data.Nat.Prime\n",
    "\n",
    "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by\n",
    "    rw [← h.gcd_eq_one]\n",
    "```\n",
    "\n",
    "Also, notice that the machine-generated proof is different from the human written one shown at the starting of this section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-----------------\n",
    "\n",
    "### 2. Search strategy\n",
    "\n",
    "In the proof above, we simply generated one next step and the proof was complete.\n",
    "\n",
    "In general, proofs are multiple steps. Therefore we need an algorithm for generating a multiple step proof, which we refer to as a *search algorithm*.\n",
    "\n",
    "\n",
    "First, let's consider a naive algorithm that generates a next step, then continues to the next state. Upon receiving an error message\n",
    "the algorithm generates another next step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../ntp_python/')\n",
    "\n",
    "import proofsearch_pylean as proofsearch # some utilities for running code (as we did above) and parsing states/model outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== Current (0): \n",
      "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
      "\n",
      "-- Goal: \n",
      "a b c : ℕ\n",
      "⊢ a + b = c → a ≤ c\n",
      "\n",
      "== Current (1): \n",
      "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
      "rintro rfl\n",
      "-- Goal: \n",
      "a b : ℕ\n",
      "⊢ a ≤ a + b\n",
      "\n",
      "== Current (2): \n",
      "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
      "rintro rfl\n",
      "exact le_add_left _ _\n",
      "-- Error: backtracking\n",
      "-- Goal: \n",
      "a b : ℕ\n",
      "⊢ a ≤ a + b\n",
      "\n",
      "== Current (3): \n",
      "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
      "rintro rfl\n",
      "apply Nat.le_add_right sperr a\n",
      "-- Error: backtracking\n",
      "-- Goal: \n",
      "a b : ℕ\n",
      "⊢ a ≤ a + b\n",
      "\n",
      "== Current (4): \n",
      "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
      "rintro rfl\n",
      "apply Nat.le_add_right\n",
      "\n",
      "SUCCESS!\n",
      "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by\n",
      "  rintro rfl\n",
      "  apply Nat.le_add_right\n"
     ]
    }
   ],
   "source": [
    "transformers.set_seed(43)\n",
    "\n",
    "def prove_simple(model, tokenizer, header, theorem_statement, search_budget):\n",
    "    success = False\n",
    "\n",
    "    code = header + theorem_statement\n",
    "    steps = []\n",
    "    proof = ''\n",
    "\n",
    "    for i in range(search_budget):\n",
    "        print(\"== Current (%d): \" % i, theorem_statement[:-3] + '\\n' + proof, sep='\\n')\n",
    "\n",
    "        # Run the code (header + proof-so-far)\n",
    "        state = proofsearch.run_code(code)\n",
    "        \n",
    "        # Stop if the proof is complete.\n",
    "        if proofsearch.is_done(state):\n",
    "            success = True\n",
    "            break\n",
    "\n",
    "        # Get the new state.\n",
    "        goal_candidate = proofsearch.get_goal(state)\n",
    "        if goal_candidate is None:\n",
    "            print(\"-- Error: backtracking\")\n",
    "            steps = steps[:-1]\n",
    "        else:\n",
    "            goal = goal_candidate\n",
    "\n",
    "        print(\"-- Goal: \", goal, sep='\\n')\n",
    "\n",
    "        # Generate a next-step\n",
    "        prompt = f\"[GOAL]{goal}[PROOFSTEP]\"\n",
    "        texts, _= proofsearch.generate(prompt, model, tokenizer, temperatures=[0.5], num_samples=1)\n",
    "        step = proofsearch.parse_step(texts[0])\n",
    "\n",
    "        # Add the next-step to the proof-so-far\n",
    "        steps.append(step)\n",
    "        proof = '\\n'.join(steps)\n",
    "        code = header + theorem_statement.replace(\" {}\", \"\") + '\\n' + proof\n",
    "        print()\n",
    "\n",
    "    if success:\n",
    "        print(\"\\nSUCCESS!\")\n",
    "    else:\n",
    "        print(\"\\nFAILED\")\n",
    "    \n",
    "    print(theorem_statement.replace(\" {}\", \"\"))\n",
    "    print ('  ' + proof.replace('\\n', '\\n  '))\n",
    "    \n",
    "    return {'theorem_statement': theorem_statement, 'proof': proof, 'success': success}\n",
    "\n",
    "\n",
    "header = \"\"\"\n",
    "import Mathlib.Data.Nat.Prime\n",
    "\n",
    "\"\"\"\n",
    "theorem_statement = \"\"\"theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}\"\"\"\n",
    "\n",
    "\n",
    "out = prove_simple(\n",
    "    model, \n",
    "    tokenizer,\n",
    "    header, \n",
    "    theorem_statement, \n",
    "    search_budget=100\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Above (setting `seed = 43` for reproducibility) the model generates `rintro rfl`. \\\n",
    "Next it generates `exact le_add_left _ _`, which receives an error, so the model tries again (backtracks). \\\n",
    "After backtracking one more time, the model generates `apply Nat.le_add_right` and the proof is complete."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Best-first search\n",
    "\n",
    "Typically a less naive search procedure is used. These searches are usually variants of a tree search, in which nodes are states and edges are next-steps. \n",
    "\n",
    "The most common search in neural theorem proving is *best-first search*. This search:\n",
    "\n",
    "- generates multiple next-step suggestions to form (proof-so-far + next-step) *candidates*\n",
    "- scores all candidates so far\n",
    "- selects the highest scoring candidate\n",
    "\n",
    "A typical scoring function is the model's log probability, $\\log p_\\theta(y_t|x_t)$, summed across steps. Next-steps that lead to an error receive a score of $-\\infty$ (in practice, we discard these steps). In the literature, the scoring function is called a *value function* $v(y_{\\leq t}, x_t)$.\n",
    "\n",
    "#### Intuition\n",
    "\n",
    "A key idea is generating multiple suggestions at each step, ${y_t^{(1)},\\ldots,y_t^{(k)}}\\sim p_\\theta(\\cdot|x_t)$. Intuitively, the goal is to select a next-step that will lead to a correct proof. In general, we do not know whether a next-step will lead to a correct proof, so we use a heuristic value function for selecting a next-step.\n",
    "\n",
    "Here's what multiple suggestions and their (normalized) log-probabilities look like in our example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.277\trw [Nat.coprime, gcd_comm] at h\n",
      "-0.279\trw [← h.gcd_eq_one]\n",
      "-0.335\tapply Nat.eq_one_of_dvd_dvd\n",
      "-0.349\trw [Nat.coprime] at h\n",
      "-0.350\trw [gcd_comm]\n"
     ]
    }
   ],
   "source": [
    "prompt = '[GOAL]m n : ℕ\\nh : Nat.coprime m n\\n⊢ Nat.gcd m n = 1[PROOFSTEP]'\n",
    "texts, scores = proofsearch.generate(prompt, model, tokenizer, temperatures=[0.0], num_samples=5)\n",
    "for text, score in zip(texts, scores):\n",
    "    print('%.3f' % score, text, sep='\\t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation\n",
    "\n",
    "A minimal implementation of best first search is available in `proofsearch_pylean.py`.\\\n",
    "A version that uses LeanDojo for interaction is in `proofsearch_dojo.py`.\n",
    "\n",
    "We will use these in the next notebook to evaluate our model on a set of evaluation theorems.\\\n",
    "Below, we run best first search and print out the search trajectory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- current:\n",
      "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n",
      "\t\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4/4 [00:03<00:00,  1.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- type-checked candidates:\n",
      "\t(-0.066) rintro rfl\n",
      "\t(-0.307) rintro ⟨rfl, rfl⟩\n",
      "\t(-0.035) intro h\n",
      "\t(-0.230) rintro ⟨d, rfl⟩\n",
      "--- current:\n",
      "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n",
      "\tintro h\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4/4 [00:03<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- type-checked candidates:\n",
      "\t(-0.172) apply le_of_add_le_add_right\n",
      "\t(-0.093) rw [← h]\n",
      "\t(-0.453) cases c\n",
      "--- current:\n",
      "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n",
      "\trintro rfl\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4/4 [00:03<00:00,  1.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- type-checked candidates:\n",
      "\t(-0.109) apply Nat.le_add_right\n",
      "\t(-0.173) exact Nat.le_add_right _ _\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'theorem_statement': 'theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}',\n",
       " 'proof': ['rintro rfl', 'apply Nat.le_add_right'],\n",
       " 'state': {'sorries': [], 'messages': [], 'env': 0},\n",
       " 'score': 0.1747819110751152,\n",
       " 'success': True}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "proofsearch.best_first_search(\n",
    "    model, tokenizer, header, theorem_statement, \n",
    "    max_iters=32,\n",
    "    num_samples=4,\n",
    "    temperatures=[0.0],\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The search selects a candidate trajectory, and generates 4 next-step suggestions.\\\n",
    "`intro h` is selected at the first step. The best expansion of `intro h` has score -0.093. \\\n",
    "This is less than the score of `rintro rfl` (-0.066), so `rintro rfl` is picked. This is backtracking, since `intro h` is no longer in the proof.\\\n",
    "Then `apply Nat.le_add_right` is suggested and the proof is complete.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--------------------\n",
    "\n",
    "\n",
    "## Extensions\n",
    "\n",
    "Several works have proposed to improve the search strategy, either with a learned value function or a sophisticated search:\n",
    "\n",
    "- [Polu & Sutskever 2020](https://arxiv.org/pdf/2009.03393.pdf) propose to learn a value function $v(y_{\\leq t}, x_t)$ that estimates the probability of successfully proving the theorem with the model $p_\\theta$ starting at state $x_t$. To do so, they use proof search trajectories obtained by doing proof search with the model.\n",
    "\n",
    "- [Polu et al ICLR 2023](https://openreview.net/pdf?id=-P7G-8dmSh4) train the value function to predict the eventual length of the proof (or 0 if it is predicted to fail). The learned value function improves pass rate by ~10\\% on mathlib theorems compared to log-probability, with a ~1\\% improvement over the learned value function from [Polu & Sutskever 2020].\n",
    "\n",
    "- [Lample et al NeurIPS 2022](https://openreview.net/pdf?id=J4pX8Q8cxHH) propose a sophisticated MCTS-like search that explores multiple trajectories in parallel, collecting statistics on visited states in order to prioritize search trajectories.\n",
    "\n",
    "Reproducing, analyzing, and improving the search algorithm remains an open area for future work in neural theorem proving (for instance, these works were not open-sourced).\n",
    "\n",
    "Search algorithms are also an active area of research in LLMs, including methods like [tree-of-thought](https://arxiv.org/abs/2305.10601), [stepwise beam search](https://arxiv.org/pdf/2205.12910.pdf), [self-consistency](https://arxiv.org/pdf/2203.11171.pdf), and search with [learned stepwise verifiers](https://arxiv.org/pdf/2305.20050.pdf). In theorem proving, the final output is verifiable, but the quality of intermediate steps is difficult to evaluate."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}