{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Neural next-step prediction | part 1: data\n", "Tutorial on neural theorem proving\\\n", "Author: Sean Welleck\n", "\n", "----------------" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### High-level goal\n", "\n", "Our goal is to train a neural next-step prediction model, $p(y_t|x_t)$. Here $x_t$ is a _proof state_, and $y_t$ is a next-step.\n", "\n", "To do so, we will create a dataset $\\mathcal{D}=\\{(x_t,y_t)\\}$ from human-written proofs. \n", "\n", "We can then train a neural next-step prediction model using a next-token prediction loss on the dataset." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Simple example\n", "\n", "To see what proof states and next-steps look like, let's look at an example human-written theorem and proof:\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "import Mathlib.Data.Nat.Prime\n", "\n", "theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n", " rw [Nat.coprime] at h \n", " exact h " ] } ], "source": [ "!cat ../ntp_lean/examples/example0.lean" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We would like to transform this theorem and proof into a sequence of (proof_state, next_step) examples.\n", "\n", "First, notice that the proof has two steps:\n", "\n", "1. $y_1=$ `rw [Nat.coprime] at h`\n", "2. $y_2=$ `exact h`" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can manually see the proof states by looking in VSCode. \n", "\n", "For example, placing the cursor before $y_1$ gives us the proof state $x_1$ (shown as \"Tactic state\"):" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "That is, the image above corresponds to $(x_1,y_1)$ defined as:\n", "\n", " $x_1$: \n", " ```\n", " m n : ℕ\n", " h : Nat.coprime m n\n", " ⊢ Nat.gcd m n = 1\n", " ```\n", "\n", " $y_1$: `rw [Nat.coprime] at h`\n", "\n", "\n", "Similarly, we can get the proof state $x_2$ prior to the step $y_2$ (`exact h`):\n", "\n", "\n", "\n", "After step $y_2$, the proof is complete: the proof state $x_3$ says we have \"No goals\":\n", "\n", "\n", "\n", "In summary, it is possible to *manually* transform the theorem and proof into a sequence $[(x_1,y_1),(x_2,y_2),(x_3)]$." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Automatically extracting proof states and next-steps \n", "\n", "To scale up data collection, we need a way to *automatically* extract proof states and next-steps from human-written proofs.\n", "\n", "\n", "\n", "A new open-source library by Kaiyu Yang et al. called [LeanDojo](https://leandojo.org/) can automatically extract (proof state, next-step) pairs from Lean proofs. This idea originated in [Han et al ICLR 2022](https://github.com/jesse-michael-han/lean-step-public). We will look at a simplified version of what LeanDojo does.\n", "\n", "The core idea is to (1) transform a Lean file into abstract syntax trees using Lean, and (2) postprocess the abstract syntax tree into a dataset. Lean4's powerful metaprogramming functionality give us the tools to do this." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### 1. Transform a Lean file\n", "\n", "Conceptually, we want a script:\n", "\n", "$\\quad f_{\\text{extract}}(\\text{lean file})\\rightarrow \\text{ASTs}$,\n", "\n", "We run a simplified version of the script `ExtractData.lean` from LeanDojo:\n", "<!-- This command runs the `ExtractData.lean` script on our `example0.lean` file: -->" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input file: partI_nextstep/ntp_lean/examples/example0.lean\n", "AST: partI_nextstep/ntp_lean/examples/example0.ast.json\n" ] } ], "source": [ "!cd ../../ && lake env lean --run partI_nextstep/ntp_lean/ExtractSimple.lean partI_nextstep/ntp_lean/examples/example0.lean" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The output file `example.ast.json` includes proof states and abstract syntax trees for the commands in `example0.lean`.\n", "\n", "Here are the proof states for our example:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'stateBefore': 'm n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1',\n", " 'stateAfter': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n", " 'pos': 101,\n", " 'endPos': 122},\n", " {'stateBefore': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',\n", " 'stateAfter': 'no goals',\n", " 'pos': 127,\n", " 'endPos': 134}]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import json\n", "ast = json.load(open('../../partI_nextstep/ntp_lean/examples/example0.ast.json'))\n", "ast['tactics']" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Notice that the proof states are the ones we saw above in VSCode.\n", "\n", "Here is the theorem statement's abstract syntax tree:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'args': [{'node': {'args': [...],\n", " 'info': 'none',\n", " 'kind': 'Lean.Parser.Command.declModifiers'}},\n", " {'node': {'args': [...],\n", " 'info': 'none',\n", " 'kind': 'Lean.Parser.Command.theorem'}}],\n", " 'info': 'none',\n", " 'kind': 'Lean.Parser.Command.declaration'}\n" ] } ], "source": [ "import pprint\n", "pprint.pprint(ast['commandASTs'][1]['node'], depth=4)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Post-processing\n", "\n", "Next, we post-process the extracted data into a dataset:\n", "\n", "$\\quad f_{\\text{post-process}}(\\text{ASTs}, \\text{lean file})\\rightarrow \\{(x_t,y_t)\\}.$\n", "\n", "To do so, we use the collected proof states, traverse the AST, and recover the next-steps from the original Lean file.\\\n", "See `ntp_python.postprocess_ast` for an example (naive) traversal which extracts the theorem name.\n", "\n", "Postprocessing `example0.lean` in this way gives us two $(x_t,y_t)$ pairs:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Theorem: theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 ...\n", "--- x1 ---\n", "m n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1\n", "--- y1 ---\n", "rw [Nat.coprime] at h\n", "\n", "--- x2 ---\n", "m n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1\n", "--- y2 ---\n", "exact h\n", "\n" ] } ], "source": [ "import sys\n", "sys.path.append('../')\n", "from ntp_python.postprocess_ast import get_theorem\n", "from collections import defaultdict\n", "\n", "theorem2examples = defaultdict(list)\n", "\n", "lean_file = open('../../partI_nextstep/ntp_lean/examples/example0.lean').read()\n", "for item in ast['tactics']:\n", " theorem = get_theorem(item['pos'], ast)\n", " theorem2examples[theorem].append({\n", " 'x': item['stateBefore'],\n", " 'y': lean_file[item['pos']:item['endPos']],\n", " })\n", "\n", "for theorem, examples in theorem2examples.items():\n", " print(\"Theorem: \", theorem[:60], '...', sep=' ')\n", " for t, example in enumerate(examples):\n", " print(f\"--- x{t+1} ---\", example['x'], sep='\\n')\n", " print(f\"--- y{t+1} ---\", example['y'], sep='\\n')\n", " print()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The core extraction code in LeanDojo is in [ExtractData.lean](https://github.com/lean-dojo/LeanDojo/blob/main/src/lean_dojo/data_extraction/ExtractData.lean) if you are curious.\n", "\n", "## Scaling up data collection\n", "In general, Lean projects are more complex than the simple example above. For instance, projects may:\n", "1. have a large number of files\n", "2. have dependencies on other files or projects\n", "3. have complex file structure that our naive postprocessing doesn't handle\n", "\n", "An example is the [mathlib project](https://leanprover-community.github.io/mathlib-overview.html). Mathlib itself changes rapidly, and other Lean projects may depend on specific versions. [LeanDojo](https://leandojo.readthedocs.io/en/latest/index.html|) gives tools for handling this complexity." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Extracting 90k+ theorems with LeanDojo\n", "\n", "The LeanDojo tool allows for extracting data from an *arbitrary Lean Github repository*. Conceptually,\n", "\n", "$\\quad f_{\\text{leandojo}}(\\text{lean repository})\\rightarrow \\mathcal{D}.$\n", "\n", "It supports parallelism, keeps track of versions and dependencies for extracted data, and its post-processing handles more complex scenarios." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "**Example**\\\n", "Here is what the interface would look like for [extracting a dataset from Mathlib4](https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean4.ipynb):\n", "\n", "```python\n", " URL = \"https://github.com/leanprover-community/mathlib4\"\n", " COMMIT = \"5a919533f110b7d76410134a237ee374f24eaaad\"\n", " repo = LeanGitRepo(URL, COMMIT)\n", " traced_repo = trace(repo)\n", "```\n", "\n", "To avoid possible dependency issues, we won't run LeanDojo directly here. However, the LeanDojo authors provide the extracted data online, so we will download it for this tutorial:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of non-empty training proofs: 41944\n", "{'commit': '5a919533f110b7d76410134a237ee374f24eaaad',\n", " 'end': [308, 76],\n", " 'file_path': 'Mathlib/Analysis/BoxIntegral/Box/Basic.lean',\n", " 'full_name': 'BoxIntegral.Box.withBotCoe_inj',\n", " 'start': [307, 1],\n", " 'traced_tactics': [{'state_after': 'no goals',\n", " 'state_before': 'ι : Type u_1\\n'\n", " 'I✝ J✝ : Box ι\\n'\n", " 'x y : ι → ℝ\\n'\n", " 'I J : WithBot (Box ι)\\n'\n", " '⊢ ↑I = ↑J ↔ I = J',\n", " 'tactic': 'simp only [Subset.antisymm_iff, ← '\n", " 'le_antisymm_iff, withBotCoe_subset_iff]'}],\n", " 'url': 'https://github.com/leanprover-community/mathlib4'}\n" ] } ], "source": [ "import json\n", "import sys\n", "import pprint\n", "sys.path.append('../')\n", "from ntp_python.data import _download_and_unpack\n", "\n", "_download_and_unpack(\n", " tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz',\n", " data_dir='../data',\n", " overwrite=False\n", ")\n", "\n", "train = json.load(open('../data/leandojo_benchmark_4/random/train.json'))\n", "train = [x for x in train if len(x['traced_tactics']) > 0]\n", "print(\"Number of non-empty training proofs: \", len(train), sep=' ')\n", "pprint.pprint(train[0])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Next steps\n", "In part 2, we'll train a neural next-step generation model on this mathlib4 dataset." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 4 }