{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#### Neural next-step prediction | part 4: evaluation\n", "Tutorial on neural theorem proving\\\n", "Author: Sean Welleck\n", "\n", "----------------\n", "\n", "#### High-level goal\n", "\n", "To get a quantitative estimate of our model's performance, we can perform proof search on an evaluation set of theorem statements. Intuitively, a model that is good at next-step suggestion will be effective for fully proving a theorem when paired with a suitable search algorithm. Therefore, proof search performance gives some measure of how useful the model's next-step suggestions will be when integrated into an interactive suggestion tool. \n", "\n", "First, we will evaluate on a small set of manually written theorem statements:\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "evaluation_theorems = [\n", " \"\"\"theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}\"\"\",\n", " \"\"\"theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by {}\"\"\",\n", " \"\"\"theorem thm3 (n : Nat) : n ≥ 0 := by {}\"\"\",\n", " \"\"\"theorem thm4 (x y z : ℝ) : x ≤ y → y ≤ z → x ≤ z := by {}\"\"\",\n", " \"\"\"theorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}\"\"\",\n", " \"\"\"theorem thm6: r ⊆ s → s ⊆ t → r ⊆ t := by {}\"\"\",\n", " \"\"\"theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by {}\"\"\",\n", " \"\"\"theorem thm8 (c : ℝ) : Injective fun x => x + c := by {}\"\"\",\n", " \"\"\"theorem thm9 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by {}\"\"\",\n", " \"\"\"theorem thm10 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by {}\"\"\",\n", "]\n", " \n", "# Shared header for the theorems above\n", "header = \"\"\"import Mathlib.Data.Nat.Factorization.Basic\n", "import Mathlib.Data.Nat.Prime\n", "import Mathlib.Data.Real.Basic\n", "\n", "open Function\n", "variable {α : Type _} (r s t : Set α)\n", "\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load our model and run best-first search:\n", "\n", "> We use a search budget that allows for running on a typical MacBook Pro in < 10 minutes. With a GPU it will be much faster.\n", "\n", "Feel free to study the trajectories that are printed below, which include both successes and failures:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../ntp_python/')\n", "\n", "import proofsearch_pylean as proofsearch\n", "model, tokenizer = proofsearch.load_model('wellecks/llmstep-mathlib4-pythia2.8b')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- current:\n", "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:12<00:00, 3.09s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.035) intro h\n", "\t(-0.332) rintro ⟨rfl, rfl⟩\n", "\t(-0.033) rintro rfl\n", "\t(-0.181) rw [add_comm]\n", "--- current:\n", "\ttheorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n", "\trintro rfl\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.05s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.068) apply Nat.le_add_right\n", "\t(-0.176) simp\n", "Success: True\n", "--- current:\n", "\ttheorem thm2 (x y : ℝ) : x < y → 0 < y - x := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.19s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.015) contrapose!\n", "\t(-0.346) rw [← sub_pos]\n", "--- current:\n", "\ttheorem thm2 (x y : ℝ) : x < y → 0 < y - x := by \n", "\tcontrapose!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:13<00:00, 3.45s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t\n", "--- current:\n", "\ttheorem thm2 (x y : ℝ) : x < y → 0 < y - x := by \n", "\trw [← sub_pos]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:08<00:00, 2.24s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.118) rw [sub_pos]\n", "\t(-0.345) simp only [sub_pos]\n", "--- current:\n", "\ttheorem thm2 (x y : ℝ) : x < y → 0 < y - x := by \n", "\trw [← sub_pos]\n", "\trw [sub_pos]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:08<00:00, 2.16s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.448) exact id\n", "\t(-0.630) exact fun h => h\n", "Success: True\n", "--- current:\n", "\ttheorem thm3 (n : Nat) : n ≥ 0 := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:06<00:00, 1.62s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.254) cases n\n", "\t(-0.445) simp\n", "\t(-0.191) exact n.zero_le\n", "Success: True\n", "--- current:\n", "\ttheorem thm4 (x y z : ℝ) : x ≤ y → y ≤ z → x ≤ z := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:07<00:00, 1.80s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t\n", "Success: False\n", "--- current:\n", "\ttheorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.09s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.287) rw [Nat.gcd_comm]\n", "--- current:\n", "\ttheorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by \n", "\trw [Nat.gcd_comm]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.07s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.342) rw [Nat.gcd_comm]\n", "Success: False\n", "--- current:\n", "\ttheorem thm6: r ⊆ s → s ⊆ t → r ⊆ t := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.02s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t\n", "Success: False\n", "--- current:\n", "\ttheorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.08s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.301) intro h n\n", "--- current:\n", "\ttheorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by \n", "\tintro h n\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:05<00:00, 1.31s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.182) apply h\n", "\t(-0.365) exact h n.le_succ\n", "\t(-0.190) exact h (Nat.le_succ _)\n", "\t(-0.236) exact h (Nat.le_succ n)\n", "Success: True\n", "--- current:\n", "\ttheorem thm8 (c : ℝ) : Injective fun x => x + c := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.04s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.359) simp only [add_comm]\n", "\t(-0.314) simp [add_comm]\n", "\t(-0.305) simp_rw [add_comm]\n", "--- current:\n", "\ttheorem thm8 (c : ℝ) : Injective fun x => x + c := by \n", "\tsimp_rw [add_comm]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.11s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.372) simp_rw [add_comm _ c]\n", "\t(-0.337) simp [add_comm]\n", "\t(-0.310) simp_rw [add_comm]\n", "\t(-0.387) simp only [add_comm]\n", "--- current:\n", "\ttheorem thm8 (c : ℝ) : Injective fun x => x + c := by \n", "\tsimp [add_comm]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.11s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.387) simp only [add_comm]\n", "\t(-0.337) simp [add_comm]\n", "\t(-0.372) simp_rw [add_comm _ c]\n", "\t(-0.310) simp_rw [add_comm]\n", "--- current:\n", "\ttheorem thm8 (c : ℝ) : Injective fun x => x + c := by \n", "\tsimp only [add_comm]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.11s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.310) simp_rw [add_comm]\n", "\t(-0.372) simp_rw [add_comm _ c]\n", "\t(-0.337) simp [add_comm]\n", "\t(-0.387) simp only [add_comm]\n", "Success: False\n", "--- current:\n", "\ttheorem thm9 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:05<00:00, 1.44s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.443) exact fun h n hn => h hn\n", "\t(-0.237) rintro h n hn\n", "\t(-0.162) intro h n hn\n", "Success: True\n", "--- current:\n", "\ttheorem thm10 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.20s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.201) intro x y hxy\n", "\t(-0.396) intro x y h\n", "\t(-0.099) exact injg.comp injf\n", "Success: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import transformers\n", "transformers.set_seed(43)\n", "\n", "results = {True: [], False: []}\n", "model.cpu()\n", "for theorem in evaluation_theorems:\n", " result = proofsearch.best_first_search(\n", " model, tokenizer, header, theorem, \n", " max_iters=16,\n", " temperatures=[0.0],\n", " num_samples=4,\n", " verbose=True\n", " )\n", " print(\"Success: %s\" % result['success'])\n", " results[result['success']].append(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are the successfully closed theorems and their generated proofs:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.600 closed\n", "theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by \n", "\trintro rfl\n", "\tapply Nat.le_add_right\n", "\n", "theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by \n", "\trw [← sub_pos]\n", "\trw [sub_pos]\n", "\texact id\n", "\n", "theorem thm3 (n : Nat) : n ≥ 0 := by \n", "\tsimp\n", "\n", "theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by \n", "\tintro h n\n", "\texact h n.le_succ\n", "\n", "theorem thm9 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by \n", "\texact fun h n hn => h hn\n", "\n", "theorem thm10 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by \n", "\texact injg.comp injf\n", "\n" ] } ], "source": [ "def print_result(result):\n", " print(result['theorem_statement'].replace('{}', '') + '\\n\\t' + '\\n\\t'.join(result['proof']) + '\\n')\n", "\n", "print(\"%.3f closed\" % (len(results[True])/ (len(results[True])+len(results[False]))))\n", "for result in results[True]:\n", " print_result(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Performance depends on model and search\n", "\n", "Our method closed 60.0% of the proofs. It is important to note that theorem proving performance is a function of the model $p_\\theta$, the search algorithm $\\mathcal{A}$, and the search budget $k$.\n", "\n", "$\\quad\\text{pass rate} = f(p_\\theta, \\mathcal{A}, k)$.\n", "\n", "In principle, we can improve theorem proving performance by improving the model, improving the search algorithm (for a fixed budget), or by increasing the budget. As a result, when comparing models it is important to account for possible performance variations that arise from the search algorithm or budget (e.g., by holding the search algorithm and budget fixed).\n", "\n", "\n", "Feel free to try out different temperatures, number of samples, etc. to see how performance varies." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "-----------\n", "\n", "### Evaluating neural theorem provers\n", "\n", "Above, we evaluated on hand-specified theorems. In practice, evaluation is done in two settings:\n", "\n", "\n", "1. Benchmarks\n", "2. Test split\n", "\n", "Benchmarks provide theorem statements that can characterize performance on a certain kind of theorems (e.g., competition problems or undergraduate math) and can test distribution shift for a model (e.g., competition problems for a model trained on mathlib).\n", "\n", "A test split measures performance on theorems drawn from the same distribution as the training set." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Benchmarks in neural theorem proving\n", "\n", "[**MiniF2F** [Zheng et al ICLR 2022]](https://arxiv.org/abs/2109.00110) is a widely used benchmark of 488 problems statements drawn from the AIME, AMC, and the International Mathematical Olympiad (IMO), as well as material from high-school and undergraduate mathematics courses. \n", "\n", "Specifically, given $x_\\text{formal statement}$, our model must produce a correct formal proof $y_1,\\ldots,y_{T_x}$. Below, you can look at examples:\n", "\n", "> As a reference point, we show the informal statement and informal proof, though the model is only given the formal statement. (The informal annotations were added in [Jiang et al ICLR 2023](https://arxiv.org/abs/2210.12283))." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "#### Problem: induction_divisibility_3divnto3m2n\n", "\n", "#### Formal statement \n", "\n", "```lean\n", "theorem induction_divisibility_3divnto3m2n\n", " (n : ℕ) :\n", " 3 ∣ n^3 + 2 * n := sorry\n", "```\n", "\n", "\n", "#### Informal statement\n", "\n", "Show that for any natural number $n \\in \\mathbb{N}$, $3 \\mid n^3 + 2n$ .\n", "#### Informal proof\n", "\n", "We show the result by induction on $n$. The result is trivial for $n=0$. Let us assume it is true for $n \\geq 0$.\n", "We have $(n+1)^3+2(n+1) = (n^3+3n^2+3n+1) + (2n+2) = n^3+2n + 3n^2+3n+3$. From the induction hypothesis, we know that $3$ divides $n^3+2n$. Since $3$ also divides $3n^2+3n+3$, the result is also true in $n+1$ and we have by induction that the result is true for all $n$." ], "text/plain": [ "<IPython.core.display.Markdown object>" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# !pip install datasets\n", "from IPython.display import Markdown\n", "import datasets\n", "\n", "minif2f = datasets.load_dataset('hoskinson-center/minif2f-lean4')\n", "\n", "idx = 35\n", "\n", "example = minif2f['validation'][idx]\n", "Markdown(\n", " '#### Problem: ' + example['id'] + \n", " '\\n\\n#### Formal statement \\n\\n' + '```lean\\n' + example['formal_statement'] + '\\n```\\n' + \n", " '\\n\\n#### Informal statement\\n\\n' + example['informal_stmt'] + \n", " '\\n#### Informal proof\\n\\n' + example['informal_proof']\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The problems vary in difficulty. Some problems can be closed easily (especially when the model deploys built in tactics such as `simp`), while others require long-form reasoning that can also be difficult to formalize. Here is a success case and a failure case (at least with this search budget):" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- current:\n", "\ttheorem mathd_algebra_15 (s : ℕ → ℕ → ℕ) (h₀ : ∀ a b, 0 < a ∧ 0 < b → s a b = a ^ (b : ℕ) + b ^ (a : ℕ)) : s 2 6 = 100 := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:04<00:00, 1.10s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t(-0.295) simp [h₀]\n", "\t(-0.426) simp [h₀, pow_succ]\n", "Success: True\n", "theorem mathd_algebra_15 (s : ℕ → ℕ → ℕ) (h₀ : ∀ a b, 0 < a ∧ 0 < b → s a b = a ^ (b : ℕ) + b ^ (a : ℕ)) : s 2 6 = 100 := by \n", "\tsimp [h₀]\n", "\n", "--- current:\n", "\ttheorem imo_2001_p6 (a b c d : ℕ) (h₀ : 0 < a ∧ 0 < b ∧ 0 < c ∧ 0 < d) (h₁ : d < c) (h₂ : c < b) (h₃ : b < a) (h₄ : a * c + b * d = (b + d + a - c) * (b + d + c - a)) : ¬Nat.Prime (a * b + c * d) := by \n", "\t\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:10<00:00, 2.54s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--- type-checked candidates:\n", "\t\n", "Success: False\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "evaluation_theorems = [\n", " \"\"\"theorem mathd_algebra_15 (s : ℕ → ℕ → ℕ) (h₀ : ∀ a b, 0 < a ∧ 0 < b → s a b = a ^ (b : ℕ) + b ^ (a : ℕ)) : s 2 6 = 100 := by {}\"\"\",\n", " \"\"\"theorem imo_2001_p6 (a b c d : ℕ) (h₀ : 0 < a ∧ 0 < b ∧ 0 < c ∧ 0 < d) (h₁ : d < c) (h₂ : c < b) (h₃ : b < a) (h₄ : a * c + b * d = (b + d + a - c) * (b + d + c - a)) : ¬Nat.Prime (a * b + c * d) := by {}\"\"\"\n", "]\n", "\n", "for theorem in evaluation_theorems:\n", " result = proofsearch.best_first_search(\n", " model, tokenizer, header, theorem, \n", " max_iters=16,\n", " temperatures=[0.0],\n", " num_samples=4,\n", " verbose=True\n", " )\n", " print(\"Success: %s\" % result['success'])\n", " if result['success']:\n", " print_result(result)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Other benchmarks\n", "\n", "[**ProofNet** [Azerbayev et al 2023]](https://arxiv.org/abs/2302.12433) is a new benchmark targeting undergraduate-level mathematics. It consists of 371 problems drawn from popular undergraduate pure mathematics textbooks, and covering topics such as real and complex analysis, linear algebra, abstract algebra, and topology. \n", "\n", "ProofNet theorems tend to depend on more background knowledge than competition problems, which means that a learned model needs to use theorems and definitions from a wider subset of mathematics. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "#### In-distribution test split\n", "We can also test the model on theorems from the same distribution as those it was trained on. For our model, this means splitting the mathlib4 repository into training theorems and evaluation theorems. A common splitting strategy is via uniform sampling. The resulting dataset covers a range of topics.\n", "\n", "We ran proof search with the `llmstep-mathlib4-pythia2.8b` model and it closed 48.8\\% of the theorems in the validation split. It used best first search with a beam size of 32. Below are some successful proofs from a smaller evaluation run on 200 theorems:\n", "\n", "\n", "> Note: `pylean` is less suitable as an interaction tool when evaluating mathlib theorems, since mathlib has many versions and the files can be large.\n", "> Instead, we used LeanDojo for interaction; see `proofsearch_dojo.py`. The 48.8% excludes theorems that did not initialize successfully." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "one_le_zpow\n", "\tlift n to ℕ using hn\n", "\trw [zpow_ofNat]\n", "\texact one_le_pow_of_one_le' H n\n", "\n" ] } ], "source": [ "import json\n", "successes = json.load(open('./data/successes_mathlib4_200_wellecks_llmstep-mathlib4-pythia2.8b.json'))['results']\n", "\n", "example = [x for x in successes if x['theorem'] == 'one_le_zpow'][0]\n", "example['theorem_statement'] = example['theorem']\n", "print_result(example)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Several works report nontrivial performance on mathlib, suggesting that existing models may already lead to useful tools for certain kinds of proofs in this setting. For instance, [Yang et al 2023](https://arxiv.org/pdf/2306.15626.pdf) report 47.5\\% pass@1 on mathlib3 (and 51.4\\% with a retrieval-augmented model), while [Polu et al ICLR 2023](https://arxiv.org/abs/2202.01344) achieve over 70\\% pass@8 on mathlib3 using a variant of reinforcement learning (and a large search budget). [Yang et al 2023](https://arxiv.org/pdf/2306.15626.pdf) show that a prompted GPT-4 substantially underperforms, achieving 28.8\\% pass@1. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Testing on unseen premises\n", "\n", "Above, we used test theorems that were sampled uniformly at random from mathlib. Alternative train/test splits can test other aspects of performance.\n", "\n", "For instance, consider what happens when mathlib is updated, say $M\\rightarrow M'$. Suppose that $M'$ has a new definition, e.g. `D'= def math_object_123`, and a new theorem `T' = theorem math_object_123_assoc`. Now consider another theorem $T''$ about $D'$ that uses $T'$ in its proof. Notice that our next step predictor $p_\\theta$ never observed $D'$ nor $T'$ during training. Thus it is very unlikely that $p_\\theta$ would use knowledge of $D'$ or $T'$ when proving $T''$. \n", "\n", "[Yang et al 2023](https://arxiv.org/pdf/2306.15626.pdf) create a `novel_premises` split in which the human-written proof uses at least one premise (theorem or definition) that is not in the training set. For instance, if our training set was $M$, we could have a test theorem whose proof uses $D'$ or $T'$. The authors find that performance is lower than on the `random` split, even when the model is augmented with retrieved premises:\n", "\n", "<img src=\"./images/leandojo_1.png\" width=\"700 px\">" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Implications for using machine learning in new proof developments\n", "\n", "The \"novel premises\" scenario described above comes up frequently when working on a new formalization project. Often a project develops its own definitions and theorems that are crucial for subsequent proofs. As one example, consider this student project from the University of Washington [eXperimental Lean Lab](https://sites.math.washington.edu/~jarod/xll.html):\n", "- [Banach Fixed Point Theorem by Lawrence Lin](https://github.com/Vilin97/LLL/blob/113f8bd767273484db6cd3c040c12f4a74c8fad4/lawrence/banach%20fpt.lean)\n", "\n", "The goal of the project is to prove the [Banach Fixed Point Theorem](https://en.wikipedia.org/wiki/Banach_fixed-point_theorem), which is stated at the bottom of the file (note that this project uses Lean3 syntax):\n", "\n", "<img src=\"./images/banach/banach_1.png\" width=\"1000 px\" />\n", "\n", "The theorem statement uses the author's own definition of `complete_metric_space` and `contraction_mapping`, which are located earlier in the file: \n", "\n", "<img src=\"./images/banach/banach_2.png\" width=\"1000 px\" />\n", "\n", "<img src=\"./images/banach/banach_3.png\" width=\"700 px\" />\n", "\n", "\n", "The proof uses various lemmas defined by the author. For instance, the first line of the proof uses `contraction_sequence_converges`:\n", "\n", "<img src=\"./images/banach/banach_4.png\" width=\"1000 px\" />\n", "\n", "<img src=\"./images/banach/banach_5.png\" width=\"900 px\" />\n", "\n", "None of these definitions (`complete_metric_space`, `contraction_mapping`) or lemmas (e.g. `contraction_sequence_converges`) are from Mathlib4, so our model $p_\\theta$ will have never seen them during training. Although $p_\\theta$ may get very good at using Mathlib4 definitions and theorems, it will not know to use these new definitions or lemmas in its proofs.\n", "\n", "Developing an effective way to handle novel contexts such as these is a key open problem in neural theorem proving. Moreover, this example shows that formalizing requires many auxiliary tasks beyond proving a single theorem, such as developing definitions or useful lemmas. These auxiliary tasks are understudied in the context of neural theorem proving." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Next steps\n", "\n", "In the final notebook, we will build a VSCode tool that generates next-step suggestions with our model, enabling a form of \"human-machine collaboration\". \n", "Building a tool is also helpful for thinking about practical requirements (e.g. runtime, generalizing to different projects)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }