#### Neural next-step prediction | part 4: evaluation
Tutorial on neural theorem proving\
Author: Sean Welleck

----------------

#### High-level goal

To get a quantitative estimate of our model's performance, we can perform proof search on an evaluation set of theorem statements. Intuitively, a model that is good at next-step suggestion will be effective for fully proving a theorem when paired with a suitable search algorithm. Therefore, proof search performance gives some measure of how useful the model's next-step suggestions will be when integrated into an interactive suggestion tool. 

First, we will evaluate on a small set of manually written theorem statements:



In [1]:
evaluation_theorems = [
 """theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by {}""",
 """theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by {}""",
 """theorem thm3 (n : Nat) : n ≥ 0 := by {}""",
 """theorem thm4 (x y z : ℝ) : x ≤ y → y ≤ z → x ≤ z := by {}""",
 """theorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by {}""",
 """theorem thm6: r ⊆ s → s ⊆ t → r ⊆ t := by {}""",
 """theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by {}""",
 """theorem thm8 (c : ℝ) : Injective fun x => x + c := by {}""",
 """theorem thm9 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by {}""",
 """theorem thm10 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by {}""",
]
 
# Shared header for the theorems above
header = """import Mathlib.Data.Nat.Factorization.Basic
import Mathlib.Data.Nat.Prime
import Mathlib.Data.Real.Basic

open Function
variable {α : Type _} (r s t : Set α)

"""

Let's load our model and run best-first search:

> We use a search budget that allows for running on a typical MacBook Pro in < 10 minutes. With a GPU it will be much faster.

Feel free to study the trajectories that are printed below, which include both successes and failures:

In [None]:
import sys
sys.path.append('../ntp_python/')

import proofsearch_pylean as proofsearch
model, tokenizer = proofsearch.load_model('wellecks/llmstep-mathlib4-pythia2.8b')

In [3]:
import transformers
transformers.set_seed(43)

results = {True: [], False: []}
model.cpu()
for theorem in evaluation_theorems:
 result = proofsearch.best_first_search(
 model, tokenizer, header, theorem, 
 max_iters=16,
 temperatures=[0.0],
 num_samples=4,
 verbose=True
 )
 print("Success: %s" % result['success'])
 results[result['success']].append(result)

--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	


100%|██████████| 4/4 [00:12<00:00, 3.09s/it]


--- type-checked candidates:
	(-0.035) intro h
	(-0.332) rintro ⟨rfl, rfl⟩
	(-0.033) rintro rfl
	(-0.181) rw [add_comm]
--- current:
	theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	rintro rfl


100%|██████████| 4/4 [00:04<00:00, 1.05s/it]


--- type-checked candidates:
	(-0.068) apply Nat.le_add_right
	(-0.176) simp
Success: True
--- current:
	theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.19s/it]


--- type-checked candidates:
	(-0.015) contrapose!
	(-0.346) rw [← sub_pos]
--- current:
	theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by 
	contrapose!


100%|██████████| 4/4 [00:13<00:00, 3.45s/it]


--- type-checked candidates:
	
--- current:
	theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by 
	rw [← sub_pos]


100%|██████████| 4/4 [00:08<00:00, 2.24s/it]


--- type-checked candidates:
	(-0.118) rw [sub_pos]
	(-0.345) simp only [sub_pos]
--- current:
	theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by 
	rw [← sub_pos]
	rw [sub_pos]


100%|██████████| 4/4 [00:08<00:00, 2.16s/it]


--- type-checked candidates:
	(-0.448) exact id
	(-0.630) exact fun h => h
Success: True
--- current:
	theorem thm3 (n : Nat) : n ≥ 0 := by 
	


100%|██████████| 4/4 [00:06<00:00, 1.62s/it]


--- type-checked candidates:
	(-0.254) cases n
	(-0.445) simp
	(-0.191) exact n.zero_le
Success: True
--- current:
	theorem thm4 (x y z : ℝ) : x ≤ y → y ≤ z → x ≤ z := by 
	


100%|██████████| 4/4 [00:07<00:00, 1.80s/it]


--- type-checked candidates:
	
Success: False
--- current:
	theorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.09s/it]


--- type-checked candidates:
	(-0.287) rw [Nat.gcd_comm]
--- current:
	theorem thm5 (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by 
	rw [Nat.gcd_comm]


100%|██████████| 4/4 [00:04<00:00, 1.07s/it]


--- type-checked candidates:
	(-0.342) rw [Nat.gcd_comm]
Success: False
--- current:
	theorem thm6: r ⊆ s → s ⊆ t → r ⊆ t := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.02s/it]


--- type-checked candidates:
	
Success: False
--- current:
	theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.08s/it]


--- type-checked candidates:
	(-0.301) intro h n
--- current:
	theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by 
	intro h n


100%|██████████| 4/4 [00:05<00:00, 1.31s/it]


--- type-checked candidates:
	(-0.182) apply h
	(-0.365) exact h n.le_succ
	(-0.190) exact h (Nat.le_succ _)
	(-0.236) exact h (Nat.le_succ n)
Success: True
--- current:
	theorem thm8 (c : ℝ) : Injective fun x => x + c := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.04s/it]


--- type-checked candidates:
	(-0.359) simp only [add_comm]
	(-0.314) simp [add_comm]
	(-0.305) simp_rw [add_comm]
--- current:
	theorem thm8 (c : ℝ) : Injective fun x => x + c := by 
	simp_rw [add_comm]


100%|██████████| 4/4 [00:04<00:00, 1.11s/it]


--- type-checked candidates:
	(-0.372) simp_rw [add_comm _ c]
	(-0.337) simp [add_comm]
	(-0.310) simp_rw [add_comm]
	(-0.387) simp only [add_comm]
--- current:
	theorem thm8 (c : ℝ) : Injective fun x => x + c := by 
	simp [add_comm]


100%|██████████| 4/4 [00:04<00:00, 1.11s/it]


--- type-checked candidates:
	(-0.387) simp only [add_comm]
	(-0.337) simp [add_comm]
	(-0.372) simp_rw [add_comm _ c]
	(-0.310) simp_rw [add_comm]
--- current:
	theorem thm8 (c : ℝ) : Injective fun x => x + c := by 
	simp only [add_comm]


100%|██████████| 4/4 [00:04<00:00, 1.11s/it]


--- type-checked candidates:
	(-0.310) simp_rw [add_comm]
	(-0.372) simp_rw [add_comm _ c]
	(-0.337) simp [add_comm]
	(-0.387) simp only [add_comm]
Success: False
--- current:
	theorem thm9 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by 
	


100%|██████████| 4/4 [00:05<00:00, 1.44s/it]


--- type-checked candidates:
	(-0.443) exact fun h n hn => h hn
	(-0.237) rintro h n hn
	(-0.162) intro h n hn
Success: True
--- current:
	theorem thm10 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.20s/it]

--- type-checked candidates:
	(-0.201) intro x y hxy
	(-0.396) intro x y h
	(-0.099) exact injg.comp injf
Success: True





Here are the successfully closed theorems and their generated proofs:

In [4]:
def print_result(result):
 print(result['theorem_statement'].replace('{}', '') + '\n\t' + '\n\t'.join(result['proof']) + '\n')

print("%.3f closed" % (len(results[True])/ (len(results[True])+len(results[False]))))
for result in results[True]:
 print_result(result)

0.600 closed
theorem thm1 (a b c : Nat) : a + b = c → a ≤ c := by 
	rintro rfl
	apply Nat.le_add_right

theorem thm2 (x y : ℝ) : x < y → 0 < y - x := by 
	rw [← sub_pos]
	rw [sub_pos]
	exact id

theorem thm3 (n : Nat) : n ≥ 0 := by 
	simp

theorem thm7 (f : ℕ → ℕ) : Monotone f → ∀ n, f n ≤ f (n + 1) := by 
	intro h n
	exact h n.le_succ

theorem thm9 (A B : Set ℕ) : A ⊆ B → ∀ n, n ∈ A → n ∈ B := by 
	exact fun h n hn => h hn

theorem thm10 (injg : Injective g) (injf : Injective f) : Injective fun x => g (f x) := by 
	exact injg.comp injf



### Performance depends on model and search

Our method closed 60.0% of the proofs. It is important to note that theorem proving performance is a function of the model $p_\theta$, the search algorithm $\mathcal{A}$, and the search budget $k$.

$\quad\text{pass rate} = f(p_\theta, \mathcal{A}, k)$.

In principle, we can improve theorem proving performance by improving the model, improving the search algorithm (for a fixed budget), or by increasing the budget. As a result, when comparing models it is important to account for possible performance variations that arise from the search algorithm or budget (e.g., by holding the search algorithm and budget fixed).


Feel free to try out different temperatures, number of samples, etc. to see how performance varies.

-----------

### Evaluating neural theorem provers

Above, we evaluated on hand-specified theorems. In practice, evaluation is done in two settings:


1. Benchmarks
2. Test split

Benchmarks provide theorem statements that can characterize performance on a certain kind of theorems (e.g., competition problems or undergraduate math) and can test distribution shift for a model (e.g., competition problems for a model trained on mathlib).

A test split measures performance on theorems drawn from the same distribution as the training set.

#### Benchmarks in neural theorem proving

[**MiniF2F** [Zheng et al ICLR 2022]](https://arxiv.org/abs/2109.00110) is a widely used benchmark of 488 problems statements drawn from the AIME, AMC, and the International Mathematical Olympiad (IMO), as well as material from high-school and undergraduate mathematics courses. 

Specifically, given $x_\text{formal statement}$, our model must produce a correct formal proof $y_1,\ldots,y_{T_x}$. Below, you can look at examples:

> As a reference point, we show the informal statement and informal proof, though the model is only given the formal statement. (The informal annotations were added in [Jiang et al ICLR 2023](https://arxiv.org/abs/2210.12283)).

In [5]:
# !pip install datasets
from IPython.display import Markdown
import datasets

minif2f = datasets.load_dataset('hoskinson-center/minif2f-lean4')

idx = 35

example = minif2f['validation'][idx]
Markdown(
 '#### Problem: ' + example['id'] + 
 '\n\n#### Formal statement \n\n' + '```lean\n' + example['formal_statement'] + '\n```\n' + 
 '\n\n#### Informal statement\n\n' + example['informal_stmt'] + 
 '\n#### Informal proof\n\n' + example['informal_proof']
)

#### Problem: induction_divisibility_3divnto3m2n

#### Formal statement 

```lean
theorem induction_divisibility_3divnto3m2n
 (n : ℕ) :
 3 ∣ n^3 + 2 * n := sorry
```


#### Informal statement

Show that for any natural number $n \in \mathbb{N}$, $3 \mid n^3 + 2n$ .
#### Informal proof

We show the result by induction on $n$. The result is trivial for $n=0$. Let us assume it is true for $n \geq 0$.
We have $(n+1)^3+2(n+1) = (n^3+3n^2+3n+1) + (2n+2) = n^3+2n + 3n^2+3n+3$. From the induction hypothesis, we know that $3$ divides $n^3+2n$. Since $3$ also divides $3n^2+3n+3$, the result is also true in $n+1$ and we have by induction that the result is true for all $n$.

The problems vary in difficulty. Some problems can be closed easily (especially when the model deploys built in tactics such as `simp`), while others require long-form reasoning that can also be difficult to formalize. Here is a success case and a failure case (at least with this search budget):

In [6]:
evaluation_theorems = [
 """theorem mathd_algebra_15 (s : ℕ → ℕ → ℕ) (h₀ : ∀ a b, 0 < a ∧ 0 < b → s a b = a ^ (b : ℕ) + b ^ (a : ℕ)) : s 2 6 = 100 := by {}""",
 """theorem imo_2001_p6 (a b c d : ℕ) (h₀ : 0 < a ∧ 0 < b ∧ 0 < c ∧ 0 < d) (h₁ : d < c) (h₂ : c < b) (h₃ : b < a) (h₄ : a * c + b * d = (b + d + a - c) * (b + d + c - a)) : ¬Nat.Prime (a * b + c * d) := by {}"""
]

for theorem in evaluation_theorems:
 result = proofsearch.best_first_search(
 model, tokenizer, header, theorem, 
 max_iters=16,
 temperatures=[0.0],
 num_samples=4,
 verbose=True
 )
 print("Success: %s" % result['success'])
 if result['success']:
 print_result(result)


--- current:
	theorem mathd_algebra_15 (s : ℕ → ℕ → ℕ) (h₀ : ∀ a b, 0 < a ∧ 0 < b → s a b = a ^ (b : ℕ) + b ^ (a : ℕ)) : s 2 6 = 100 := by 
	


100%|██████████| 4/4 [00:04<00:00, 1.10s/it]


--- type-checked candidates:
	(-0.295) simp [h₀]
	(-0.426) simp [h₀, pow_succ]
Success: True
theorem mathd_algebra_15 (s : ℕ → ℕ → ℕ) (h₀ : ∀ a b, 0 < a ∧ 0 < b → s a b = a ^ (b : ℕ) + b ^ (a : ℕ)) : s 2 6 = 100 := by 
	simp [h₀]

--- current:
	theorem imo_2001_p6 (a b c d : ℕ) (h₀ : 0 < a ∧ 0 < b ∧ 0 < c ∧ 0 < d) (h₁ : d < c) (h₂ : c < b) (h₃ : b < a) (h₄ : a * c + b * d = (b + d + a - c) * (b + d + c - a)) : ¬Nat.Prime (a * b + c * d) := by 
	


100%|██████████| 4/4 [00:10<00:00, 2.54s/it]

--- type-checked candidates:
	
Success: False





#### Other benchmarks

[**ProofNet** [Azerbayev et al 2023]](https://arxiv.org/abs/2302.12433) is a new benchmark targeting undergraduate-level mathematics. It consists of 371 problems drawn from popular undergraduate pure mathematics textbooks, and covering topics such as real and complex analysis, linear algebra, abstract algebra, and topology. 

ProofNet theorems tend to depend on more background knowledge than competition problems, which means that a learned model needs to use theorems and definitions from a wider subset of mathematics. 


#### In-distribution test split
We can also test the model on theorems from the same distribution as those it was trained on. For our model, this means splitting the mathlib4 repository into training theorems and evaluation theorems. A common splitting strategy is via uniform sampling. The resulting dataset covers a range of topics.

We ran proof search with the `llmstep-mathlib4-pythia2.8b` model and it closed 48.8\% of the theorems in the validation split. It used best first search with a beam size of 32. Below are some successful proofs from a smaller evaluation run on 200 theorems:


> Note: `pylean` is less suitable as an interaction tool when evaluating mathlib theorems, since mathlib has many versions and the files can be large.
> Instead, we used LeanDojo for interaction; see `proofsearch_dojo.py`. The 48.8% excludes theorems that did not initialize successfully.

In [7]:
import json
successes = json.load(open('./data/successes_mathlib4_200_wellecks_llmstep-mathlib4-pythia2.8b.json'))['results']

example = [x for x in successes if x['theorem'] == 'one_le_zpow'][0]
example['theorem_statement'] = example['theorem']
print_result(example)

one_le_zpow
	lift n to ℕ using hn
	rw [zpow_ofNat]
	exact one_le_pow_of_one_le' H n





Several works report nontrivial performance on mathlib, suggesting that existing models may already lead to useful tools for certain kinds of proofs in this setting. For instance, [Yang et al 2023](https://arxiv.org/pdf/2306.15626.pdf) report 47.5\% pass@1 on mathlib3 (and 51.4\% with a retrieval-augmented model), while [Polu et al ICLR 2023](https://arxiv.org/abs/2202.01344) achieve over 70\% pass@8 on mathlib3 using a variant of reinforcement learning (and a large search budget). [Yang et al 2023](https://arxiv.org/pdf/2306.15626.pdf) show that a prompted GPT-4 substantially underperforms, achieving 28.8\% pass@1. 

### Testing on unseen premises

Above, we used test theorems that were sampled uniformly at random from mathlib. Alternative train/test splits can test other aspects of performance.

For instance, consider what happens when mathlib is updated, say $M\rightarrow M'$. Suppose that $M'$ has a new definition, e.g. `D'= def math_object_123`, and a new theorem `T' = theorem math_object_123_assoc`. Now consider another theorem $T''$ about $D'$ that uses $T'$ in its proof. Notice that our next step predictor $p_\theta$ never observed $D'$ nor $T'$ during training. Thus it is very unlikely that $p_\theta$ would use knowledge of $D'$ or $T'$ when proving $T''$. 

[Yang et al 2023](https://arxiv.org/pdf/2306.15626.pdf) create a `novel_premises` split in which the human-written proof uses at least one premise (theorem or definition) that is not in the training set. For instance, if our training set was $M$, we could have a test theorem whose proof uses $D'$ or $T'$. The authors find that performance is lower than on the `random` split, even when the model is augmented with retrieved premises:

<img src="./images/leandojo_1.png" width="700 px">

#### Implications for using machine learning in new proof developments

The "novel premises" scenario described above comes up frequently when working on a new formalization project. Often a project develops its own definitions and theorems that are crucial for subsequent proofs. As one example, consider this student project from the University of Washington [eXperimental Lean Lab](https://sites.math.washington.edu/~jarod/xll.html):
- [Banach Fixed Point Theorem by Lawrence Lin](https://github.com/Vilin97/LLL/blob/113f8bd767273484db6cd3c040c12f4a74c8fad4/lawrence/banach%20fpt.lean)

The goal of the project is to prove the [Banach Fixed Point Theorem](https://en.wikipedia.org/wiki/Banach_fixed-point_theorem), which is stated at the bottom of the file (note that this project uses Lean3 syntax):

<img src="./images/banach/banach_1.png" width="1000 px" />

The theorem statement uses the author's own definition of `complete_metric_space` and `contraction_mapping`, which are located earlier in the file: 

<img src="./images/banach/banach_2.png" width="1000 px" />

<img src="./images/banach/banach_3.png" width="700 px" />


The proof uses various lemmas defined by the author. For instance, the first line of the proof uses `contraction_sequence_converges`:

<img src="./images/banach/banach_4.png" width="1000 px" />

<img src="./images/banach/banach_5.png" width="900 px" />

None of these definitions (`complete_metric_space`, `contraction_mapping`) or lemmas (e.g. `contraction_sequence_converges`) are from Mathlib4, so our model $p_\theta$ will have never seen them during training. Although $p_\theta$ may get very good at using Mathlib4 definitions and theorems, it will not know to use these new definitions or lemmas in its proofs.

Developing an effective way to handle novel contexts such as these is a key open problem in neural theorem proving. Moreover, this example shows that formalizing requires many auxiliary tasks beyond proving a single theorem, such as developing definitions or useful lemmas. These auxiliary tasks are understudied in the context of neural theorem proving.

#### Next steps

In the final notebook, we will build a VSCode tool that generates next-step suggestions with our model, enabling a form of "human-machine collaboration". 
Building a tool is also helpful for thinking about practical requirements (e.g. runtime, generalizing to different projects).