// Used by INSTALL.md

var ok:bool {:state ok()};
var x:int {:state reg(X)};
var y:int {:state reg(Y)};
var z:int {:state reg(Z)};
var w:int {:state reg(W)};

type nat:Type(0) := int_range(0, _);
type seq(a:Type(0)):Type(0) {:primitive} extern;
type map(a:Type(0), b:Type(0)):Type(0) extern;
type reg:Type(0) extern;
const X:reg extern;
const Z:reg extern;
function operator([]) #[a:Type(0)](s:seq(a), i:int):a extern; // TODO: requires clause
function operator([ := ]) #[a:Type(0)](s:seq(a), i:int, v:a):seq(a) extern; // TODO: requires clause
function operator([]) #[a:Type(0), b:Type(0)](m:map(a, b), key:a):b extern; // TODO: requires clause
function operator([ := ]) #[a:Type(0), b:Type(0)](m:map(a, b), key:a, v:b):map(a, b) extern;

operand_type opr:int @ reg := inout x | inout y | inout z | inout w;

#verbatim

///////////////////////////////////////////////////////////////////////////////
// Trusted machine definition
datatype reg = X | Y | Z | W

datatype ins =
  InsImm(dstImm:reg, imm:int)
| InsAdd(dstAdd:reg, srcAdd:reg)

datatype obool = OLe(r1:reg, r2:reg)

datatype codes = CNil | va_CCons(hd:code, tl:codes)
datatype code =
  Ins(ins:ins)
| Block(block:codes)
| IfElse(ifCond:obool, ifTrue:code, ifFalse:code)
| While(whileCond:obool, whileBody:code)

// States may be "good" (ok == true) or "bad" (ok == false)
datatype state = state(ok:bool, regs:map<reg, int>)

function getReg(s:state, r:reg):int { if r in s.regs then s.regs[r] else 0 }

predicate evalOBool(s:state, cond:obool)
{
    getReg(s, cond.r1) <= getReg(s, cond.r2)
}

// Evaluation:
// We want to prove that:
//   - a program never reaches a bad state (ok == false)
//   - if a program terminates, it terminates in a good state that satisfies some postcondition
// We do this by modeling all possible finite executions from state s0 to state sN:
//   - if a program goes bad, it will do so in a finite number of steps
//   - if a program terminates, it will do so in a finite number of steps
// evalCode(c, s0, sN) says that it is possible for code c to step from state s0 to state sN in
// a finite number of steps.
// For clarity's sake, when a program reaches a bad state, it stops (no more steps).

predicate evalIns(ins:ins, s0:state, s1:state)
{
    match ins
        case InsImm(dst, imm) => s1 == s0.(regs := s0.regs[dst := imm])
        case InsAdd(dst, src) => s1 == s0.(regs := s0.regs[dst := getReg(s0, dst) + getReg(s0, src)])
}

predicate evalBlock(block:codes, s0:state, sN:state)
{
    if block.CNil? then sN == s0
    else exists s1:state ::
        evalCode(block.hd, s0, s1) && (if s1.ok then evalBlock(block.tl, s1, sN) else s1 == sN)
}

predicate evalWhile(b:obool, c:code, n:nat, s0:state, sN:state)
    decreases c, n
{
    if n == 0 then !evalOBool(s0, b) && s0 == sN
    else exists s1:state ::
        evalOBool(s0, b) && evalCode(c, s0, s1) && (if s1.ok then evalWhile(b, c, n - 1, s1, sN) else s1 == sN)
}

predicate evalCode(c:code, s0:state, sN:state)
    decreases c, 0
{
    s0.ok
 && (match c
        case Ins(ins) => evalIns(ins, s0, sN)
        case Block(block) => evalBlock(block, s0, sN)
        case IfElse(cond, ifT, ifF) => if evalOBool(s0, cond) then evalCode(ifT, s0, sN) else evalCode(ifF, s0, sN)
        case While(cond, body) => exists n:nat :: evalWhile(cond, body, n, s0, sN)
    )
}

///////////////////////////////////////////////////////////////////////////////
// Untrusted Vale interface

type opr = reg
type va_operand_opr = opr
type va_value_opr = int
function method va_op_opr_reg(r:reg):opr { r }
function method va_op_cmp_reg(r:reg):opr { r }
predicate va_is_src_opr(o:opr, s:va_state) { true }
predicate va_is_dst_opr(o:opr, s:va_state) { true }

type va_code = code
type va_codes = codes
type va_state = state

function va_get_ok(s:va_state):bool { s.ok }
function va_get_reg(r:reg, s:va_state):int { getReg(s, r) }

function va_update_ok(sM:va_state, sK:va_state):va_state { sK.(ok := sM.ok) }
function va_update_reg(r:reg, sM:va_state, sK:va_state):va_state requires r in sM.regs { sK.(regs := sK.regs[r := sM.regs[r]]) }

function va_update_operand_opr(o:opr, sM:va_state, sK:va_state):va_state
    requires o in sM.regs
{
    va_update_reg(o, sM, sK)
}

predicate va_state_eq(s0:va_state, s1:va_state)
{
    s0.ok == s1.ok
 && s0.regs == s1.regs
}

function method va_CNil():codes { CNil }
function method va_Block(block:codes):code { Block(block) }
function method va_IfElse(ifb:obool, ift:code, iff:code):code { IfElse(ifb, ift, iff) }
function method va_While(whileb:obool, whilec:code):code { While(whileb, whilec) }
function method va_cmp_le(o1:opr, o2:opr):obool { OLe(o1, o2) }

function va_get_block(c:va_code):va_codes requires c.Block? { c.block }
function va_get_ifCond(c:code):obool requires c.IfElse? { c.ifCond }
function va_get_ifTrue(c:code):code requires c.IfElse? { c.ifTrue }
function va_get_ifFalse(c:code):code requires c.IfElse? { c.ifFalse }
function va_get_whileCond(c:code):obool requires c.While? { c.whileCond }
function va_get_whileBody(c:code):code requires c.While? { c.whileBody }

predicate{:opaque} evalWhileOpaque(b:obool, c:code, n:nat, s0:state, sN:state) { evalWhile(b, c, n, s0, sN) }
predicate{:opaque} evalCodeOpaque(c:code, s0:state, sN:state) { evalCode(c, s0, sN) }

// For the proof, we define a more liberal evaluation relation evalCode_lax that
// allows arbitrary "zombie steps" from a bad state to any other state.
// The proof proceeds as follows:
//   - suppose s0.ok; i.e., we start in a good state
//   - suppose evalCode(c, s0, sN); i.e., c takes non-zombie steps from s0 to sN
//   - then evalCode_lax(c, s0, sN) holds; i.e., c takes possibly-zombie steps from s0 to sN
//   - we first use evalCode_lax to prove some initial basic properties of the evaluation
//   - we finally prove that the possibly-zombie steps are in fact non-zombie steps,
//     by proving that ok holds for all intermediate states s0,s1,s2,...,sN
// Ultimately, we still prove precisely what we want: if s.ok and evalCode(c, s0, sN),
// then sN.ok and postcondition(sN).
// It may seem strange to allow zombie steps and then prove that there are no
// zombie steps, but this allows us to factor the proof for  procedures
// so that the proof of non-zombieness comes after other parts of the proof.
// Specifically, the "refined" lemmas will allow zombie steps, and only at the
// end of each refined lemma will we call the "abstract" lemma that
// retroactively eliminates the possibility that there were zombie steps.
predicate evalCode_lax(c:code, s0:state, sN:state) { s0.ok ==> evalCodeOpaque(c, s0, sN) }
predicate evalWhile_lax(b:obool, c:code, n:nat, s0:state, sN:state) { s0.ok ==> evalWhileOpaque(b, c, n, s0, sN) }

function va_eval_opr(s:state, o:opr):int
{
    getReg(s, o)
}

predicate valid_state(s:state)
{
  forall r:reg :: r in s.regs
}

predicate va_require(block0:va_codes, c:va_code, s0:va_state, sN:va_state)
{
    block0.va_CCons?
 && block0.hd == c
 && evalCode_lax(va_Block(block0), s0, sN)
 && valid_state(s0)
}

predicate va_ensure(b0:va_codes, b1:va_codes, s0:va_state, s1:va_state, sN:va_state)
{
    b0.va_CCons?
 && b0.tl == b1
 && (s1.ok ==> evalCode_lax(b0.hd, s0, s1))
 && evalCode_lax(va_Block(b1), s1, sN)
 && valid_state(s1)
}

predicate va_whileInv(b:obool, c:code, n:int, s0:va_state, sN:va_state)
{
    n >= 0
 && evalWhile_lax(b, c, n, s0, sN)
 && valid_state(s0)
}

lemma va_ins_lemma(b0:code, s0:va_state)
{
}

lemma va_lemma_block(b0:va_codes, s0:state, sN:state) returns(s1:state, c1:va_code, b1:va_codes)
    requires b0.va_CCons?
    requires evalCode_lax(va_Block(b0), s0, sN)
    ensures  b0 == va_CCons(c1, b1)
    ensures  evalCode_lax(c1, s0, s1)
    ensures  evalCode_lax(va_Block(b1), s1, sN)
{
    reveal_evalCodeOpaque();
    c1 := b0.hd;
    b1 := b0.tl;
    if (s0.ok)
    {
        assert evalBlock(b0, s0, sN);
        s1 :| evalCode(b0.hd, s0, s1) && (if s1.ok then evalBlock(b0.tl, s1, sN) else s1 == sN);
    }
    else
    {
        s1 := s0;
    }
}

lemma va_lemma_empty(s0:va_state, sN:va_state) returns(sM:va_state)
    requires evalCode_lax(va_Block(va_CNil()), s0, sN)
    ensures  s0 == sM
    ensures  s0.ok ==> s0 == sN
{
    reveal_evalCodeOpaque();
    sM := s0;
}

lemma va_lemma_while(b:obool, c:code, s0:va_state, sN:va_state) returns(n:nat, s1:va_state)
    requires evalCode_lax(While(b, c), s0, sN)
    ensures  evalWhile_lax(b, c, n, s0, sN)
    ensures  s1 == s0
{
    reveal_evalCodeOpaque();
    reveal_evalWhileOpaque();
    if (s0.ok)
    {
        assert evalCode(While(b, c), s0, sN);
        n :| evalWhile(b, c, n, s0, sN);
    }
    else
    {
        n := 0;
    }
    s1 := s0;
}

lemma va_lemma_whileTrue(b:obool, c:code, n:nat, s0:va_state, sN:va_state) returns(s0':va_state, s1:va_state)
    requires n > 0
    requires evalWhile_lax(b, c, n, s0, sN)
    ensures  s0' == s0
    ensures  s0.ok ==> evalOBool(s0, b)
    ensures  evalCode_lax(c, s0', s1)
    ensures  evalWhile_lax(b, c, n - 1, s1, sN)
{
    reveal_evalCodeOpaque();
    reveal_evalWhileOpaque();
    s0' := s0;
    if (s0.ok)
    {
        s1 :| evalOBool(s0, b) && evalCode(c, s0, s1) && (if s1.ok then evalWhile(b, c, n - 1, s1, sN) else s1 == sN);
    }
    else
    {
        s1 := s0;
    }
}

lemma va_lemma_whileFalse(b:obool, c:code, s0:va_state, sN:va_state) returns(s1:va_state)
    requires evalWhile_lax(b, c, 0, s0, sN)
    ensures  s1 == s0
    ensures  s0.ok ==> !evalOBool(s0, b)
    ensures  s0.ok ==> s1 == sN
{
    reveal_evalCodeOpaque();
    reveal_evalWhileOpaque();
    s1 := if s0.ok then sN else s0;
}

#endverbatim

procedure Imm(out dst:opr, inline imm:int)
    {:instruction Ins(InsImm(dst, imm))}
    ensures
        dst == imm;
{
    reveal evalCodeOpaque;
}

procedure Add(inout dst:opr, in src:opr)
    {:instruction Ins(InsAdd(dst, src))}
    ensures
        dst == old(dst) + old(src);
{
    reveal evalCodeOpaque;
}

procedure test1()
    requires
        x >= 0;
    ensures
        x >= 2;
    modifies
        x; y;
{
    Imm(y, 1);
    Add(x, y);
    Add(x, y);
}

#verbatim
predicate{:opaque} Post(x:int, y:int, g:int) { true }
predicate Inv(x:int, y:int, g:int, i:int) { Post(x, y, g) && i >= 0 }
predicate R(h:int, g:int) { h == g + 1 }
predicate Easy(a:int, b:int, c:int, d:int) { true }
#endverbatim

function Post(x:int, y:int, g:int):bool extern;
function Inv(x:int, y:int, g:int, i:int):bool extern;
function R(h:int, g:int):bool extern;
function Easy(a:int, b:int, c:int, d:int):bool extern;

// --------------------------------------------------------------

procedure LoopBody(ghost g:int, ghost i:int, ghost h1:int) returns(ghost h2:int)
    requires
        g <= x;
        0 <= h1;
        Inv(x, y, g, i);
    ensures
        x == old(x) + 1;
        h1 < h2;
        Inv(x, y, g, i + 1);
    reads
        y;
    modifies
        x; z;
{
    reveal Post;
    Imm(z, 1);
    Add(x, z);
    h2 := h1 + 1;
}

procedure Loop(ghost g:int, ghost h1:int) returns(ghost h2:int)
    requires/ensures
        g <= x;
    requires
        0 <= h1;
        Post(x, y, g);
    ensures
        x > y;
        Post(x, y, g);
        h1 <= h2;
    reads
        y;
    modifies
        x; z;
{
    ghost var i:nat := 0;
    h2 := h1;
    while (x <= y)
        invariant
            g <= x;
            0 <= i;
            Inv(x, y, g, i);
            0 <= h1;
            h1 <= h2;
        decreases
            y - x; // TODO: currently, this is ignored
    {
        h2 := LoopBody(g, i, h2);
        i := i + 1;
    }
}

procedure ULoopBody(ghost g:int, ghost i:int, ghost h1:int) returns(ghost h2:int)
    requires
        g <= x;
        0 <= h1;
        Inv(x, y, g, i);
    ensures
        x == old(x) + 1;
        h1 < h2;
        Inv(x, y, g, i + 1);
    reads
        y;
    modifies
        x; z;
{
    reveal Post;
    Imm(z, 1);
    Add(x, z);
    h2 := h1 + 1;
}

procedure ULoop(ghost g:int, ghost h1:int) returns(ghost h2:int)
    requires/ensures
        g <= x;
    requires
        0 <= h1;
        Inv(x, y, g, 0);
    ensures
        x > y;
        h1 <= h2;
    reads
        y;
    modifies
        x; z;
{
    ghost var i:nat := 0;
    h2 := h1;
    while (x <= y)
        invariant
            g <= x;
            0 <= i;
            Inv(x, y, g, i);
            h1 <= h2;
        decreases
            y - x;
    {
        h2 := ULoopBody(g, i, h2);
        i := i + 1;
    }
}

#verbatim
predicate{:opaque} RotateInv(n:nat, a:int, b:int, c:int, d:int) { true }

function RotateReg(i:int):reg
{
    var n := i % 4;
    if n == 0 then W else if n == 1 then X else if n == 2 then Y else Z
}
#endverbatim
function RotateInv(n:nat, a:int, b:int, c:int, d:int):bool extern;
function RotateReg(i:int):reg extern;

procedure RotateBody(inline n:nat, inout a:opr, inout b:opr, inout c:opr, inout d:opr)
    requires
        n > 0;
        RotateInv(n, a, b, c, d);
    ensures
        RotateInv(#nat(n - 1), b, c, d, a); // note: we could have given n type pos instead of nat to avoid the cast
{
    reveal RotateInv;
}

procedure RotateLoop(inline n:nat, inout a:opr, inout b:opr, inout c:opr, inout d:opr)
    {:recursive}
    requires
        @a == RotateReg(16 - n);
        @b == RotateReg(17 - n);
        @c == RotateReg(18 - n);
        @d == RotateReg(19 - n);
    requires
        RotateInv(n, a, b, c, d);
    ensures
        let arr := seq(a, b, c, d) in RotateInv(0, arr[n % 4], arr[(n + 1) % 4], arr[(n + 2) % 4], arr[(n + 3) % 4]);
{
    inline if (n > 0)
    {
        RotateBody(n, a, b, c, d);
        RotateLoop(#nat(n - 1), b, c, d, a);
    }
}

procedure RotateProc()
    modifies
        w; x; y; z;
    requires
        RotateInv(16, w, x, y, z);
    ensures
        RotateInv(0, w, x, y, z);
{
    RotateLoop(16, w, x, y, z);
}

// --------------------------------------------------------------

procedure InlineIfElse(inline n:nat)
    requires
        n < 3;
    requires/ensures
        x >= 0;
    ensures
        x == old(x) + n;
    modifies
        x; z;
{
    inline if (n == 0)
    {
        Imm(z, 0);
        Add(x, z);
    }
    else if (n == 1)
    {
        Imm(z, 1);
        Add(x, z);
    }
    else
    {
        Imm(z, 2);
        Add(x, z);
    }
}

// --------------------------------------------------------------

procedure InlineIfInTheMiddle(inline n:nat)
    requires
        n < 2;
    requires/ensures
        x >= 0;
    ensures
        x == old(x) + n + 2;
    modifies
        x; z;
{
    Imm(z, 1);
    ghost var h: int;
    h := 0;
    Add(x, z);
    inline if (n == 0)
    {
        h := h + 100;
        Imm(z, 0);
        Add(x, z);
    }
    else
    {
        Imm(z, 1);
        Add(x, z);
        h := h + 20;
    }
    Imm(z, 1);
    h := h + 5;
    Add(x, z);
    assert n == 0 ==> h == 105;
    assert n != 0 ==> h == 25;
}

// --------------------------------------------------------------

procedure Unroll(inline n:nat)
    {:recursive}
    requires/ensures
        x >= 0;
    ensures
        x >= old(x) + n;
    modifies
        x; z;
{
    inline if (n > 0)
    {
        Imm(z, n);
        Add(x, z);
        Unroll(#nat(n - 1));
    }
}

// --------------------------------------------------------------

procedure GhostReturn(ghost g:int, inline n:bool, in src:opr, ghost i:int) returns (ghost h:int)
    requires
        g <= x;
        Inv(x, y, g, i);
    ensures
        x == old(x) + 1;
        n || !n;
        src < 0 || 0 <= src;
        Inv(x, y, g, i + 1);
        R(h, g);
    reads
        y;
    modifies
        x; z;
{
    reveal Post;
    Imm(z, 1);
    Add(x, z);
    h := g + 1;
}

procedure GhostReturnCaller(ghost kg:int, in source:opr)
    requires/ensures
        kg <= x;
    requires
        x <= y;
        Post(x, y, kg);
    ensures
        x > y ==> Post(x, y, kg);
    reads
        y;
    modifies
        x; z; w;
{
    ghost var ki:int := 0;
    ghost var kh:int := 0;
    let n:int := 17;
    kh := GhostReturn(kg, true, source, ki);
    assert R(kh, kg);
    ki := ki + 1;
}

// --------------------------------------------------------------

procedure GhostReturnY(ghost g:int, inline n:bool, in src:opr, ghost i:int) returns (ghost h:int)
    requires
        true;
    ensures
        true;
    reads
        y;
    modifies
        x; z;
{
    Imm(z, 1);
    Add(x, z);
}

procedure GhostReturnCallerY(ghost kg:int, in source:opr)
    requires
         @source != X;
         @source != Z;
    requires
        Easy(x, y, kg, 0);
    ensures
        true;
    reads
        y;
    modifies
        x; z; w;
{
    ghost var ki := 0;
    ghost var kh:int := 0;
    kh := GhostReturnY(kg, true, source, ki);
    kh := GhostReturnY(kg, true, source, ki);
}

procedure UGhostReturnY(ghost g:int, inline n:bool, in src:opr, ghost i:int) returns (ghost h:int)
    reads
        y;
    modifies
        x; z;
{
    Imm(z, 1);
    Add(x, z);
}

procedure UGhostReturnCallerY(ghost kg:int, in source:opr)
    requires
        @source != X && @source != Z;
        Easy(x, y, kg, 0);
    reads
        y;
    modifies
        x; z; w;
{
    ghost var ki := 0;
    ghost var kh:int := 0;
    kh := UGhostReturnY(kg, true, source, ki);
    kh := UGhostReturnY(kg, true, source, ki);
}

// --------------------------------------------------------------

procedure ProcA(in src:opr)
    modifies x;
{
}

procedure ProcB(in source:opr)
    requires @source != X;
    modifies x;
{
    ProcA(source);
    ProcA(source);
}

// --------------------------------------------------------------