IBM Federated Learning with Homomorphic Encryption¶

This notebook demonstrates how to run Federated Learning training experiments with Homomorphic Encryption.

Learning Goals¶

The learning goals of this notebook are:

  • Create and use the WML Python Client to setup and run Federated Learning training jobs.
  • Create WML assets (initial models, remote training systems), data sets and handlers, and cryptographic files (certificates, keys) - that are required for running a Federated Learning training job with encryption. All local files are stored in a single directory tree, whose root location can be specified using a setup parameter.
  • Launch a Federated Learning training job, by launching the aggergator and multiple parties. The aggregator runs in a cluster, and the parties run in the machine that is running this notebook. The number of parties can be specified using a setup parameter.
  • Monitor the training job.
  • Cleanup of the WML assets and the local files and directories created by this notebook. You will also be able to reuse assets and files that were created in previous sessions of this notebook and not removed.

Table of Contents¶

  • Introduction
  • Prerequisites
  • Basic setup
  • Create a WML client
  • Create WML assets
  • Create parties data
  • Create parties cryptographic elements
  • Launch aggregator
  • Launch parties
  • Monitor execution status of the training
  • Cleanup
  • Next steps

Introduction¶

IBM Federated Learning¶

IBM Federated Learning enables you to train a machine learning model across multiple decentralized parties holding local data sets, without sharing the local data sets. Such parties can be for example within an enterprise, within a consortium of enterprises, within multiple data centers or multiple clouds, or on edge devices. This allows to build a collective machine learning model without sharing data between the nodes, therefore addressing data security, privacy, and regulatory compliance requirements, as well as eliminating data movement and its associated costs.

In the federated learning training process, the parties build locally trained machine learning models and send these local models to an aggregator. The aggregator fuses the local models into an aggregated model and sends this model back to the parties to continue with the next round of training.

For additional details see IBM Federated Learning documentation.

Homomorphic encryption support in IBM Federated Learning¶

IBM Federated Learning uses SSL secured connections between the parties and the aggregator for communicating the machine learning models. In this setting, the aggregator can still see the unencrypted local and aggregated models.

IBM Federated Learning further includes homomorphic encryption capabilities, to enhance the parties’ data privacy and security in settings where the aggregator operates in an environment which is less trusted, and the parties wish to avoid revealing the local models and the aggregated models to the aggregator.

Homomorphic encryption (HE) is a form of encryption that enables performing computations on the encrypted data without decrypting it. The results of the computations remain in encrypted form which, when decrypted, results in an output that is the same as the output produced had the computations been performed on the unencrypted data.

In federated learning, homomorphic encryption enables the parties to homomorphically encrypt their local model updates before sending them to the aggregator. The aggregator sees only the homomorphically encrypted local model updates, and therefore cannot learn anything from this information. Specifically, the aggregator is not able to reverse-engineer the local model updates to discover information on local training data. The aggregator fuses the local model updates in their encrypted form, obtaining an encrypted aggregated model. Then the aggregator sends the encrypted aggregated model to the parties, which decrypt it and continue with the next round of training.

Homomorphic encryption is a form of public key cryptography. It uses a public key for encryption and a private key for decryption. In IBM Federated Learning with homomorphic encryption, the parties (also named “remote training systems”) share the private HE key, and the aggregator has only the public HE key. Each party encrypts its local model update using the public HE key, and sends its encrypted local model update to the aggregator. Since the aggregator does not have the private HE key, it cannot decrypt the encrypted local model updates.

The aggregator uses its public HE key to fuse the encrypted local model updates into a new encrypted aggregated model. This encrypted aggregated model is sent to the parties, which decrypt it using their private HE key, and continue the model training process.

IBM Federated Learning makes it easy to use homomorphic encryption in model training, by specifying simple parameters in the configurations of the aggregator and the parties. IBM Federated Learning includes a mechanism that generates and distributes automatically and securely homomorphic encryption keys among the parties participating in a training experiment.

Prerequisites¶

  1. The currently supported operating system and architecture for running parties in Federated Learning experiments with homomorphic encryption is Linux x86. Therefore, this notebook must run on a Linux x86 platform.



2. Install the IBM Watson Machine Learning Python client package with homomorphic encryption support, within the Python environment in which this notebook runs. Use the following command within your Python environment:
pip install 'ibm_watsonx_ai[fl-rt23.1-py3.10,fl-crypto]'.
You can use the installation cell in the next section of this notebook to perform this installation.
This installation is required for any Python environment that will be used for running parties in Federated Learning experiments with homomorphic encryption.

3. In your IBM cloud account:

  1. Obtain your cloud user IAM ID by accessing the Users page, clicking on the relevant user, clicking on Details and copying the IAM ID. The format of the IAM ID is IBMid-<aaa>.
  2. Generate and obtain an API key from the API keys page.



4. Create a Watson Machine Learning service instance. A free plan is offered. In your Watson Machine Learning service instance:

  1. Create a project or use an existing project, for use with this notebook. Access the Projects page, and click New project to create a new project. For additional details see: Creating a project. Obtain the ID of the project to be used from Projects > Specific project > Manage > Project ID.
  2. Associate your project with the Watson Machine Learning service, by accessing Projects > Specific project > Manage > Services & integrations > Associate service. For additional details see: Associating services.
  3. Obtain the location name of your Watson Machine Learning service instance, by accessing the location drop-down menu in the top toolbar. Example location names: us-south, eu-gb, eu-de, jp-tok.

Basic setup¶

Install the IBM Watson Machine Learning Python client package with homomorphic encryption support, within the Python environment in which this notebook runs, if this package is not yet installed in this environment.

In [ ]:
%pip install --upgrade 'ibm_watsonx_ai[fl-rt23.1-py3.10,fl-crypto]'

The following cell applies base definitions for the notebook.

User action: Before running the following cell, replace the mandatory TBDs in the cell with your information and review the optional TBDs.

In [ ]:
import os
import subprocess
import urllib3
import requests
urllib3.disable_warnings()

cmd = subprocess.Popen("pip list | grep -E 'ibm-watsonx-ai|ibm_watsonx_ai'", 
    shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
wml_installed = len(cmd.communicate()[0]) > 0
if not wml_installed:
    raise Exception('ibm-watsonx-ai package must be installed in the environment')

base_dir = os.getcwd() # TBD [optional] A base directory under which the notebook work directory will be created. Default is the current work directory.
nb_dir = os.path.join(base_dir, 'fl_fhe_nb')
data_path = os.path.join(nb_dir, 'data')
model_path = os.path.join(nb_dir, 'model')
crypto_path = os.path.join(nb_dir, 'crypto')
exec_path = os.path.join(nb_dir, 'exec')
if not os.path.exists(data_path):
    os.makedirs(data_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)
if not os.path.exists(crypto_path):
    os.makedirs(crypto_path)
if not os.path.exists(exec_path):
    os.makedirs(exec_path)
os.chdir(exec_path)

PROJECT_ID = '' # TBD [mandatory] See the prerequisites section for details.
CLOUD_USERID = '' # TBD [mandatory] See the prerequisites section for details.
IAM_APIKEY = '' # TBD [mandatory] See the prerequisites section for details.
WML_SERVICES_LOCATION = '' # TBD [mandatory] See the prerequisites section for details.
WML_SERVICES_URL = 'https://' + WML_SERVICES_LOCATION + '.ml.cloud.ibm.com'
NUM_RTS = int(3) # TBD [optional] This parameter enables to specify the number of parties for a training experiment.
SW_SPEC_NAME = 'runtime-23.1-py3.10'
HW_SPEC_NAME = 'S'
RSC_TAGS = ['wml_fl_fhe_nb_example']
TIMEOUT_TRAINING_SEC = 600
crypto_file_ext = 'v1'
asym_file_is = crypto_path + "/is_asym_" + crypto_file_ext + ".pem"
cert_file_is = crypto_path + "/is_cert_" + crypto_file_ext + ".pem"
asym_file_sb = crypto_path + "/sb_asym_" + crypto_file_ext + "_"
csr_file_sb = crypto_path + "/sb_csr_" + crypto_file_ext + "_"
cert_file_sb = crypto_path + "/sb_cert_" + crypto_file_ext + "_"
prt_data_file_prefix = 'data_party_'
NUM_MODELS = int(1)
MODEL_NAME = 'pytorch'
MODEL_TYPE = 'pytorch-onnx_2.0'
INIT_MODEL_FILE_NAME = 'pt_mnist_init_model.zip'
INIT_MODEL_URL = 'https://github.com/IBMDataScience/sample-notebooks/raw/master/Files/pt_mnist_init_model.zip'
DATA_HANDLER_FILE_NAME = 'mnist_pytorch_data_handler.py'
DATA_HANDLER_CLASS_NAME = 'MnistPytorchDataHandler'
DATASET_FILE_NAME = 'mnist.npz'
DATASET_URL = 'https://api.dataplatform.cloud.ibm.com/v2/gallery-assets/entries/85ae67d0cf85df6cf114d0664194dc3b/data'

hearbeat_resp = requests.get(WML_SERVICES_URL + "/wml_services/training/heartbeat", verify=False)
print("Heartbeat response %s" % hearbeat_resp.content.decode("utf-8"))

Create a WML client¶

This section creates and activates a WML client, which enables to interact with your WML instance.

In [ ]:
from ibm_watsonx_ai import APIClient

wml_credentials = {
    "url": WML_SERVICES_URL,
    "apikey": IAM_APIKEY
}
wml_client = APIClient(wml_credentials)
wml_client.set.default_project(PROJECT_ID)

Create WML assets¶

The WML assets created in this notebook are initial models and remote training systems.
In this section you can either create new assets, or reuse assets that were created in a previous session of this notebook and not removed.

Create new assets¶

Store initial model assets in the cluster¶

Initial untrained model assets are required for Federated Learning.
In this notebook, an untrained Pytorch model is used.
For additional details see the documentation on creating initial models.

First, we download a pre-built initial model.

In [ ]:
import shutil
print("Downloading initial model")
init_model_file_path = os.path.join(model_path, INIT_MODEL_FILE_NAME)
with requests.get(INIT_MODEL_URL, stream=True) as r:
    with open(init_model_file_path, 'wb') as f:
        shutil.copyfileobj(r.raw, f)
print('Model stored in: ' + str(init_model_file_path))
print("Done")

Next, we upload the initial model as an asset into the cluster.

In [ ]:
print("Storing initial model")
sw_spec_id = wml_client.software_specifications.get_id_by_name(SW_SPEC_NAME)
untrained_model_ids = {}
model_metadata = {
    wml_client.repository.ModelMetaNames.NAME: MODEL_NAME,
    wml_client.repository.ModelMetaNames.TYPE: MODEL_TYPE,
    wml_client.repository.ModelMetaNames.SOFTWARE_SPEC_UID: sw_spec_id,
    wml_client.repository.ModelMetaNames.TAGS: RSC_TAGS
}
untrained_model_details = wml_client.repository.store_model(os.path.join(model_path, INIT_MODEL_FILE_NAME), model_metadata)
untrained_model_ids[MODEL_NAME] = wml_client.repository.get_model_id(untrained_model_details)
print('Model id: ' + str(untrained_model_ids[MODEL_NAME]))
print('Done')

Create remote training systems in the cluster¶

A Remote Training System (RTS) asset defines a party that connects to the aggregator for a training experiment.
For additional details see the corresponding documentation page.

In [ ]:
print("Creating Remote Training Systems")
remote_training_systems = []
for i in range(NUM_RTS):
    rts_metadata = {
        wml_client.remote_training_systems.ConfigurationMetaNames.NAME: "Party_"+str(i),
        wml_client.remote_training_systems.ConfigurationMetaNames.TAGS: RSC_TAGS,
        wml_client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name" : "IBM", "region": "US"},
        wml_client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": CLOUD_USERID, "type": "user"}],
        wml_client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": CLOUD_USERID, "type":"user"}
    }
    rts = wml_client.remote_training_systems.store(rts_metadata)
    rts_id = wml_client.remote_training_systems.get_id(rts)
    print('Remote training system Party_' + str(i) + ' id: ' + str(rts_id))
    remote_training_systems.append({'id': rts_id, 'required': True})
print('Done')

Reuse existing assets¶

Run the following cell if you are reusing WML assets that were created in previous sessions.
This code enables to build internal notebook lists from existing assets. These lists are used in later operations of this notebook.

In [ ]:
import json

FORCE_REBUILD_DS = False

print("Models:")
if 'untrained_model_ids' not in globals() or FORCE_REBUILD_DS or \
    len(untrained_model_ids) != NUM_MODELS:
    untrained_model_ids = {}
    load_models_dict = True
else:
    load_models_dict = False
models = wml_client.repository.get_model_details(get_all=True)
for m in models['resources']:
    md = m['metadata']
    if not 'tags' in md or md['tags'] != RSC_TAGS:
        continue
    if load_models_dict:
        untrained_model_ids[md['name']] = md['id']
    print('{}: {}'.format(md['name'],md['id']))

print("Remote Training Systems:")
if 'remote_training_systems' not in globals() or FORCE_REBUILD_DS or \
    len(remote_training_systems) != NUM_RTS:
    remote_training_systems = []
    load_rts_lst = True
else:
    load_rts_lst = False
rts = wml_client.remote_training_systems.get_details()
for r in rts['resources']:
    md = r['metadata']
    if not 'tags' in md or md['tags'] != RSC_TAGS:
        continue
    if load_rts_lst:
        remote_training_systems.append({'id': md['id'], 'required': True})
    print('{}: {}'.format(md['name'],md['id']))

Create parties data¶

This section downloads the MNIST data set and splits it into subsets for the parties.
Then, it defines and stores a data handler.

Download data set and split it for the parties¶

In [ ]:
import os
import requests
import numpy as np
import shutil

def load_mnist(normalize=True, download_dir=''):
    """
    Download MNIST training data from source used in `keras.datasets.load_mnist`
    :param normalize: whether or not to normalize data
    :type normalize: bool
    :param download_dir: directory to download data
    :type download_dir: `str`
    :return: 2 tuples containing training and testing data respectively
    :rtype (`np.ndarray`, `np.ndarray`), (`np.ndarray`, `np.ndarray`)
    """
    local_file = os.path.join(download_dir, DATASET_FILE_NAME)
    if not os.path.isfile(local_file):
        with requests.get(DATASET_URL, stream=True) as r:
            with open(local_file, 'wb') as f:
                shutil.copyfileobj(r.raw, f)
        with np.load(local_file, allow_pickle=True) as mnist:
            x_train, y_train = mnist['x_train'], mnist['y_train']
            x_test, y_test = mnist['x_test'], mnist['y_test']
            if normalize:
                x_train = x_train.astype('float32')
                x_test = x_test.astype('float32')
                x_train /= 255
                x_test /= 255
        np.savez(local_file, x_train=x_train, y_train=y_train,
                 x_test=x_test, y_test=y_test)
    else:
        with np.load(local_file, allow_pickle=True) as mnist:
            x_train, y_train = mnist['x_train'], mnist['y_train']
            x_test, y_test = mnist['x_test'], mnist['y_test']
    return (x_train, y_train), (x_test, y_test)

def save_mnist_party_data(nb_dp_per_party, should_stratify, party_folder, dataset_folder):
    """
    Saves MNIST party data
    :param nb_dp_per_party: the number of data points each party should have
    :type nb_dp_per_party: `list[int]`
    :param should_stratify: True if data should be assigned proportional to source class distributions
    :type should_stratify: `bool`
    :param party_folder: folder to save party data
    :type party_folder: `str`
    :param dataset_folder: folder to save dataset
    :type data_path: `str`
    :param dataset_folder: folder to save dataset
    :type dataset_folder: `str`
    """
    if not os.path.exists(dataset_folder):
        os.makedirs(dataset_folder)
    (x_train, y_train), (x_test, y_test) = load_mnist(download_dir=dataset_folder)
    labels, train_counts = np.unique(y_train, return_counts=True)
    te_labels, test_counts = np.unique(y_test, return_counts=True)
    diff_labels = np.all(np.isin(labels, te_labels))
    num_train = np.shape(y_train)[0]
    num_test = np.shape(y_test)[0]
    num_labels = np.shape(np.unique(y_test))[0]
    nb_parties = len(nb_dp_per_party)
    if should_stratify:
        train_probs = {
            label: train_counts[label] / float(num_train) for label in labels}
        test_probs = {label: test_counts[label] /
                        float(num_test) for label in te_labels}
    else:
        train_probs = {label: 1.0 / len(labels) for label in labels}
        test_probs = {label: 1.0 / len(te_labels) for label in te_labels}
    for idx, dp in enumerate(nb_dp_per_party):
        train_p = np.array([train_probs[y_train[idx]]
                            for idx in range(num_train)])
        train_p /= np.sum(train_p)
        train_indices = np.random.choice(num_train, dp, p=train_p)
        test_p = np.array([test_probs[y_test[idx]] for idx in range(num_test)])
        test_p /= np.sum(test_p)
        test_indices = np.random.choice(
            num_test, int(num_test / nb_parties), p=test_p)
        x_train_pi = x_train[train_indices]
        y_train_pi = y_train[train_indices]
        x_test_pi = x_test[test_indices]
        y_test_pi = y_test[test_indices]
        name_file = prt_data_file_prefix + str(idx) + '.npz'
        name_file = os.path.join(party_folder, name_file)
        np.savez(name_file, x_train=x_train_pi, y_train=y_train_pi,
                 x_test=x_test_pi, y_test=y_test_pi)
    print('Data saved in ' + party_folder)
    return

save_mnist_party_data(nb_dp_per_party=[200 for _ in range(NUM_RTS)], should_stratify=False, 
    party_folder=data_path, dataset_folder=data_path)
print('Done')

Define and store a data handler¶

This section creates a data handler Python file for the MNIST dataset to train using PyTorch.
For additional details see the corresponding documentation page.

In [ ]:
%%writefile mnist_pytorch_data_handler.py
import numpy as np
from ibmfl.data.data_handler import DataHandler

class MnistPytorchDataHandler(DataHandler):
    """
    Data handler for the MNIST dataset to train using PyTorch.
    """

    def __init__(self, data_config=None):
        super().__init__()
        self.file_name = None
        if data_config is not None:
            if 'npz_file' in data_config:
                self.file_name = data_config['npz_file']
        # Load the datasets.
        (self.x_train, self.y_train), (self.x_test, self.y_test) = self.load_dataset()
        # Pre-process the datasets.
        self.preprocess()

    def get_data(self):
        """
        Gets pre-process mnist training and testing data.

        :return: training data
        :rtype: `tuple`
        """
        return (self.x_train, self.y_train), (self.x_test, self.y_test)

    def load_dataset(self, nb_points=500):
        """
        Loads the training and testing datasets from a given local path.
        If no local path is provided, it will download the original MNIST \
        dataset online, and reduce the dataset size to contain \
        500 data points per training and testing dataset.
        Because this method
        is for testing it takes as input the number of datapoints, nb_points,
        to be included in the training and testing set.

        :param nb_points: Number of data points to be included in each set if
        no local dataset is provided.
        :type nb_points: `int`
        :return: training and testing datasets
        :rtype: `tuple`
        """
        try:
            data_train = np.load(self.file_name)
            x_train = data_train['x_train']
            y_train = data_train['y_train']
            x_test = data_train['x_test']
            y_test = data_train['y_test']
        except Exception:
            raise IOError('Unable to load training data from path '
                            'provided in config file: ' +
                            self.file_name)
        return (x_train, y_train), (x_test, y_test)

    def preprocess(self):
        """
        Preprocesses the training and testing dataset, \
        e.g., reshape the images according to self.channels_first; \
        convert the labels to binary class matrices.

        :return: None
        """
        img_rows, img_cols = 28, 28
        self.x_train = self.x_train.astype('float32').reshape(self.x_train.shape[0], 1, img_rows, img_cols)
        self.x_test = self.x_test.astype('float32').reshape(self.x_test.shape[0], 1,img_rows, img_cols)
        self.y_train = self.y_train.astype('int64')
        self.y_test = self.y_test.astype('int64')
In [ ]:
import shutil
shutil.move(os.path.join('.', DATA_HANDLER_FILE_NAME), os.path.join(data_path, DATA_HANDLER_FILE_NAME))

Create parties cryptographic elements¶

This section creates the certificate and key files required for running a Federated Learning training experiment with encryption.
Two methods are provided in this section for creating the cryptographic files - using the Python cryptography package, or using openssl. Use either one of these methods.

Homomorphic encryption keys are generated and distributed automatically and securely among the parties for each experiment. Only the parties participating in an experiment have access to the homomorphic encryption private key generated for the experiment.
To facilitate this generation and distribution process, the following steps must be performed before an experiment:

  • All the parties participating in the experiment must agree on a single Certificate Authority.
  • Each party must be provisioned with a certificate from the agreed Certificate Authority.
  • Each party must be provisioned with an RSA key pair. The RSA public key must be included in the aforementioned party certificate.

An RSA key pair and certificate for a party must be generated using the following parameters and guidelines:

  • Key type: RSA.
  • Key size: 4096 bit.
  • Public exponent: 65537.
  • No password for the RSA key file.
  • Hash algorithm: SHA256.
  • The key and certificate files must be in PEM format.

Each party must be configured with paths to the following files:

  • Certificate of the Certificate Authority.
  • Certificate of the party issued by the Certificate Authority (includes the RSA public key of the party).
  • RSA private key of the party.

Further details on this configuration are provided in the notebook section Launch parties.

In this notebook, we generate and provision self-signed certificates.

Method 1: Using Python Cryptography package¶

In [ ]:
import os
import datetime
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.x509.oid import NameOID

class CryptoRsa():

    KEY_SIZE = 4096
    PUBLIC_EXPONENT = 65537
    CRYPTO_HASH = hashes.SHA256()

    def __init__(self):
        self.private_key = CryptoRsa.generate_key()

    def generate_key():
        private_key = rsa.generate_private_key(
            public_exponent=CryptoRsa.PUBLIC_EXPONENT,
            key_size=CryptoRsa.KEY_SIZE,
        )
        return private_key

    def get_public_key(self, type: str = "obj"):
        if self.private_key is None:
            raise Exception("self.private_key is None")
        if type == "obj":
            ret = self.private_key.public_key()
        elif type == "pem":
            ret = self.private_key.public_key().public_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo
            )
        else:
            raise Exception("Invalid type=" + repr(type))
        return ret

    def write_key_file(self, file_path: str):
        if self.private_key is None:
            raise Exception("self.private_key is None")
        pem = self.private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption()
        )
        with open(file_path, "wb") as key_file:
            key_file.write(pem)
        return

if not os.path.exists(crypto_path):
    os.makedirs(crypto_path)

issuer = x509.Name([
    x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
    x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"California"),
    x509.NameAttribute(NameOID.LOCALITY_NAME, u"San Francisco"),
    x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Issuer Company"),
    x509.NameAttribute(NameOID.COMMON_NAME, u"mysite.com"),
])
subject = x509.Name([
    x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
    x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"California"),
    x509.NameAttribute(NameOID.LOCALITY_NAME, u"San Francisco"),
    x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Subject Company"),
    x509.NameAttribute(NameOID.COMMON_NAME, u"mysite.com"),
])

issuer_key = CryptoRsa()
issuer_key.write_key_file(asym_file_is)

cert_is = x509.CertificateBuilder().subject_name(
    issuer
).issuer_name(
    issuer
).public_key(
    issuer_key.get_public_key()
).serial_number(
    x509.random_serial_number()
).not_valid_before(
    datetime.datetime.utcnow()
).not_valid_after(
    datetime.datetime.utcnow() + datetime.timedelta(days=1000)
).add_extension(
    x509.SubjectAlternativeName([x509.DNSName(u"localhost")]),
    critical=False,
).sign(issuer_key.private_key, CryptoRsa.CRYPTO_HASH)
with open(cert_file_is, "wb") as f:
    f.write(cert_is.public_bytes(serialization.Encoding.PEM))

for idx in range(NUM_RTS):
    asym_file_path = asym_file_sb+str(idx)+".pem"
    cert_file_path = cert_file_sb+str(idx)+".pem"
    subject_key = CryptoRsa()
    subject_key.write_key_file(asym_file_path)
    cert_sb = x509.CertificateBuilder().subject_name(
        subject
    ).issuer_name(
        issuer
    ).public_key(
        subject_key.get_public_key()
    ).serial_number(
        x509.random_serial_number()
    ).not_valid_before(
        datetime.datetime.utcnow()
    ).not_valid_after(
        datetime.datetime.utcnow() + datetime.timedelta(days=1000)
    ).add_extension(
        x509.SubjectAlternativeName([x509.DNSName(u"localhost")]),
        critical=False,
    ).sign(issuer_key.private_key, CryptoRsa.CRYPTO_HASH)
    with open(cert_file_path, "wb") as f:
        f.write(cert_sb.public_bytes(serialization.Encoding.PEM))

print('Done')

Method 2: Using openssl¶

In [ ]:
import os

if not os.path.exists(cert_file_is):
    ret = os.system("openssl req -x509 -newkey rsa:4096 -sha256 -days 365 -nodes "
        "-subj \"/C=US/ST=California/L=San Francisco/O=Issuer Company/OU=Org/CN=www.iscompany.com\" -keyout " + 
        str(asym_file_is) + " -out " + str(cert_file_is))
    if ret != 0:
        raise Exception("openssl for issuer failed: {}".format(ret))

for idx in range(NUM_RTS):
    asym_file_path = asym_file_sb+str(idx)+".pem"
    csr_file_path = csr_file_sb+str(idx)+".pem"
    cert_file_path = cert_file_sb+str(idx)+".pem"
    if not os.path.exists(cert_file_path):
        ret = os.system("openssl req -newkey rsa:4096 -nodes -subj "
            "\"/C=US/ST=California/L=San Francisco/O=SB Company/OU=Org/CN=www.sbcompany.com\" -keyout " +
            str(asym_file_path) + " -out " + str(csr_file_path))
        if ret != 0:
            raise Exception("openssl for subject step 1 failed: {}".format(ret))
        ret = os.system("openssl x509 -req -CAcreateserial -CA " + str(cert_file_is) + " -CAkey " + str(asym_file_is) +
            " -sha256 -days 365 -in " + str(csr_file_path) + " -out " + str(cert_file_path))
        if ret != 0:
            raise Exception("openssl for subject step 2 failed: {}".format(ret))

print('Done')

Launch aggregator¶

This section launches the Federated Learning aggregator for the experiment.

To run the experiment with homomorphic encryption, the aggregator’s configuration must specify the following fusion type:
"fusion_type": "crypto_iter_avg".

The aggregator’s configuration may also include a crypto object, that specifies the required encryption level. For example:

"crypto": {
    "cipher_spec": "encryption_level_1"
}

If this object is not specified then the default of encryption_level_1 is used.

There are four possible encryption levels, ranging from level 1 to level 4. Higher encryption levels increase security and precision, and require higher resource consumption (e.g. computation, memory, network bandwidth). The security level corresponds to the strength of the encryption system, typically measured by the number of operations that an attacker must perform to break the system. The precision level corresponds to the precision of the encryption system's outcomes. Higher precision level means that cryptographic operations are accurate up to a larger number of digits before and after the floating point. Higher precision levels reduce loss of accuracy of the model due to the encryption operations.
Following is a description of the encryption levels:

  • Encryption level 1 provides high security and good precision, and is the default level.
  • Encryption level 2 provides high security and high precision, and requires more resources than level 1.
  • Encryption level 3 provides extra high security and good precision, and requires more resources than level 2.
  • Encryption level 4 provides extra high security and high precision, and requires more resources than level 3.

For additional details on launching the aggregator see the corresponding documentation page.

In [ ]:
fl_conf = {
    "model": {
      "type": MODEL_NAME,
      "spec": {
        "id": untrained_model_ids[MODEL_NAME]
      },
      "model_file": "pytorch_sequence.pt"
    },
    "fusion_type": "crypto_iter_avg",
    "crypto": {
      "cipher_spec": "encryption_level_1"
    },
    "epochs": 1,
    "rounds": 2,
    "metrics": "accuracy",
    "remote_training": {
      "max_timeout": TIMEOUT_TRAINING_SEC,
      "quorum": 1,
      "remote_training_systems": remote_training_systems,
    },
    "software_spec": {
      "name": SW_SPEC_NAME
    },
    "hardware_spec": {
      "name": HW_SPEC_NAME
    }
}
aggregator_metadata = {
    wml_client.training.ConfigurationMetaNames.NAME: 'aggregator_he',
    wml_client.training.ConfigurationMetaNames.DESCRIPTION: '',
    wml_client.training.ConfigurationMetaNames.TAGS: RSC_TAGS,
    wml_client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
    wml_client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
        "type": "container",
        "name": "outputData",
        "connection": {},
        "location": {
          "path": "."
        }
    },
    wml_client.training.ConfigurationMetaNames.FEDERATED_LEARNING: fl_conf
}
print("Prepared config for aggregator with model type {}".format(MODEL_NAME))
aggregator = wml_client.training.run(aggregator_metadata, asynchronous=True)
print("Created Aggregator")
training_id = wml_client.training.get_id(aggregator)
print("Training id: " + str(training_id))
print ("RTS: " + str(remote_training_systems))

Launch parties¶

This section launches the Federated Learning parties for the experiment.

To run the experiment with homomorphic encryption, the parties’ configuration must include a crypto object inside the local_training object, which specifies the required certificate and key files for the party.

For additional details on launching the parties see the corresponding documentation page.

In [ ]:
import os

for idx, prt in enumerate(remote_training_systems):
    party_metadata = {
        wml_client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: {
            "info": {
                "crypto": {
                    "key_manager": {
                        "key_mgr_info": {
                            "distribution": {
                				"ca_cert_file_path": cert_file_is,
                				"my_cert_file_path": cert_file_sb+str(idx)+'.pem',
                				"asym_key_file_path": asym_file_sb+str(idx)+'.pem'
                            }
                        }
                    }
                }
            }
        },
        wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
            "info": {
                "npz_file": os.path.join(data_path, prt_data_file_prefix+str(idx)+'.npz')
            },
            "name": DATA_HANDLER_CLASS_NAME,
            "path": os.path.join(data_path, DATA_HANDLER_FILE_NAME)
        }
    }
    print("Connecting party id {} to aggregator id {}, model type {}".format(prt['id'], training_id, MODEL_NAME))
    party = wml_client.remote_training_systems.create_party(prt['id'], party_metadata)
    party.monitor_logs("ERROR")
    party.run(aggregator_id=training_id, asynchronous=True, verify=False)
    print("Party {} is running".format(prt['id']))
print('Done')

Monitor execution status of the training¶

This section enables to monitor the execution status of the training experiment.

For additional details on monitoring the experiment see the corresponding documentation page.

In [ ]:
import time
import json

def monitor_training(training_id):
    print('Monitoring training id: {}'.format(training_id))
    MAX_ITER = 240
    SLP_TIME_SEC = 10
    aggregator_status = wml_client.training.get_status(training_id)
    aggregator_state = aggregator_status['state']
    iter = 0
    while iter < MAX_ITER and 'completed' != aggregator_state and 'failed' != aggregator_state and 'canceled' != aggregator_state:
        print("Elapsed time: {} seconds, State: {}".format(iter*SLP_TIME_SEC, aggregator_state))
        time.sleep(SLP_TIME_SEC)
        aggregator_status = wml_client.training.get_status(training_id)
        aggregator_state = aggregator_status['state']
        iter += 1
    if iter >= MAX_ITER:
        raise Exception("Training did not finish after {} seconds".format(iter*SLP_TIME_SEC))
    print("Final status: " + json.dumps(aggregator_status, indent=4))

if 'training_id' in globals():
    monitor_training(training_id)
else:
    trn = wml_client.training.get_details(get_all=True)
    for t in trn['resources']:
        md = t['metadata']
        if 'tags' in md and md['tags'] == RSC_TAGS:
            monitor_training(md['id'])

Cleanup¶

Use this section to delete the training jobs, assets, and local files created using this notebook.

Remove WML assets¶

Remove training jobs¶

In [ ]:
print('Removing training jobs')
trn = wml_client.training.get_details(get_all=True)
for t in trn['resources']:
    md = t['metadata']
    if 'tags' in md and md['tags'] == RSC_TAGS:
        wml_client.training.cancel(md['id'], hard_delete=True)
        print('Deleted {}: {}'.format(md['name'],md['id']))
print('Done')

Remove remote training systems and models¶

In [ ]:
print('Removing remote training systems')
rts = wml_client.remote_training_systems.get_details(get_all=True)
for r in rts['resources']:
    md = r['metadata']
    if 'tags' in md and md['tags'] == RSC_TAGS:
        wml_client.repository.delete(md['id'])
        print('Deleted {}: {}'.format(md['name'],md['id']))

print('Removing models')
models = wml_client.repository.get_model_details(get_all=True)
for m in models['resources']:
    md = m['metadata']
    if 'tags' in md and md['tags'] == RSC_TAGS:
        wml_client.repository.delete(md['id'])
        print('Deleted {}: {}'.format(md['name'],md['id']))

print('Done')

Remove local files¶

In [ ]:
import shutil
shutil.rmtree(nb_dir)

Next steps¶

You successfully completed this notebook!

Check out our online documentation and IBM Federated Learning documentation for more tutorials, samples and documentation.


Copyright © IBM Corp. 2022-2024. This notebook and its source code are released under the terms of the MIT License.