{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Language cascades | part 1: introduction\n",
    "Tutorial on neural theorem proving\\\n",
    "Author: Sean Welleck\n",
    "\n",
    "----------------\n",
    "\n",
    "Tools such as [Chat-GPT]() show the flexibility of modern neural language generators.\\\n",
    "Namely, a single generation system can often perform a task by simply providing a suitable *prompt*:\n",
    "\n",
    "$\\quad y=f(p_\\theta(\\cdot|x;P)),$\n",
    "\n",
    "where $x$ is an input, $P$ is a prompt, and $f(\\cdot)$ is a decoding algorithm.\n",
    "\n",
    "\n",
    "Let's look at one of these functions:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os\n",
    "import openai\n",
    "\n",
    "class LMFunction(object):\n",
    "    def __init__(self, engine='gpt-3.5-turbo', max_tokens=512):\n",
    "        self.engine = engine\n",
    "        self.max_tokens = max_tokens\n",
    "        self.openai = openai\n",
    "        openai.api_key = os.environ['OPENAI_API_KEY']\n",
    "\n",
    "    def _call_api(self, prompt, engine, max_tokens, max_retries=10, retry_wait=2):\n",
    "        for i in range(max_retries):\n",
    "            try:\n",
    "                return self.openai.ChatCompletion.create(\n",
    "                    model=engine,\n",
    "                    messages=[\n",
    "                        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "                        {\"role\": \"user\", \"content\": prompt}\n",
    "                    ],\n",
    "                    max_tokens=max_tokens,\n",
    "                    temperature=1.0\n",
    "                )\n",
    "            except self.openai.error.OpenAIError as e:\n",
    "                time.sleep(retry_wait)\n",
    "        return {'choices': [{'message': {'content': ''}}]}\n",
    "\n",
    "    def _parse_message(self, msg):\n",
    "        try:\n",
    "            content = msg['choices'][0]['message']['content']\n",
    "            content = content.strip().split('\\n')[0]\n",
    "        except (IndexError, KeyError):\n",
    "            content = ''\n",
    "        return content\n",
    "\n",
    "    def f(self, prompt, x):\n",
    "        msg = self._call_api(\n",
    "            prompt=prompt+x,\n",
    "            engine=self.engine,\n",
    "            max_tokens=self.max_tokens\n",
    "        )\n",
    "        evaluation = self._parse_message(msg)\n",
    "        return evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('1947256', False),\n",
       " ('1945256', False),\n",
       " ('1946256', False),\n",
       " ('1947176', True),\n",
       " ('1947176', True),\n",
       " ('1947056', False),\n",
       " ('1947456', False),\n",
       " ('1947256', False),\n",
       " ('1947256', False),\n",
       " ('1947256', False)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = \"\"\"Multiply two numbers. Here are some examples:\n",
    "432*342=147744\n",
    "98*19=1862\n",
    "\"\"\"\n",
    "\n",
    "p = LMFunction('gpt-4')\n",
    "\n",
    "outputs = [p.f(prompt, '872*2233=') for _ in range(10)]\n",
    "[(output, output==str(872*2233)) for output in outputs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1947176"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "872*2233"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result above shows two interesting things:\n",
    "1. The function is stochastic; it can return different answers each time it is called.\n",
    "2. The function is capable of producing a correct answer; it gets it correct 2 times.\n",
    "\n",
    "Therefore, one way a stochastic function like this is useful is to pair it with a reliable verifier.\n",
    "\n",
    "\n",
    "#### Why is this useful?\n",
    "The main attraction of these functions is their flexibility. For instance,\n",
    "it is easy to implement a function that maps language input to a call to Sympy:\n",
    "\n",
    "```\n",
    "    sympy_expression ~ f(\"872 times 2233 =\")\n",
    "    answer = g(sympy_expression)\n",
    "```\n",
    "where $g(\\cdot)$ is Sympy evaluation.\n",
    "\n",
    "This yields functionality that is difficult to program manually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(800+72)*2233\n",
      "1947176\n"
     ]
    }
   ],
   "source": [
    "from sympy.parsing.sympy_parser import parse_expr\n",
    "\n",
    "prompt = \"\"\"Solve the multiplication problem by writing input to a sympy function.\n",
    "Do not add additional text. Here are some examples:\n",
    "\n",
    "432 multiplied by 342 is: 432*342\n",
    "98* 19 is how much? 98*19\n",
    "\"\"\"\n",
    "\n",
    "p = LMFunction('gpt-4')\n",
    "\n",
    "g = parse_expr\n",
    "\n",
    "sympy_expression = p.f(prompt, 'There are 800+72 apples in a barrel. How many apples in 2233 barrels?')\n",
    "answer = g(sympy_expression)\n",
    "print(sympy_expression)\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1947176"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "872*2233"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Language cascades\n",
    "\n",
    "A [language model cascade [Dohan et al 2022]](https://arxiv.org/abs/2207.10342) formalizes the idea of composing multiple functions, some of which are stochastic functions implemented by a language model.\n",
    "\n",
    "The result can be seen as a probabilistic program, whose samples \"execute the function\", e.g.:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "$\\displaystyle 4544$"
      ],
      "text/plain": [
       "4544"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def multiply(x):\n",
    "    y1 = p.f(prompt, x)\n",
    "    y2 = g(y1)\n",
    "    return y2\n",
    "\n",
    "\n",
    "multiply('I bought 32 cases of apples, with one hundred and 42 apples per case. How many total apples?')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4544"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "32*142"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Cascades for neural theorem proving\n",
    "\n",
    "The two ideas mentioned above: composing multiple functions and using a verifier, make neural theorem proving a natural setting for language cascades.\n",
    "\n",
    "Namely, the goal will be to decompose theorem proving into different functions, then use the proof assistant to verify the final output.\n",
    "\n",
    "In the next notebook, we will see a cascade called [Draft, Sketch, Prove [Jiang et al ICLR 2023]](https://arxiv.org/abs/2210.12283) that does so with three components: \\\n",
    "**draft** an informal proof, **sketch** a formal proof, and **prove** the remaining gaps.\n",
    "\n",
    "The end result is a model and proof search procedure that is qualitatively much different than the next-step predictors we used in Part I."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}