#### Neural next-step prediction | part 1: data
Tutorial on neural theorem proving\
Author: Sean Welleck

----------------

#### High-level goal

Our goal is to train a neural next-step prediction model, $p(y_t|x_t)$. Here $x_t$ is a _proof state_, and $y_t$ is a next-step.

To do so, we will create a dataset $\mathcal{D}=\{(x_t,y_t)\}$ from human-written proofs. 

We can then train a neural next-step prediction model using a next-token prediction loss on the dataset.

#### Simple example

To see what proof states and next-steps look like, let's look at an example human-written theorem and proof:



In [1]:
!cat ../ntp_lean/examples/example0.lean

import Mathlib.Data.Nat.Prime

theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 := by 
  rw [Nat.coprime] at h  
  exact h  

We would like to transform this theorem and proof into a sequence of (proof_state, next_step) examples.

First, notice that the proof has two steps:

1. $y_1=$ `rw [Nat.coprime] at h`
2. $y_2=$ `exact h`

We can manually see the proof states by looking in VSCode. 

For example, placing the cursor before $y_1$ gives us the proof state $x_1$ (shown as "Tactic state"):

![title](images/proof_state_1.png)

That is, the image above corresponds to $(x_1,y_1)$ defined as:

  $x_1$: 
  ```
    m n : ℕ
    h : Nat.coprime m n
    ⊢ Nat.gcd m n = 1
  ```

  $y_1$: `rw [Nat.coprime] at h`


Similarly, we can get the proof state $x_2$ prior to the step $y_2$ (`exact h`):

![title](images/proof_state_2.png)

After step $y_2$, the proof is complete: the proof state $x_3$ says we have "No goals":

![title](images/proof_state_3.png)

In summary, it is possible to *manually* transform the theorem and proof into a sequence $[(x_1,y_1),(x_2,y_2),(x_3)]$.

## Automatically extracting proof states and next-steps 

To scale up data collection, we need a way to *automatically* extract proof states and next-steps from human-written proofs.



A new open-source library by Kaiyu Yang et al. called [LeanDojo](https://leandojo.org/) can automatically extract (proof state, next-step) pairs from Lean proofs. This idea originated in [Han et al ICLR 2022](https://github.com/jesse-michael-han/lean-step-public).  We will look at a simplified version of what LeanDojo does.

The core idea is to (1) transform a Lean file into abstract syntax trees using Lean, and (2) postprocess the abstract syntax tree into a dataset. Lean4's powerful metaprogramming functionality give us the tools to do this.

#### 1. Transform a Lean file

Conceptually, we want a script:

$\quad f_{\text{extract}}(\text{lean file})\rightarrow \text{ASTs}$,

We run a simplified version of the script `ExtractData.lean` from LeanDojo:
<!-- This command runs the `ExtractData.lean` script on our `example0.lean` file: -->

In [2]:
!cd ../../ && lake env lean --run partI_nextstep/ntp_lean/ExtractSimple.lean partI_nextstep/ntp_lean/examples/example0.lean

Input file: partI_nextstep/ntp_lean/examples/example0.lean
AST: partI_nextstep/ntp_lean/examples/example0.ast.json


The output file `example.ast.json` includes proof states and abstract syntax trees for the commands in `example0.lean`.

Here are the proof states for our example:

In [3]:
import json
ast = json.load(open('../../partI_nextstep/ntp_lean/examples/example0.ast.json'))
ast['tactics']

[{'stateBefore': 'm n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1',
  'stateAfter': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',
  'pos': 101,
  'endPos': 122},
 {'stateBefore': 'm n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1',
  'stateAfter': 'no goals',
  'pos': 127,
  'endPos': 134}]

Notice that the proof states are the ones we saw above in VSCode.

Here is the theorem statement's abstract syntax tree:

In [4]:
import pprint
pprint.pprint(ast['commandASTs'][1]['node'], depth=4)

{'args': [{'node': {'args': [...],
                    'info': 'none',
                    'kind': 'Lean.Parser.Command.declModifiers'}},
          {'node': {'args': [...],
                    'info': 'none',
                    'kind': 'Lean.Parser.Command.theorem'}}],
 'info': 'none',
 'kind': 'Lean.Parser.Command.declaration'}


#### Post-processing

Next, we post-process the extracted data into a dataset:

$\quad f_{\text{post-process}}(\text{ASTs}, \text{lean file})\rightarrow \{(x_t,y_t)\}.$

To do so, we use the collected proof states, traverse the AST, and recover the next-steps from the original Lean file.\
See `ntp_python.postprocess_ast` for an example (naive) traversal which extracts the theorem name.

Postprocessing `example0.lean` in this way gives us two $(x_t,y_t)$ pairs:

In [5]:
import sys
sys.path.append('../')
from ntp_python.postprocess_ast import get_theorem
from collections import defaultdict

theorem2examples = defaultdict(list)

lean_file = open('../../partI_nextstep/ntp_lean/examples/example0.lean').read()
for item in ast['tactics']:
    theorem = get_theorem(item['pos'], ast)
    theorem2examples[theorem].append({
        'x': item['stateBefore'],
        'y': lean_file[item['pos']:item['endPos']],
    })

for theorem, examples in theorem2examples.items():
    print("Theorem: ", theorem[:60], '...', sep=' ')
    for t, example in enumerate(examples):
        print(f"--- x{t+1} ---", example['x'], sep='\n')
        print(f"--- y{t+1} ---", example['y'], sep='\n')
        print()

Theorem:  theorem test_thm (m n : Nat) (h : m.coprime n) : m.gcd n = 1 ...
--- x1 ---
m n : ℕ h : Nat.coprime m n ⊢ Nat.gcd m n = 1
--- y1 ---
rw [Nat.coprime] at h

--- x2 ---
m n : ℕ h : Nat.gcd m n = 1 ⊢ Nat.gcd m n = 1
--- y2 ---
exact h



The core extraction code in LeanDojo is in [ExtractData.lean](https://github.com/lean-dojo/LeanDojo/blob/main/src/lean_dojo/data_extraction/ExtractData.lean) if you are curious.

## Scaling up data collection
In general, Lean projects are more complex than the simple example above. For instance, projects may:
1. have a large number of files
2. have dependencies on other files or projects
3. have complex file structure that our naive postprocessing doesn't handle

An example is the [mathlib project](https://leanprover-community.github.io/mathlib-overview.html). Mathlib  itself changes rapidly, and other Lean projects may depend on specific versions. [LeanDojo](https://leandojo.readthedocs.io/en/latest/index.html|) gives tools for handling this complexity.

#### Extracting 90k+ theorems with LeanDojo

The LeanDojo tool allows for extracting data from an *arbitrary Lean Github repository*. Conceptually,

$\quad f_{\text{leandojo}}(\text{lean repository})\rightarrow \mathcal{D}.$

It supports parallelism, keeps track of versions and dependencies for extracted data, and its post-processing handles more complex scenarios.

**Example**\
Here is what the interface would look like for [extracting a dataset from Mathlib4](https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean4.ipynb):

```python
    URL = "https://github.com/leanprover-community/mathlib4"
    COMMIT = "5a919533f110b7d76410134a237ee374f24eaaad"
    repo = LeanGitRepo(URL, COMMIT)
    traced_repo = trace(repo)
```

To avoid possible dependency issues, we won't run LeanDojo directly here. However, the LeanDojo authors provide the extracted data online,  so we will download it for this tutorial:

In [6]:
import json
import sys
import pprint
sys.path.append('../')
from ntp_python.data import _download_and_unpack

_download_and_unpack(
    tarball_url='https://zenodo.org/record/8040110/files/leandojo_benchmark_4_v1.tar.gz',
    data_dir='../data',
    overwrite=False
)

train = json.load(open('../data/leandojo_benchmark_4/random/train.json'))
train = [x for x in train if len(x['traced_tactics']) > 0]
print("Number of non-empty training proofs: ", len(train), sep=' ')
pprint.pprint(train[0])

Number of non-empty training proofs:  41944
{'commit': '5a919533f110b7d76410134a237ee374f24eaaad',
 'end': [308, 76],
 'file_path': 'Mathlib/Analysis/BoxIntegral/Box/Basic.lean',
 'full_name': 'BoxIntegral.Box.withBotCoe_inj',
 'start': [307, 1],
 'traced_tactics': [{'state_after': 'no goals',
                     'state_before': 'ι : Type u_1\n'
                                     'I✝ J✝ : Box ι\n'
                                     'x y : ι → ℝ\n'
                                     'I J : WithBot (Box ι)\n'
                                     '⊢ ↑I = ↑J ↔ I = J',
                     'tactic': 'simp only [Subset.antisymm_iff, ← '
                               'le_antisymm_iff, withBotCoe_subset_iff]'}],
 'url': 'https://github.com/leanprover-community/mathlib4'}


#### Next steps
In part 2, we'll train a neural next-step generation model on this mathlib4 dataset.