{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary \n",
"This notebook contains:\n",
"- torch implementations of a few linear algebra techniques:\n",
" - forward- and back-solving \n",
" - LDLt decomposition\n",
" - QR decomposition via Householder reflections\n",
"\n",
"- initial implementations of secure linear regression and Jonathan Bloom's [DASH](https://github.com/jbloom22/DASH/) that leverage PySyft for secure computation.\n",
"\n",
"These implementations linear regression and DASH are not currently strictly secure, in that a few final steps are performed on the local worker for now. That's because our implementations of LDLt decomposition, QR decomposition, etc. don't quite work for the PySyft `AdditiveSharingTensor` just yet. They definitely do in principle (because they're compositions of operations the SPDZ supports), but there are still a few details to hammer out.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Contents \n",
"[Ordinary least squares regression and LDLt decomposition](#OLSandLDLt)\n",
"* [LDLt decomposition, forward/back-solving](#LDLt)\n",
"* [Secure linear regression example](#OLS)\n",
"\n",
"[DASH](#dashqr)\n",
"* [QR decomposition via Householder transforms](#qr)\n",
"* [DASH example](#dash)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: Logging before flag parsing goes to stderr.\n",
"W0710 23:13:43.013911 4542494144 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/Users/andrew/.virtualenvs/pysyft/lib/python3.7/site-packages/tf_encrypted-0.5.6-py3.7-macosx-10.14-x86_64.egg/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'\n",
"W0710 23:13:43.023926 4542494144 deprecation_wrapper.py:119] From /Users/andrew/.virtualenvs/pysyft/lib/python3.7/site-packages/tf_encrypted-0.5.6-py3.7-macosx-10.14-x86_64.egg/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
"\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch as th\n",
"import syft as sy\n",
"from scipy import stats"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Setting up Sandbox...\n",
"\t- Hooking PyTorch\n",
"\t- Creating Virtual Workers:\n",
"\t\t- bob\n",
"\t\t- theo\n",
"\t\t- jason\n",
"\t\t- alice\n",
"\t\t- andy\n",
"\t\t- jon\n",
"\tStoring hook and workers as global variables...\n",
"\tLoading datasets from SciKit Learn...\n",
"\t\t- Boston Housing Dataset\n",
"\t\t- Diabetes Dataset\n",
"\t\t- Breast Cancer Dataset\n",
"\t- Digits Dataset\n",
"\t\t- Iris Dataset\n",
"\t\t- Wine Dataset\n",
"\t\t- Linnerud Dataset\n",
"\tDistributing Datasets Amongst Workers...\n",
"\tCollecting workers into a VirtualGrid...\n",
"Done!\n"
]
}
],
"source": [
"sy.create_sandbox(globals())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ordinary least squared regression and LDLt decomposition"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LDLt decomposition, forward/back-solving\n",
"\n",
"These are torch implementations of basic linear algebra routines we'll use to perform regression (and also in parts of the next section). \n",
"- Forward/back-solving allows us to solve triangular linear systems efficiently and stably.\n",
"- LDLt decomposition lets us write symmetric matrics as a product LDL^t where L is lower-triangular and D is diagonal (^t denotes transpose). It performs a role similar to Cholesky decomposition (which is normally available as method of a torch tensor), but doesn't require computing square roots. This makes makes LDLt a better fit for the secure setting."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def _eye(n):\n",
" \"\"\"th.eye doesn't seem to work after hooking torch, so just adding\n",
" a workaround for now.\n",
" \"\"\"\n",
" return th.FloatTensor(np.eye(n))\n",
"\n",
"\n",
"def ldlt_decomposition(x):\n",
" \"\"\"Decompose the square, symmetric, full-rank matrix X as X = LDL^t, where \n",
" - L is upper triangular\n",
" - D is diagonal.\n",
" \"\"\"\n",
" n, _ = x.shape\n",
" l, diag = _eye(n), th.zeros(n).float()\n",
"\n",
" for j in range(n):\n",
" diag[j] = x[j, j] - (th.sum((l[j, :j] ** 2) * diag[:j]))\n",
" for i in range(j + 1, n):\n",
" l[i, j] = (x[i, j] - th.sum(diag[:j] * l[i, :j] * l[j, :j])) / diag[j]\n",
"\n",
" return l, th.diag(diag), l.transpose(0, 1)\n",
"\n",
"\n",
"def back_solve(u, y):\n",
" \"\"\"Solve Ux = y for U a square, upper triangular matrix of full rank\"\"\"\n",
" n = u.shape[0]\n",
" x = th.zeros(n)\n",
" for i in range(n - 1, -1, -1):\n",
" x[i] = (y[i] - th.sum(u[i, i+1:] * x[i+1:])) / u[i, i]\n",
"\n",
" return x.reshape(-1, 1)\n",
"\n",
"\n",
"def forward_solve(l, y):\n",
" \"\"\"Solve Lx = y for L a square, lower triangular matrix of full rank.\"\"\"\n",
" n = l.shape[0]\n",
" x = th.zeros(n)\n",
" for i in range(0, n):\n",
" x[i] = (y[i] - th.sum(l[i, :i] * x[:i])) / l[i, i]\n",
"\n",
" return x.reshape(-1, 1)\n",
"\n",
"\n",
"def invert_triangular(t, upper=True):\n",
" \"\"\"\n",
" Invert by repeated forward/back-solving.\n",
" TODO: -Could be made more efficient with vectorized implementation of forward/backsolve\n",
" -detection and validation around triangularity/squareness\n",
" \"\"\"\n",
" solve = back_solve if upper else forward_solve\n",
" t_inv = th.zeros_like(t)\n",
" n = t.shape[0]\n",
" for i in range(n):\n",
" e = th.zeros(n, 1)\n",
" e[i] = 1.\n",
" t_inv[:, [i]] = solve(t, e)\n",
" return t_inv\n",
"\n",
"\n",
"def solve_symmetric(a, y):\n",
" \"\"\"Solve the linear system Ax = y where A is a symmetric matrix of full rank.\"\"\"\n",
" l, d, lt = ldlt_decomposition(a)\n",
" \n",
" # TODO: more efficient to just extract diagonal of d as 1D vector and scale?\n",
" x_ = forward_solve(l.mm(d), y)\n",
" return back_solve(lt, x_)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PASSED for tensor([[1., 2., 3.],\n",
" [2., 1., 2.],\n",
" [3., 2., 1.]])\n",
"PASSED for tensor([[1., 2., 3.],\n",
" [2., 1., 2.],\n",
" [3., 2., 1.]]), tensor([[1.],\n",
" [2.],\n",
" [3.]])\n"
]
}
],
"source": [
"\"\"\"\n",
"Basic tests for LDLt decomposition.\n",
"\"\"\"\n",
"\n",
"def _assert_small(x, failure_msg=None, threshold=1E-5):\n",
" norm = x.norm()\n",
" assert norm < threshold, failure_msg\n",
"\n",
"\n",
"def test_ldlt_case(a):\n",
" l, d, lt = ldlt_decomposition(a)\n",
" _assert_small(l - lt.transpose(0, 1))\n",
" _assert_small(l.mm(d).mm(lt) - a, 'Decomposition is inaccurate.')\n",
" _assert_small(l - th.tril(l), 'L is not lower triangular.')\n",
" _assert_small(th.triu(th.tril(d)) - d, 'D is not diagonal.')\n",
" print(f'PASSED for {a}')\n",
" \n",
"\n",
"def test_solve_symmetric_case(a, x):\n",
" y = a.mm(x)\n",
" _assert_small(solve_symmetric(a, y) - x)\n",
" print(f'PASSED for {a}, {x}')\n",
"\n",
" \n",
"a = th.tensor([[1, 2, 3],\n",
" [2, 1, 2],\n",
" [3, 2, 1]]).float()\n",
"\n",
"x = th.tensor([1, 2, 3]).float().reshape(-1, 1)\n",
"\n",
"test_ldlt_case(a)\n",
"test_solve_symmetric_case(a, x)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Secure linear regression example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Problem\n",
"We're solving \n",
"$$ \\min_\\beta \\|X \\beta - y\\|_2 $$\n",
"in the situation where the data $(X, y)$ is horizontally partitioned (each worker $w$ owns chunks $X_w, y_w$ of the rows of $X$ and $y$).\n",
"\n",
"#### Goals\n",
"We want to do this \n",
"* securely \n",
"* without network overhead or MPC-related costs that scale with the number of rows of $X$. \n",
"\n",
"#### Plan\n",
"\n",
"1. (**local plaintext compression**): each worker locally computes $X_w^t X_w$ and $X_w^t y_w$ in plain text. This is the only step that depends on the number of rows of X, and it's performed in plaintext.\n",
"2. (**secure summing**): securely compute the sums $$\\begin{align}X^t X &= \\sum_w X^t_w X_w \\\\ X^t y &= \\sum_w X^t_w y_w \\end{align}$$ as an AdditiveSharingTensor. Some worker or other party (here the local worker) will have a pointers to those two AdditiveSharingTensors.\n",
"3. (**secure solve**): We can then solve $X^tX\\beta = X^ty$ for $\\beta$ by a sequence of operations on those pointers (specifically, we apply `solve_symmetric` defined above).\n",
"\n",
"#### Example data: \n",
"The correct $\\beta$ is $[1, 2, -1]$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X = th.tensor(10 * np.random.randn(30000, 3))\n",
"y = (X[:, 0] + 2 * X[:, 1] - X[:, 2]).reshape(-1, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Split the data into chunks and send a chunk to each worker, storing pointers to chunks in two `MultiPointerTensor`s."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"workers = [alice, bob, theo]\n",
"crypto_provider = jon\n",
"chunk_size = int(X.shape[0] / len(workers))\n",
"\n",
"\n",
"def _get_chunk_pointers(data, chunk_size, workers):\n",
" return [\n",
" data[(i * chunk_size):((i+1)*chunk_size), :].send(worker)\n",
" for i, worker in enumerate(workers)\n",
" ] \n",
"\n",
"\n",
"X_ptrs = sy.MultiPointerTensor(\n",
" children=_get_chunk_pointers(X, chunk_size, workers))\n",
"y_ptrs = sy.MultiPointerTensor(\n",
" children=_get_chunk_pointers(y, chunk_size, workers))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### local compression\n",
"This is the only step that depends on the number of rows of $X, y$, and it's performed locally on each worker in plain text. The result is two `MultiPointerTensor`s with pointers to each workers' summand of $X^tX$ (or $X^ty$)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"Xt_ptrs = X_ptrs.transpose(0, 1)\n",
"\n",
"XtX_summand_ptrs = Xt_ptrs.mm(X_ptrs)\n",
"Xty_summand_ptrs = Xt_ptrs.mm(y_ptrs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### secure sum\n",
"We add those summands up in two steps:\n",
"- share each summand among all other workers\n",
"- move the resulting pointers to one place (here just the local worker) and add 'em up."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def _generate_shared_summand_pointers(\n",
" summand_ptrs, \n",
" workers, \n",
" crypto_provider):\n",
"\n",
" for worker_id, summand_pointer in summand_ptrs.child.items():\n",
" shared_summand_pointer = summand_pointer.fix_precision().share(\n",
" *workers, crypto_provider=crypto_provider)\n",
" yield shared_summand_pointer.get()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"XtX_shared = sum(\n",
" _generate_shared_summand_pointers(\n",
" XtX_summand_ptrs, workers, crypto_provider))\n",
"\n",
"Xty_shared = sum(_generate_shared_summand_pointers(\n",
" Xty_summand_ptrs, workers, crypto_provider))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### secure solve\n",
"The coefficient $\\beta$ is the solution to\n",
"$$X^t X \\beta = X^t y$$\n",
"\n",
"We solve for $\\beta$ using `solve_symmetric`. Critically, this is a composition of linear operations that should be supported by `AdditiveSharingTensor`. Unlike the classic Cholesky decomposition, the $LDL^t$ decomposition in step 1 does not involve taking square roots, which would be challenging.\n",
"\n",
"\n",
"**TODO**: there's still some additional work required to get `solve_symmetric` working for `AdditiveSharingTensor`, so we're performing the final linear solve publicly for now."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"beta = solve_symmetric(XtX_shared.get().float_precision(), Xty_shared.get().float_precision())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 1.0000],\n",
" [ 2.0000],\n",
" [-1.0000]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"beta"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DASH and QR-decomposition"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## QR decomposition\n",
"\n",
"A $m \\times n$ real matrix $A$ with $m \\geq n$ can be written as $$A = QR$$ for $Q$ orthogonal and $R$ upper triangular. This is helpful in solving systems of equations, among other things. It is also central to the compression idea of [DASH](https://arxiv.org/pdf/1901.09531.pdf). "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Full QR decomposition via Householder transforms, \n",
"following Numerical Linear Algebra (Trefethen and Bau).\n",
"\"\"\"\n",
"\n",
"def _apply_householder_transform(a, v):\n",
" return a - 2 * v.mm(v.transpose(0, 1).mm(a))\n",
"\n",
"\n",
"def _build_householder_matrix(v):\n",
" n = v.shape[0]\n",
" u = v / v.norm()\n",
" return _eye(n) - 2 * u.mm(u.transpose(0, 1))\n",
"\n",
"\n",
"def _householder_qr_step(a):\n",
" x = a[:, 0].reshape(-1, 1)\n",
" alpha = x.norm()\n",
" u = x.copy()\n",
"\n",
" # note: can get better stability by multiplying by sign(u[0, 0])\n",
" # (where sign(0) = 1); is this supported in the secure context?\n",
" u[0, 0] += u.norm()\n",
" \n",
" # is there a simple way of getting around computing the norm twice?\n",
" u /= u.norm()\n",
" a = _apply_householder_transform(a, u)\n",
"\n",
" return a, u\n",
"\n",
"\n",
"def _recover_q(householder_vectors):\n",
" \"\"\"\n",
" Build the matrix Q from the Householder transforms.\n",
" \"\"\"\n",
" n = len(householder_vectors)\n",
"\n",
" def _apply_transforms(x):\n",
" \"\"\"Trefethen and Bau, Algorithm 10.3\"\"\"\n",
" for k in range(n-1, -1, -1):\n",
" x[k:, :] = _apply_householder_transform(\n",
" x[k:, :], \n",
" householder_vectors[k])\n",
" return x\n",
"\n",
" m = householder_vectors[0].shape[0]\n",
" n = len(householder_vectors)\n",
" q = th.zeros(m, m)\n",
" \n",
" # Determine q by evaluating it on a basis\n",
" for i in range(m):\n",
" e = th.zeros(m, 1)\n",
" e[i] = 1.\n",
" q[:, [i]] = _apply_transforms(e)\n",
" \n",
" return q\n",
"\n",
"\n",
"def qr(a, return_q=True):\n",
" \"\"\"\n",
" Args:\n",
" a: shape (m, n), m >= n\n",
" return_q: bool, whether to reconstruct q \n",
" Returns:\n",
" orthogonal q of shape (m, m) (None if return_q is False)\n",
" upper-triangular of shape (m, n)\n",
" \"\"\"\n",
" m, n = a.shape\n",
" assert m >= n, \\\n",
" f\"Passed a of shape {a.shape}, must have a.shape[0] >= a.shape[1]\"\n",
"\n",
" r = a.copy()\n",
" householder_unit_normal_vectors = []\n",
"\n",
" for k in range(n):\n",
" r[k:, k:], u = _householder_qr_step(r[k:, k:])\n",
" householder_unit_normal_vectors.append(u)\n",
" if return_q:\n",
" q = _recover_q(householder_unit_normal_vectors)\n",
" else:\n",
" q = None\n",
" return q, r\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PASSED for \n",
"tensor([[1., 0., 1.],\n",
" [1., 1., 0.],\n",
" [0., 1., 1.]])\n",
"\n",
"PASSED for \n",
"tensor([[1., 0., 1.],\n",
" [1., 1., 0.],\n",
" [0., 1., 1.],\n",
" [1., 1., 1.]])\n",
"\n"
]
}
],
"source": [
"\"\"\"\n",
"Basic tests for QR decomposition\n",
"\"\"\"\n",
"\n",
"def _test_qr_case(a): \n",
" \n",
" q, r = qr(a)\n",
" \n",
" # actually have QR = A\n",
" _assert_small(q.mm(r) - a, \"QR = A failed\")\n",
"\n",
" # Q is orthogonal\n",
" m, _ = a.shape\n",
" _assert_small(\n",
" q.mm(q.transpose(0, 1)) - _eye(m),\n",
" \"QQ^t = I failed\"\n",
" )\n",
" \n",
" # R is upper triangular\n",
" lower_triangular_entries = th.tensor([\n",
" r[i, j].item() for i in range(r.shape[0]) \n",
" for j in range(i)])\n",
"\n",
" _assert_small(\n",
" lower_triangular_entries,\n",
" \"R is not upper triangular\"\n",
" )\n",
"\n",
" print(f\"PASSED for \\n{a}\\n\")\n",
"\n",
"\n",
"def test_qr():\n",
" _test_qr_case(\n",
" th.tensor([[1, 0, 1],\n",
" [1, 1, 0],\n",
" [0, 1, 1]]).float()\n",
" )\n",
"\n",
" _test_qr_case(\n",
" th.tensor([[1, 0, 1],\n",
" [1, 1, 0],\n",
" [0, 1, 1],\n",
" [1, 1, 1],]).float()\n",
" )\n",
" \n",
"test_qr()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DASH implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We follow https://github.com/jbloom22/DASH/. \n",
"\n",
"The overall structure is roughly analogous to the linear regression example above.\n",
"\n",
"- There's a local compression step that's performed separately on each worker in plaintext.\n",
"- We leverage PySyft's SMCP features to perform secure summation.\n",
"- For now, the last few steps are performed by a single player (the local worker). \n",
" - Again, this could be performed securely, but there are still a few hitches with getting our torch implementation of QR decomposition to work for an `AdditiveSharingTensor`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def _generate_worker_data_pointers(\n",
" n, m, k, worker,\n",
" beta_correct, gamma_correct, epsilon=0.01\n",
"):\n",
" \"\"\"\n",
" Return pointers to worker-level data.\n",
" Args:\n",
" n: number of rows\n",
" m: number of transient\n",
" k: number of covariates\n",
" beta_correct: coefficients for transient features (tensor of shape (m, 1))\n",
" gamma_correct: coefficients for covariates (tensor of shape (k, 1))\n",
" epsilon: scale of noise added to response\n",
" Return:\n",
" y, X, C: pointers to response, transients, and covariates\n",
" \"\"\"\n",
" X = th.randn(n, m).send(worker)\n",
" C = th.randn(n, k).send(worker)\n",
" \n",
" y = (X.mm(beta_correct.copy().send(worker)).reshape(-1, 1) + \n",
" C.mm(gamma_correct.copy().send(worker)).reshape(-1, 1))\n",
"\n",
" y += (epsilon * th.randn(n, 1)).send(worker)\n",
"\n",
" return y, X, C\n",
"\n",
"\n",
"def _dot(x):\n",
" return (x * x).sum(dim=0).reshape(-1, 1)\n",
"\n",
"\n",
"def _secure_sum(worker_level_pointers, workers, crypto_provider):\n",
" \"\"\"\n",
" Securely add up an interable of pointers to (same-sized) tensors.\n",
" Args:\n",
" worker_level_pointers: iterable of pointer tensors\n",
" workers: list of workers\n",
" crypto_provider: worker\n",
" Returns:\n",
" AdditiveSharingTensor shared among workers\n",
" \"\"\"\n",
" return sum([\n",
" p.fix_precision(precision_fractional=10).share(*workers, crypto_provider=crypto_provider).get()\n",
" for p in worker_level_pointers\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def dash_example_secure(\n",
" workers, crypto_provider,\n",
" n_samples_by_worker, m, k,\n",
" beta_correct, gamma_correct, \n",
" epsilon=0.01\n",
"):\n",
" \"\"\"\n",
" Args:\n",
" workers: list of workers\n",
" crypto_provider: worker\n",
" n_samples_by_worker: dict mapping worker ids to ints (number of rows of data)\n",
" m: number of transients\n",
" k: number of covariates\n",
" beta_correct: coefficient for transient features\n",
" gamma_correct: coefficient for covariates\n",
" epsilon: scale of noise added to response\n",
" Returns:\n",
" beta, sigma, tstat, pval: coefficient of transients and accompanying statistics\n",
" \"\"\"\n",
" # Generate each worker's data\n",
" worker_data_pointers = {\n",
" p: _generate_worker_data_pointers(\n",
" n, m, k, workers[p],\n",
" beta_correct, gamma_correct,\n",
" epsilon=epsilon)\n",
" for p, n in n_samples_by_worker.items()\n",
" }\n",
"\n",
" # to be populated with pointers to results of local, worker-level computations\n",
" Ctys, CtXs, yys, Xys, XXs, Rs = {}, {}, {}, {}, {}, {}\n",
"\n",
" def _sum(pointers):\n",
" return _secure_sum(pointers, list(players.values()), crypto_provider) \n",
"\n",
" # worker-level compression step\n",
" for p, (y, X, C) in worker_data_pointers.items():\n",
" \n",
" # perform worker-level compression step\n",
" yys[p] = y.norm()\n",
" Xys[p] = X.transpose(0, 1).mm(y)\n",
" XXs[p] = _dot(X)\n",
" \n",
" Ctys[p] = C.transpose(0, 1).mm(y)\n",
" CtXs[p] = C.transpose(0, 1).mm(X)\n",
" _, R_full = qr(C, return_q=False)\n",
" Rs[p] = R_full[:k, :]\n",
" \n",
" # Perform secure sum \n",
" # - We're returning result to the local worker and computing there for the rest\n",
" # of the way, but should be possible to compute via SMPC (on a pointers to AdditiveSharingTensors)\n",
" # - still afew minor-looking issues with implementing invert_triangular/qr for \n",
" # AdditiveSharingTensor\n",
" yy = _sum(yys.values()).get().float_precision()\n",
" Xy = _sum(Xys.values()).get().float_precision()\n",
" XX = _sum(XXs.values()).get().float_precision()\n",
" \n",
" Cty = _sum(Ctys.values()).get().float_precision()\n",
" CtX = _sum(CtXs.values()).get().float_precision()\n",
"\n",
" # Rest is done publicly on the local worker for now\n",
" _, R_public = qr(\n",
" th.cat([R.get() for R in Rs.values()], dim=0),\n",
" return_q=False)\n",
"\n",
" invR_public = invert_triangular(R_public[:k, :])\n",
"\n",
" Qty = invR_public.transpose(0, 1).mm(Cty)\n",
" QtX = invR_public.transpose(0, 1).mm(CtX)\n",
"\n",
" QtXQty = QtX.transpose(0, 1).mm(Qty)\n",
" QtyQty = _dot(Qty)\n",
" QtXQtX = _dot(QtX)\n",
"\n",
" yyq = yy - QtyQty\n",
" Xyq = Xy - QtXQty\n",
" XXq = XX - QtXQtX\n",
"\n",
" d = sum(n_samples_by_worker.values()) - k - 1\n",
" beta = Xyq / XXq\n",
" sigma = ((yyq / XXq - (beta ** 2)) / d).abs() ** 0.5\n",
" tstat = beta / sigma\n",
" pval = 2 * stats.t.cdf(-abs(tstat), d)\n",
"\n",
" return beta, sigma, tstat, pval"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[1.0198],\n",
" [0.9758],\n",
" [0.9643],\n",
" [1.0004],\n",
" [1.0150],\n",
" [0.9944],\n",
" [0.9961],\n",
" [1.0318],\n",
" [1.0002],\n",
" [0.9830],\n",
" [0.9790],\n",
" [0.9926],\n",
" [1.0008],\n",
" [0.9697],\n",
" [0.9954],\n",
" [0.9995],\n",
" [0.9960],\n",
" [0.9767],\n",
" [1.0059],\n",
" [0.9838],\n",
" [0.9911],\n",
" [1.0179],\n",
" [1.0080],\n",
" [0.9829],\n",
" [0.9937],\n",
" [0.9819],\n",
" [1.0188],\n",
" [0.9811],\n",
" [0.9971],\n",
" [0.9866],\n",
" [1.0117],\n",
" [0.9953],\n",
" [0.9966],\n",
" [0.9952],\n",
" [0.9957],\n",
" [0.9860],\n",
" [1.0206],\n",
" [0.9928],\n",
" [0.9925],\n",
" [1.0149],\n",
" [0.9587],\n",
" [0.9851],\n",
" [1.0102],\n",
" [1.0127],\n",
" [1.0143],\n",
" [1.0050],\n",
" [0.9926],\n",
" [0.9646],\n",
" [0.9966],\n",
" [0.9906],\n",
" [1.0212],\n",
" [0.9948],\n",
" [1.0253],\n",
" [0.9936],\n",
" [0.9834],\n",
" [0.9770],\n",
" [0.9885],\n",
" [0.9890],\n",
" [0.9954],\n",
" [0.9900],\n",
" [0.9795],\n",
" [0.9657],\n",
" [0.9836],\n",
" [1.0042],\n",
" [0.9957],\n",
" [0.9929],\n",
" [1.0127],\n",
" [0.9869],\n",
" [0.9969],\n",
" [1.0172],\n",
" [1.0030],\n",
" [0.9844],\n",
" [1.0121],\n",
" [1.0071],\n",
" [0.9954],\n",
" [0.9936],\n",
" [0.9954],\n",
" [1.0070],\n",
" [0.9928],\n",
" [0.9900],\n",
" [0.9970],\n",
" [0.9992],\n",
" [0.9851],\n",
" [0.9942],\n",
" [0.9710],\n",
" [0.9799],\n",
" [0.9675],\n",
" [1.0246],\n",
" [1.0085],\n",
" [0.9906],\n",
" [0.9984],\n",
" [1.0182],\n",
" [0.9805],\n",
" [0.9905],\n",
" [1.0034],\n",
" [0.9965],\n",
" [0.9983],\n",
" [0.9973],\n",
" [0.9872],\n",
" [0.9937]]), tensor([[0.0032],\n",
" [0.0032],\n",
" [0.0031],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0031],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0031],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0031],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0031],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032],\n",
" [0.0032]]), tensor([[319.5643],\n",
" [309.3729],\n",
" [306.5470],\n",
" [315.1175],\n",
" [318.3896],\n",
" [313.9088],\n",
" [314.0636],\n",
" [322.6956],\n",
" [315.1400],\n",
" [311.2341],\n",
" [310.3611],\n",
" [313.5351],\n",
" [315.4006],\n",
" [307.5811],\n",
" [313.6877],\n",
" [314.7977],\n",
" [314.0446],\n",
" [309.6685],\n",
" [316.4051],\n",
" [311.1368],\n",
" [313.2263],\n",
" [319.4759],\n",
" [316.2847],\n",
" [310.5904],\n",
" [313.5548],\n",
" [311.4909],\n",
" [319.7083],\n",
" [310.5916],\n",
" [314.3991],\n",
" [311.7694],\n",
" [317.7445],\n",
" [314.2973],\n",
" [313.8512],\n",
" [313.7931],\n",
" [313.8863],\n",
" [311.7593],\n",
" [320.1660],\n",
" [313.1064],\n",
" [312.8649],\n",
" [318.5549],\n",
" [305.3120],\n",
" [311.6055],\n",
" [317.7577],\n",
" [318.0811],\n",
" [318.6992],\n",
" [316.0238],\n",
" [313.5295],\n",
" [306.4926],\n",
" [314.0880],\n",
" [312.5821],\n",
" [319.9753],\n",
" [313.6790],\n",
" [321.0503],\n",
" [313.4674],\n",
" [310.9974],\n",
" [310.1643],\n",
" [312.6641],\n",
" [312.5327],\n",
" [313.7359],\n",
" [312.9331],\n",
" [310.3143],\n",
" [307.3026],\n",
" [311.3515],\n",
" [316.4857],\n",
" [314.6939],\n",
" [313.4534],\n",
" [318.1405],\n",
" [312.0404],\n",
" [314.7274],\n",
" [318.8090],\n",
" [315.5815],\n",
" [311.3936],\n",
" [317.9680],\n",
" [317.0735],\n",
" [314.1143],\n",
" [313.6470],\n",
" [313.5154],\n",
" [316.9404],\n",
" [313.7480],\n",
" [312.7427],\n",
" [314.7064],\n",
" [314.5428],\n",
" [311.3454],\n",
" [313.8516],\n",
" [308.4349],\n",
" [310.3224],\n",
" [307.0361],\n",
" [320.6432],\n",
" [317.1022],\n",
" [312.7644],\n",
" [314.5857],\n",
" [319.2417],\n",
" [310.1894],\n",
" [312.5483],\n",
" [315.9466],\n",
" [314.3875],\n",
" [314.4272],\n",
" [314.6446],\n",
" [312.2668],\n",
" [313.8906]]), array([[0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.],\n",
" [0.]]))"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"players = {\n",
" worker.id: worker \n",
" for worker in [alice, bob, theo]\n",
"}\n",
"\n",
"# de\n",
"n_samples_by_player = {\n",
" alice.id: 100000,\n",
" bob.id: 200000,\n",
" theo.id: 100000\n",
"}\n",
"\n",
"crypto_provider = jon\n",
"\n",
"m = 100\n",
"k = 3\n",
"d = sum(n_samples_by_player.values()) - k - 1\n",
"\n",
"\n",
"beta_correct = th.ones(m, 1)\n",
"gamma_correct = th.ones(k, 1)\n",
"\n",
"dash_example_secure(\n",
" players, crypto_provider, \n",
" n_samples_by_player, m, k, \n",
" beta_correct, gamma_correct)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}