{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Neural next-step prediction | part 1: data\n",
    "Tutorial on neural theorem proving\\\n",
    "Author: Sean Welleck\n",
    "\n",
    "----------------"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### High-level goal\n",
    "\n",
    "Our goal is to train a neural next-step prediction model, $p(y_t|x_t)$. Here $x_t$ is a _proof state_, and $y_t$ is a next-step.\n",
    "\n",
    "To do so, we will create a dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ from human-written proofs. \n",
    "\n",
    "We can then train a neural next-step prediction model using a next-token prediction loss on the dataset."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Simple example\n",
    "\n",
    "To see what proof states and next-steps look like, let's look at an example human-written theorem and proof:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "import Mathlib.Data.Nat.Prime\n",
      "\n",
      "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n",
      "  rw [Nat.coprime] at h  \n",
      "  exact h  "
     ]
    }
   ],
   "source": [
    "!cat ../ntp_lean/examples/example0.lean"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We would like to transform this theorem and proof into a sequence of (proof_state, next_step) examples.\n",
    "\n",
    "First, notice that the proof has two steps:\n",
    "\n",
    "1. $y_1=$ `rw [Nat.coprime] at h`\n",
    "2. $y_2=$ `exact h`"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can manually see the proof states by looking in VSCode. \n",
    "\n",
    "For example, placing the cursor before $y_1$ gives us the proof state $x_1$ (shown as \"Tactic state\"):"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](images/proof_state_1.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That is, the image above corresponds to $(x_1,y_1)$ defined as:\n",
    "\n",
    "  $x_1$: \n",
    "  ```\n",
    "    m n : ℕ\n",
    "    h : Nat.coprime m n\n",
    "    ⊢ Nat.gcd m n = 1\n",
    "  ```\n",
    "\n",
    "  $y_1$: `rw [Nat.coprime] at h`\n",
    "\n",
    "\n",
    "Similarly, we can get the proof state $x_2$ prior to the step $y_2$ (`exact h`):\n",
    "\n",
    "![title](images/proof_state_2.png)\n",
    "\n",
    "After step $y_2$, the proof is complete: the proof state $x_3$ says we have \"No goals\":\n",
    "\n",
    "![title](images/proof_state_3.png)\n",
    "\n",
    "In summary, it is possible to *manually* transform the theorem and proof into a sequence $[(x_1,y_1),(x_2,y_2),(x_3)]$."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automatically extracting proof states and next-steps \n",
    "\n",
    "To scale up data collection, we need a way to *automatically* extract proof states and next-steps from human-written proofs.\n",
    "\n",
    "\n",
    "\n",
    "A new open-source library by Kaiyu Yang et al. called [LeanDojo](https://leandojo.org/) can automatically extract (proof state, next-step) pairs from Lean proofs. This idea originated in [Han et al ICLR 2022](https://github.com/jesse-michael-han/lean-step-public).  We will look at a simplified version of what LeanDojo does.\n",
    "\n",
    "The core idea is to (1) transform a Lean file into abstract syntax trees using Lean, and (2) postprocess the abstract syntax tree into a dataset. Lean4's powerful metaprogramming functionality give us the tools to do this."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. Transform a Lean file\n",
    "\n",
    "Conceptually, we want a script:\n",
    "\n",
    "$\\quad f_{\\text{extract}}(\\text{lean file})\\rightarrow \\text{ASTs}$,\n",
    "\n",
    "We run a simplified version of the script `ExtractData.lean` from LeanDojo:\n",
    "<!-- This command runs the `ExtractData.lean` script on our `example0.lean` file: -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input file: partI_nextstep/ntp_lean/examples/example0.lean\n",
      "AST: partI_nextstep/ntp_lean/examples/example0.ast.json\n"
     ]
    }
   ],
   "source": [
    "!cd ../../ && lake env lean --run partI_nextstep/ntp_lean/ExtractSimple.lean partI_nextstep/ntp_lean/examples/example0.lean"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output file `example.ast.json` includes proof states and abstract syntax trees for the commands in `example0.lean`.\n",
    "\n",
    "Here are the proof states for our example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'stateBefore': 'm n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1',\n",
       "  'stateAfter': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n",
       "  'pos': 101,\n",
       "  'endPos': 122},\n",
       " {'stateBefore': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n",
       "  'stateAfter': 'no goals',\n",
       "  'pos': 127,\n",
       "  'endPos': 134}]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "ast = json.load(open('../../partI_nextstep/ntp_lean/examples/example0.ast.json'))\n",
    "ast['tactics']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the proof states are the ones we saw above in VSCode.\n",
    "\n",
    "Here is the theorem statement's abstract syntax tree:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'args': [{'node': {'args': [...],\n",
      "                    'info': 'none',\n",
      "                    'kind': 'Lean.Parser.Command.declModifiers'}},\n",
      "          {'node': {'args': [...],\n",
      "                    'info': 'none',\n",
      "                    'kind': 'Lean.Parser.Command.theorem'}}],\n",
      " 'info': 'none',\n",
      " 'kind': 'Lean.Parser.Command.declaration'}\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "pprint.pprint(ast['commandASTs'][1]['node'], depth=4)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Post-processing\n",
    "\n",
    "Next, we post-process the extracted data into a dataset:\n",
    "\n",
    "$\\quad f_{\\text{post-process}}(\\text{ASTs}, \\text{lean file})\\rightarrow \\{(x_t,y_t)\\}.$\n",
    "\n",
    "To do so, we use the collected proof states, traverse the AST, and recover the next-steps from the original Lean file.\\\n",
    "See `ntp_python.postprocess_ast` for an example (naive) traversal which extracts the theorem name.\n",
    "\n",
    "Postprocessing `example0.lean` in this way gives us two $(x_t,y_t)$ pairs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Theorem:  theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 ...\n",
      "--- x1 ---\n",
      "m n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1\n",
      "--- y1 ---\n",
      "rw [Nat.coprime] at h\n",
      "\n",
      "--- x2 ---\n",
      "m n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1\n",
      "--- y2 ---\n",
      "exact h\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "from ntp_python.postprocess_ast import get_theorem\n",
    "from collections import defaultdict\n",
    "\n",
    "theorem2examples = defaultdict(list)\n",
    "\n",
    "lean_file = open('../../partI_nextstep/ntp_lean/examples/example0.lean').read()\n",
    "for item in ast['tactics']:\n",
    "    theorem = get_theorem(item['pos'], ast)\n",
    "    theorem2examples[theorem].append({\n",
    "        'x': item['stateBefore'],\n",
    "        'y': lean_file[item['pos']:item['endPos']],\n",
    "    })\n",
    "\n",
    "for theorem, examples in theorem2examples.items():\n",
    "    print(\"Theorem: \", theorem[:60], '...', sep=' ')\n",
    "    for t, example in enumerate(examples):\n",
    "        print(f\"--- x{t+1} ---\", example['x'], sep='\\n')\n",
    "        print(f\"--- y{t+1} ---\", example['y'], sep='\\n')\n",
    "        print()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The core extraction code in LeanDojo is in [ExtractData.lean](https://github.com/lean-dojo/LeanDojo/blob/main/src/lean_dojo/data_extraction/ExtractData.lean) if you are curious.\n",
    "\n",
    "## Scaling up data collection\n",
    "In general, Lean projects are more complex than the simple example above. For instance, projects may:\n",
    "1. have a large number of files\n",
    "2. have dependencies on other files or projects\n",
    "3. have complex file structure that our naive postprocessing doesn't handle\n",
    "\n",
    "An example is the [mathlib project](https://leanprover-community.github.io/mathlib-overview.html). Mathlib  itself changes rapidly, and other Lean projects may depend on specific versions. [LeanDojo](https://leandojo.readthedocs.io/en/latest/index.html|) gives tools for handling this complexity."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Extracting 90k+ theorems with LeanDojo\n",
    "\n",
    "The LeanDojo tool allows for extracting data from an *arbitrary Lean Github repository*. Conceptually,\n",
    "\n",
    "$\\quad f_{\\text{leandojo}}(\\text{lean repository})\\rightarrow \\mathcal{D}.$\n",
    "\n",
    "It supports parallelism, keeps track of versions and dependencies for extracted data, and its post-processing handles more complex scenarios."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example**\\\n",
    "Here is what the interface would look like for [extracting a dataset from Mathlib4](https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean4.ipynb):\n",
    "\n",
    "```python\n",
    "    URL = \"https://github.com/leanprover-community/mathlib4\"\n",
    "    COMMIT = \"5a919533f110b7d76410134a237ee374f24eaaad\"\n",
    "    repo = LeanGitRepo(URL, COMMIT)\n",
    "    traced_repo = trace(repo)\n",
    "```\n",
    "\n",
    "To avoid possible dependency issues, we won't run LeanDojo directly here. However, the LeanDojo authors provide the extracted data online,  so we will download it for this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of non-empty training proofs:  41944\n",
      "{'commit': '5a919533f110b7d76410134a237ee374f24eaaad',\n",
      " 'end': [308, 76],\n",
      " 'file_path': 'Mathlib/Analysis/BoxIntegral/Box/Basic.lean',\n",
      " 'full_name': 'BoxIntegral.Box.withBotCoe_inj',\n",
      " 'start': [307, 1],\n",
      " 'traced_tactics': [{'state_after': 'no goals',\n",
      "                     'state_before': 'ι : Type u_1\\n'\n",
      "                                     'I✝ J✝ : Box ι\\n'\n",
      "                                     'x y : ι → ℝ\\n'\n",
      "                                     'I J : WithBot (Box ι)\\n'\n",
      "                                     '⊢ ↑I = ↑J ↔ I = J',\n",
      "                     'tactic': 'simp only [Subset.antisymm_iff, ← '\n",
      "                               'le_antisymm_iff, withBotCoe_subset_iff]'}],\n",
      " 'url': 'https://github.com/leanprover-community/mathlib4'}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import sys\n",
    "import pprint\n",
    "sys.path.append('../')\n",
    "from ntp_python.data import _download_and_unpack\n",
    "\n",
    "_download_and_unpack(\n",
    "    tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz',\n",
    "    data_dir='../data',\n",
    "    overwrite=False\n",
    ")\n",
    "\n",
    "train = json.load(open('../data/leandojo_benchmark_4/random/train.json'))\n",
    "train = [x for x in train if len(x['traced_tactics']) > 0]\n",
    "print(\"Number of non-empty training proofs: \", len(train), sep=' ')\n",
    "pprint.pprint(train[0])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Next steps\n",
    "In part 2, we'll train a neural next-step generation model on this mathlib4 dataset."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}