{ "cells": [ { "cell_type": "markdown", "id": "f3eabf9b-80ee-45aa-a476-e7c113574165", "metadata": {}, "source": [ "# Simple Applications of Homomorphic Encryption" ] }, { "cell_type": "markdown", "id": "07e91816", "metadata": {}, "source": [ "#pip install tenseal torch pandas numpy pytest " ] }, { "cell_type": "code", "execution_count": 1, "id": "56180778-a1aa-4445-99ba-758ce2220708", "metadata": {}, "outputs": [], "source": [ "import tenseal as ts\n", "import torch\n", "import pandas as pd\n", "import random\n", "import numpy as np\n", "\n", "import pytest\n", "from time import time\n", "import sys\n", "import os" ] }, { "cell_type": "markdown", "id": "f1d1a443-af08-4448-88f9-0997b7dc3883", "metadata": {}, "source": [ "## TenSEAL Context\n", "\n", "The TenSEALContext is a special object that holds different encryption keys and parameters for you, so that you only need to use a single object to make your encrypted computation instead of managing all the keys and the HE details. Basically, you will want to create a single TenSEALContext before doing your encrypted computation. Let's see how to create one !" ] }, { "cell_type": "code", "execution_count": 2, "id": "d72654ae-456b-475c-a8bd-536a8cd5f20b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<tenseal.enc_context.Context at 0x1aba5ad2160>" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "context = ts.Context(\n", " ts.SCHEME_TYPE.CKKS,\n", " poly_modulus_degree=8192,\n", " coeff_mod_bit_sizes=[60, 40, 40, 60]\n", ")\n", "context.global_scale = 2**40\n", "# this key is needed for doing dot-product operations\n", "context.generate_galois_keys()\n", "context" ] }, { "cell_type": "markdown", "id": "2cebe4b1-c5dd-41c9-a63c-3206559b958d", "metadata": {}, "source": [ "## Encrypt the data" ] }, { "cell_type": "code", "execution_count": 3, "id": "a7d079ea-f9c2-47bf-9f11-b43c32bb16d5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(<tenseal.tensors.ckksvector.CKKSVector at 0x1aba5ad21f0>,\n", " <tenseal.tensors.ckksvector.CKKSVector at 0x1aba5ad26a0>)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v1 = [0, 1, 2, 3, 4]\n", "v2 = [4, 3, 2, 1, 0]\n", "\n", "enc_v1 = ts.ckks_vector(context, v1)\n", "enc_v2 = ts.ckks_vector(context, v2)\n", "(enc_v1, enc_v2)" ] }, { "cell_type": "code", "execution_count": 4, "id": "94359057-c26d-4ded-9f52-8157542d7e63", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5]\n" ] } ], "source": [ "print(enc_v1.shape)" ] }, { "cell_type": "markdown", "id": "5d9de7f2-08f5-48b0-a9f3-1ae94721f5aa", "metadata": {}, "source": [ "## Compute different operations over the two vectors locally and decrypt the results" ] }, { "cell_type": "markdown", "id": "3b972d4a-e008-4471-9f21-e51e8eb8f8ec", "metadata": {}, "source": [ "### Addition" ] }, { "cell_type": "code", "execution_count": 5, "id": "9b2dc2ef-e895-42fa-af1d-d5119c6f22ad", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[4.000000000112872,\n", " 4.000000000242888,\n", " 4.000000000972386,\n", " 3.999999999642842,\n", " 3.9999999994070214]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_add = enc_v1 + enc_v2\n", "decrypted_result = result_add.decrypt()\n", "decrypted_result" ] }, { "cell_type": "code", "execution_count": 6, "id": "fcad2dae-5bb0-4683-aa99-0821374952f1", "metadata": {}, "outputs": [], "source": [ "assert pytest.approx(decrypted_result, abs=10**-3) == [v1 + v2 for v1, v2 in zip(v1, v2)]" ] }, { "cell_type": "markdown", "id": "d8736536-9895-4d90-8134-ee7b6c675b05", "metadata": {}, "source": [ "### Multiplication" ] }, { "cell_type": "code", "execution_count": 7, "id": "fe188335-f3e5-4906-8392-f8d264e96bb7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[3.920404223478613e-09,\n", " 3.000000406002841,\n", " 4.000000537622926,\n", " 3.000000400956933,\n", " -3.3927722853377418e-09]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_mul = enc_v1 * enc_v2\n", "decrypted_result = result_mul.decrypt()\n", "decrypted_result" ] }, { "cell_type": "code", "execution_count": 8, "id": "760b9b12-344d-4ec7-97d1-0fe2f569bb95", "metadata": {}, "outputs": [], "source": [ "assert pytest.approx(decrypted_result, abs=10**-3) == [v1 * v2 for v1, v2 in zip(v1, v2)]" ] }, { "cell_type": "markdown", "id": "fa41f2a8-b90b-4027-8d69-f8e51d12cc16", "metadata": {}, "source": [ "### Polynomial conversion to 1 + X^2 + X^3" ] }, { "cell_type": "code", "execution_count": 9, "id": "b5a8ced2-2123-4370-98c2-296f052be051", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0000000022713067,\n", " 3.0000009493306408,\n", " 13.000006974897476,\n", " 37.00002292364773,\n", " 81.00005366743699]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_poly = enc_v1.polyval([1,0,1,1]) # 1 + X^2 + X^3\n", "decrypted_result = result_poly.decrypt()\n", "decrypted_result" ] }, { "cell_type": "code", "execution_count": 10, "id": "527ef025-b333-4e90-92d4-d2e439c8c405", "metadata": {}, "outputs": [], "source": [ "assert pytest.approx(decrypted_result, abs=10**-3) == [1 + v**2 + v**3 for v in v1]" ] }, { "cell_type": "markdown", "id": "121e292e-2068-4179-a65d-82171630ed7d", "metadata": {}, "source": [ "## Training a Logistic Regression Model on non-encrypted training data" ] }, { "cell_type": "code", "execution_count": 11, "id": "813a100a-8841-4d3b-a2a6-8138b4f3fb10", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>male</th>\n", " <th>age</th>\n", " <th>education</th>\n", " <th>currentSmoker</th>\n", " <th>cigsPerDay</th>\n", " <th>BPMeds</th>\n", " <th>prevalentStroke</th>\n", " <th>prevalentHyp</th>\n", " <th>diabetes</th>\n", " <th>totChol</th>\n", " <th>sysBP</th>\n", " <th>diaBP</th>\n", " <th>BMI</th>\n", " <th>heartRate</th>\n", " <th>glucose</th>\n", " <th>TenYearCHD</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>39</td>\n", " <td>4.0</td>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>195.0</td>\n", " <td>106.0</td>\n", " <td>70.0</td>\n", " <td>26.97</td>\n", " <td>80.0</td>\n", " <td>77.0</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>46</td>\n", " <td>2.0</td>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>250.0</td>\n", " <td>121.0</td>\n", " <td>81.0</td>\n", " <td>28.73</td>\n", " <td>95.0</td>\n", " <td>76.0</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>1</td>\n", " <td>48</td>\n", " <td>1.0</td>\n", " <td>1</td>\n", " <td>20.0</td>\n", " <td>0.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>245.0</td>\n", " <td>127.5</td>\n", " <td>80.0</td>\n", " <td>25.34</td>\n", " <td>75.0</td>\n", " <td>70.0</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0</td>\n", " <td>61</td>\n", " <td>3.0</td>\n", " <td>1</td>\n", " <td>30.0</td>\n", " <td>0.0</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>225.0</td>\n", " <td>150.0</td>\n", " <td>95.0</td>\n", " <td>28.58</td>\n", " <td>65.0</td>\n", " <td>103.0</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0</td>\n", " <td>46</td>\n", " <td>3.0</td>\n", " <td>1</td>\n", " <td>23.0</td>\n", " <td>0.0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>285.0</td>\n", " <td>130.0</td>\n", " <td>84.0</td>\n", " <td>23.10</td>\n", " <td>85.0</td>\n", " <td>85.0</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " male age education currentSmoker cigsPerDay BPMeds prevalentStroke \\\n", "0 1 39 4.0 0 0.0 0.0 0 \n", "1 0 46 2.0 0 0.0 0.0 0 \n", "2 1 48 1.0 1 20.0 0.0 0 \n", "3 0 61 3.0 1 30.0 0.0 0 \n", "4 0 46 3.0 1 23.0 0.0 0 \n", "\n", " prevalentHyp diabetes totChol sysBP diaBP BMI heartRate glucose \\\n", "0 0 0 195.0 106.0 70.0 26.97 80.0 77.0 \n", "1 0 0 250.0 121.0 81.0 28.73 95.0 76.0 \n", "2 0 0 245.0 127.5 80.0 25.34 75.0 70.0 \n", "3 1 0 225.0 150.0 95.0 28.58 65.0 103.0 \n", "4 0 0 285.0 130.0 84.0 23.10 85.0 85.0 \n", "\n", " TenYearCHD \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 1 \n", "4 0 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv(\"framingham.csv\")\n", "data.head()" ] }, { "cell_type": "markdown", "id": "41a6695c-47ee-404d-be3e-ca148dfee5cf", "metadata": {}, "source": [ "### Prepare the data\n", "\n", "We now prepare the training and test data, the dataset was downloaded from Kaggle. T his dataset provides patients' information along with a 10-year risk of future coronary heart disease (CHD) as a label, and the goal is to build a model that can predict this 10-year CHD risk based on patients' information." ] }, { "cell_type": "code", "execution_count": 12, "id": "19b10a4f-c5bb-4f1c-ad84-51f39b2854ba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "############# Data summary #############\n", "x_train has shape: torch.Size([780, 9])\n", "y_train has shape: torch.Size([780, 1])\n", "x_test has shape: torch.Size([334, 9])\n", "y_test has shape: torch.Size([334, 1])\n", "#######################################\n" ] } ], "source": [ "torch.random.manual_seed(73)\n", "random.seed(73)\n", "\n", "def split_train_test(x, y, test_ratio=0.3):\n", " idxs = [i for i in range(len(x))]\n", " random.shuffle(idxs)\n", " # delimiter between test and train data\n", " delim = int(len(x) * test_ratio)\n", " test_idxs, train_idxs = idxs[:delim], idxs[delim:]\n", " return x[train_idxs], y[train_idxs], x[test_idxs], y[test_idxs]\n", " \n", "def heart_disease_data(): \n", " data = pd.read_csv(\"framingham.csv\")\n", " # drop rows with missing values\n", " data = data.dropna()\n", " # drop some features\n", " data = data.drop(columns=[\"education\", \"currentSmoker\", \"BPMeds\", \"diabetes\", \"diaBP\", \"BMI\"])\n", " # balance data\n", " grouped = data.groupby('TenYearCHD')\n", " data = grouped.apply(lambda x: x.sample(grouped.size().min(), random_state=73).reset_index(drop=True))\n", " # extract labels\n", " y = torch.tensor(data[\"TenYearCHD\"].values).float().unsqueeze(1)\n", " data = data.drop(labels=\"TenYearCHD\", axis='columns')\n", " # standardize data\n", " data = (data - data.mean()) / data.std()\n", " x = torch.tensor(data.values).float()\n", " return split_train_test(x, y)\n", "\n", "x_train, y_train, x_test, y_test = heart_disease_data()\n", "\n", "print(\"############# Data summary #############\")\n", "print(f\"x_train has shape: {x_train.shape}\")\n", "print(f\"y_train has shape: {y_train.shape}\")\n", "print(f\"x_test has shape: {x_test.shape}\")\n", "print(f\"y_test has shape: {y_test.shape}\")\n", "print(\"#######################################\")" ] }, { "cell_type": "markdown", "id": "b8cc32dd-28f1-425e-8bf9-789d0436793a", "metadata": {}, "source": [ "## Here we define the models used.\n", "\n", "- LR is a simple logistic regression for plain data\n", "\n", "We will start by training a logistic regression model (without any encryption), which can be viewed as a single layer neural network with a single node." ] }, { "cell_type": "code", "execution_count": 13, "id": "d9772c9e-2059-4e3a-a47f-780eedf30ce7", "metadata": {}, "outputs": [], "source": [ "class LR(torch.nn.Module):\n", "\n", " def __init__(self, n_features):\n", " super(LR, self).__init__()\n", " self.lr = torch.nn.Linear(n_features, 1)\n", " \n", " def forward(self, x):\n", " out = torch.sigmoid(self.lr(x))\n", " return out" ] }, { "cell_type": "markdown", "id": "b27068ad-7461-40ad-86bf-7334a3c15cbb", "metadata": {}, "source": [ "### Train the logistic regression" ] }, { "cell_type": "code", "execution_count": 14, "id": "36e210ff-31b6-4e79-8e6c-cf9f199ec986", "metadata": {}, "outputs": [], "source": [ "n_features = x_train.shape[1]\n", "model = LR(n_features)\n", "\n", "# use gradient descent with a learning_rate=1\n", "optim = torch.optim.SGD(model.parameters(), lr=1)\n", "\n", "# use Binary Cross Entropy Loss\n", "criterion = torch.nn.BCELoss()" ] }, { "cell_type": "code", "execution_count": 15, "id": "f577b622-d98b-4206-a2ac-c5ac9119a30a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss at epoch 1: 0.8504331707954407\n", "Loss at epoch 2: 0.6863384246826172\n", "Loss at epoch 3: 0.6358115673065186\n", "Loss at epoch 4: 0.6193529367446899\n", "Loss at epoch 5: 0.6124349236488342\n", "Loss at epoch 6: 0.6089244484901428\n", "Loss at epoch 7: 0.6069258451461792\n", "Loss at epoch 8: 0.6057038307189941\n", "Loss at epoch 9: 0.6049202084541321\n", "Loss at epoch 10: 0.604399561882019\n", "Loss at epoch 11: 0.6040432453155518\n", "Loss at epoch 12: 0.6037929058074951\n", "Loss at epoch 13: 0.6036127805709839\n", "Loss at epoch 14: 0.6034800410270691\n", "Loss at epoch 15: 0.6033799648284912\n", "Loss at epoch 16: 0.6033029556274414\n", "Loss at epoch 17: 0.6032425165176392\n", "Loss at epoch 18: 0.6031941175460815\n", "Loss at epoch 19: 0.603154718875885\n", "Loss at epoch 20: 0.6031221747398376\n" ] } ], "source": [ "EPOCHS = 20\n", "\n", "def train(model, optim, criterion, x, y, epochs=EPOCHS):\n", " for e in range(1, epochs + 1):\n", " optim.zero_grad()\n", " out = model(x)\n", " loss = criterion(out, y)\n", " loss.backward()\n", " optim.step()\n", " print(f\"Loss at epoch {e}: {loss.data}\")\n", " return model\n", "\n", "model = train(model, optim, criterion, x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 16, "id": "5e2eef5d-9c9e-49ec-991c-4805b9ba5794", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on plain test_set: 0.7005987763404846\n" ] } ], "source": [ "def accuracy(model, x, y):\n", " out = model(x)\n", " correct = torch.abs(y - out) < 0.5\n", " return correct.float().mean()\n", "\n", "plain_accuracy = accuracy(model, x_test, y_test)\n", "print(f\"Accuracy on plain test_set: {plain_accuracy}\")" ] }, { "cell_type": "markdown", "id": "611dadea-0b28-4c38-b4bd-31f2c17304a2", "metadata": {}, "source": [ "We will be comparing accuracies over encrypted data against the plain_accuracy we got here." ] }, { "cell_type": "markdown", "id": "7c1e30ce-b21f-482a-9312-6f3c95c2b9ec", "metadata": {}, "source": [ "## Encrypted Evaluation on Encrypted Test Data\n", "\n", "In this part, we will just focus on evaluating the logistic regression model with plain parameters (optionally encrypted parameters) on the encrypted test set. We first create a PyTorch-like LR model that can evaluate encrypted data:" ] }, { "cell_type": "code", "execution_count": 17, "id": "b895251e-8a30-4e3a-8e7a-9407f167a87d", "metadata": {}, "outputs": [], "source": [ "enc_x_test = [ts.ckks_vector(context, x.tolist()) for x in x_test]" ] }, { "cell_type": "markdown", "id": "b5254ff5-63b2-46b7-8322-e1bd746e6605", "metadata": {}, "source": [ "### We create the model variant that can handle encrypted tensors\n", "- EncryptedLR adapts the forward method to the API exposed by the encrypted ciphertexts." ] }, { "cell_type": "code", "execution_count": 19, "id": "6124c453-8b8c-4a9b-b1b0-ed42a764f486", "metadata": {}, "outputs": [], "source": [ "class EncryptedLR:\n", " def __init__(self, torch_lr):\n", " # TenSEAL processes lists and not torch tensors\n", " # so we take out parameters from the PyTorch model\n", " self.weight = torch_lr.lr.weight.data.tolist()[0]\n", " self.bias = torch_lr.lr.bias.data.tolist()\n", " \n", " def forward(self, enc_x):\n", " # We don't need to perform sigmoid as this model\n", " # will only be used for evaluation, and the label\n", " # can be deduced without applying sigmoid\n", " enc_out = enc_x.dot(self.weight) + self.bias\n", " return enc_out\n", " \n", " def __call__(self, *args, **kwargs):\n", " return self.forward(*args, **kwargs)\n", " \n", " ## You can use the functions below to perform \n", " ## the evaluation with an encrypted model \n", " \n", " def encrypt(self, context):\n", " self.weight = ts.ckks_vector(context, self.weight)\n", " self.bias = ts.ckks_vector(context, self.bias)\n", " \n", " def decrypt(self, context):\n", " self.weight = self.weight.decrypt()\n", " self.bias = self.bias.decrypt()\n", "\n", "eelr = EncryptedLR(model)\n", " \n", "# encrypt the model's parameters\n", "eelr.encrypt(context)" ] }, { "cell_type": "code", "execution_count": 20, "id": "1f1eab87-f132-4161-ab21-13ee554b498e", "metadata": {}, "outputs": [], "source": [ "def encrypted_evaluation(model, enc_x_test, y_test): \n", " correct = 0\n", " for enc_x, y in zip(enc_x_test, y_test):\n", " # encrypted evaluation\n", " enc_out = model(enc_x)\n", " # plain comparaison\n", " out = enc_out.decrypt()\n", " out = torch.tensor(out)\n", " out = torch.sigmoid(out)\n", " if torch.abs(out - y) < 0.5:\n", " correct += 1\n", " \n", " print(f\"Evaluated test_set of {len(x_test)} entries. Accuracy: {correct}/{len(x_test)} = {correct / len(x_test)}\")\n", " return correct / len(x_test)" ] }, { "cell_type": "code", "execution_count": 21, "id": "acc55ef3-ac7c-42ff-8552-a03e5cb7a9df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluated test_set of 334 entries. Accuracy: 234/334 = 0.7005988023952096\n", "Difference between plain and encrypted accuracies: 0.0\n" ] } ], "source": [ "encrypted_accuracy = encrypted_evaluation(eelr, enc_x_test, y_test)\n", "diff_accuracy = plain_accuracy - encrypted_accuracy\n", "print(f\"Difference between plain and encrypted accuracies: {diff_accuracy}\")\n", "\n", "if diff_accuracy < 0:\n", " print(\"Oh! We got a better accuracy on the encrypted test-set! The noise was on our side...\")" ] }, { "cell_type": "markdown", "id": "04a34d9b-e2ab-4f87-9596-dda82ee1fba4", "metadata": {}, "source": [ "We saw that evaluating on the encrypted test set doesn't affect the accuracy." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.7 ('climate')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "vscode": { "interpreter": { "hash": "0d7967a8d9afcf4b31f4d35ec92712493ac98bed77dd4624f3d610c11e265b0a" } } }, "nbformat": 4, "nbformat_minor": 5 }