{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simonのアルゴリズム\n",
    "f(x) = f(x xor s) となる周期 s を求める\n",
    "\n",
    "入力を n bit とすると、O(n^3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from blueqat import Circuit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## f(x) を出力するオラクル\n",
    "n = 2, 周期 s = 1, 2, 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# s = 1\n",
    "def n2s1(c):\n",
    "    return c.cx[0, 3]\n",
    "\n",
    "# s = 2\n",
    "def n2s2(c):\n",
    "    return c.cx[1, 2].x[1].cx[1, 3].x[1]\n",
    "\n",
    "# s = 3\n",
    "def n2s3(c):\n",
    "    return c.ccx[0, 1, 2].ccx[0, 1, 3].x[:2].ccx[0, 1, 2].ccx[0, 1, 3].x[:2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "n = 3, s = 5\n",
    "\n",
    "    000: 101\n",
    "    001: 001\n",
    "    010: 100\n",
    "    011: 000\n",
    "    100: 001\n",
    "    101: 101\n",
    "    110: 000\n",
    "    111: 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def n3s3(c):\n",
    "    return c.x[1].cx[1,5].x[1:3].ccx[0,2,3].x[0,2].ccx[0,2,3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 各オラクルの出力"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'0000': 27, '1001': 22, '0100': 23, '1101': 28})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n2s1(Circuit(4).h[:2]).m[:].run(shots=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'0110': 22, '1110': 31, '0001': 22, '1001': 25})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n2s2(Circuit(4).h[:2]).m[:].run(shots=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'0011': 34, '1111': 27, '0100': 26, '1000': 13})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n2s3(Circuit(4).h[:2]).m[:].run(shots=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'010100': 11,\n",
       "         '110000': 15,\n",
       "         '101101': 13,\n",
       "         '100001': 9,\n",
       "         '001001': 14,\n",
       "         '111100': 12,\n",
       "         '000101': 13,\n",
       "         '011000': 13})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n3s3(Circuit(6).h[:3]).m[:].run(shots=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## sを求める"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_s(n, oracle):\n",
    "    r = list(oracle(Circuit(n*2).h[:n]).h[:n].m[:n].run(shots=100))\n",
    "    A = [[int(r[i][j]) for j in range(n)] for i in range(n)]\n",
    "    b = [0 for i in range(n)]\n",
    "\n",
    "    f,r,x = gauss_elimination(A,b)\n",
    "    s = 0\n",
    "    for d in x:\n",
    "        s <<= 1\n",
    "        s += d\n",
    "    return s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ({0,1}, xor, *)上のガウスの消去法\n",
    "O(n^3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gauss_elimination(A, b):\n",
    "    n = len(A)\n",
    "    m = len(A[0])\n",
    "    rank = 0\n",
    "    cj = 0\n",
    "    while rank < n and cj < m:\n",
    "        for i in range(rank+1, n):\n",
    "            if A[i][cj] > A[rank][cj]:\n",
    "                A[i], A[rank] = A[rank], A[i]\n",
    "                b[i], b[rank] = b[rank], b[i]\n",
    "        if A[rank][cj]:\n",
    "            d = A[rank][cj] ^ 1\n",
    "            for j in range(m):\n",
    "                A[rank][j] ^= d\n",
    "            b[rank] ^= d\n",
    "            for i in range(rank+1, n):\n",
    "                k = A[i][cj]\n",
    "                for j in range(m):\n",
    "                    A[i][j] ^= k*A[rank][j]\n",
    "                b[i] ^= k*b[rank]\n",
    "            rank += 1\n",
    "        cj += 1\n",
    "    for i in range(rank, n):\n",
    "        if b[i]:\n",
    "            return False, rank, None\n",
    "    if rank < m or cj < m:\n",
    "        ci = rank\n",
    "        for i in range(min(n, m)):\n",
    "            if A[i][i] == 0:\n",
    "                if i != ci:\n",
    "                    A[i], A[ci] = A[ci], A[i]\n",
    "                    b[i], b[ci] = b[ci], b[i]\n",
    "                ci += 1\n",
    "                for j in range(m):\n",
    "                    A[i][j] = 0\n",
    "                A[i][i] = 1\n",
    "                b[i] = 1\n",
    "    for j in range(m-1, -1, -1):\n",
    "        for i in range(j):\n",
    "            b[i] ^= b[j] * A[i][j]\n",
    "    return True, rank, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "5\n"
     ]
    }
   ],
   "source": [
    "print(get_s(2, n2s1))\n",
    "print(get_s(2, n2s2))\n",
    "print(get_s(2, n2s3))\n",
    "print(get_s(3, n3s3))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}