{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading and saving models\n", "\n", "Since data linking tasks can take a long time to execute, it is often useful to be able to save the results. For example, this allows model parameters to be applied to new data, or iterations to be re-started from where they left off.\n", "\n", "In this demo, we see how we can save a model to a json file and reload it.\n", "\n", "It assumes you have already completed the [data deduplication quick start](quickstart_demo_deduplication.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Imports and setup\n", "\n", "The following is just boilerplate code that sets up the Spark session and sets some other non-essential configuration options" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging \n", "from utility_functions.demo_utils import get_spark\n", "\n", "logging.basicConfig() # Means logs will print in Jupyter Lab\n", "\n", "# Set to DEBUG if you want splink to log the SQL statements it's executing under the hood\n", "logging.getLogger(\"splink\").setLevel(logging.INFO)\n", "spark = get_spark()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Read in data and run linking" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:splink.iterate:Iteration 0 complete\n", "INFO:splink.params:The maximum change in parameters was 0.5087412834167481 for key π_gamma_surname_prob_dist_non_match_level_2_probability\n", "INFO:splink.iterate:Iteration 1 complete\n", "INFO:splink.params:The maximum change in parameters was 0.0954439640045166 for key π_gamma_surname_prob_dist_match_level_2_probability\n", "INFO:splink.iterate:Iteration 2 complete\n", "INFO:splink.params:The maximum change in parameters was 0.021286725997924805 for key π_gamma_dob_prob_dist_non_match_level_0_probability\n", "INFO:splink.iterate:Iteration 3 complete\n", "INFO:splink.params:The maximum change in parameters was 0.010865330696105957 for key π_gamma_dob_prob_dist_non_match_level_0_probability\n", "INFO:splink.iterate:Iteration 4 complete\n", "INFO:splink.params:The maximum change in parameters was 0.008596867322921753 for key π_gamma_email_prob_dist_match_level_0_probability\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tf_adjusted_match_probmatch_probabilityunique_id_lunique_id_rfirst_name_lfirst_name_rgamma_first_nameprob_gamma_first_name_non_matchprob_gamma_first_name_matchfirst_name_adj...city_lcity_rgamma_cityprob_gamma_city_non_matchprob_gamma_city_matchemail_lemail_rgamma_emailprob_gamma_email_non_matchprob_gamma_email_match
00.9999910.99964603JuliaJulia20.472290.5670370.975943...LondonNone-11.000001.000000hannah88@powers.comhannah88opowersc@m10.0070890.894782
10.9996450.98581102JuliaJulia20.472290.5670370.975943...LondonLondon10.146580.780896hannah88@powers.comhannah88@powers.com10.0070890.894782
20.9996450.98581101JuliaJulia20.472290.5670370.975943...LondonLondon10.146580.780896hannah88@powers.comhannah88@powers.com10.0070890.894782
30.9889140.91617113JuliaJulia20.472290.5670370.975943...LondonNone-11.000001.000000hannah88@powers.comhannah88opowersc@m10.0070890.894782
40.9979000.98311512JuliaJulia20.472290.5670370.975943...LondonLondon10.146580.780896hannah88@powers.comhannah88@powers.com10.0070890.894782
\n", "

5 rows × 31 columns

\n", "
" ], "text/plain": [ " tf_adjusted_match_prob match_probability unique_id_l unique_id_r \\\n", "0 0.999991 0.999646 0 3 \n", "1 0.999645 0.985811 0 2 \n", "2 0.999645 0.985811 0 1 \n", "3 0.988914 0.916171 1 3 \n", "4 0.997900 0.983115 1 2 \n", "\n", " first_name_l first_name_r gamma_first_name \\\n", "0 Julia Julia 2 \n", "1 Julia Julia 2 \n", "2 Julia Julia 2 \n", "3 Julia Julia 2 \n", "4 Julia Julia 2 \n", "\n", " prob_gamma_first_name_non_match prob_gamma_first_name_match \\\n", "0 0.47229 0.567037 \n", "1 0.47229 0.567037 \n", "2 0.47229 0.567037 \n", "3 0.47229 0.567037 \n", "4 0.47229 0.567037 \n", "\n", " first_name_adj ... city_l city_r gamma_city prob_gamma_city_non_match \\\n", "0 0.975943 ... London None -1 1.00000 \n", "1 0.975943 ... London London 1 0.14658 \n", "2 0.975943 ... London London 1 0.14658 \n", "3 0.975943 ... London None -1 1.00000 \n", "4 0.975943 ... London London 1 0.14658 \n", "\n", " prob_gamma_city_match email_l email_r \\\n", "0 1.000000 hannah88@powers.com hannah88opowersc@m \n", "1 0.780896 hannah88@powers.com hannah88@powers.com \n", "2 0.780896 hannah88@powers.com hannah88@powers.com \n", "3 1.000000 hannah88@powers.com hannah88opowersc@m \n", "4 0.780896 hannah88@powers.com hannah88@powers.com \n", "\n", " gamma_email prob_gamma_email_non_match prob_gamma_email_match \n", "0 1 0.007089 0.894782 \n", "1 1 0.007089 0.894782 \n", "2 1 0.007089 0.894782 \n", "3 1 0.007089 0.894782 \n", "4 1 0.007089 0.894782 \n", "\n", "[5 rows x 31 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = spark.read.parquet(\"data/fake_1000.parquet\")\n", "\n", "settings = {\n", " \"link_type\": \"dedupe_only\",\n", " \"max_iterations\": 5,\n", " \"blocking_rules\": [\n", " \"l.first_name = r.first_name\",\n", " \"l.surname = r.surname\",\n", " \"l.dob = r.dob\"\n", " ],\n", " \"comparison_columns\": [\n", " {\n", " \"col_name\": \"first_name\",\n", " \"num_levels\": 3,\n", " \"term_frequency_adjustments\": True\n", " },\n", " {\n", " \"col_name\": \"surname\",\n", " \"num_levels\": 3,\n", " \"term_frequency_adjustments\": True\n", " },\n", " {\n", " \"col_name\": \"dob\"\n", " },\n", " {\n", " \"col_name\": \"city\"\n", " },\n", " {\n", " \"col_name\": \"email\"\n", " }\n", " ]\n", "}\n", "\n", "from splink import Splink\n", "\n", "linker = Splink(settings, spark, df=df)\n", "df_e = linker.get_scored_comparisons()\n", "df_e.limit(5).toPandas()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Save model\n", "\n", "We are going to save the model settings, current parameters, and iteration history to a file called `saved_model.json`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "linker.save_model_as_json(\"saved_model.json\", overwrite=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4: Reload model\n", "\n", "Reloading the model creates a new Splink object. It populates the settings with the settings saved in the json files, and restores the parameters (the `m_probabilities` and `u_probabilities`) from the file" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:splink.iterate:Iteration 0 complete\n", "INFO:splink.params:The maximum change in parameters was 0.006465733051300049 for key π_gamma_email_prob_dist_match_level_1_probability\n", "INFO:splink.iterate:Iteration 1 complete\n", "INFO:splink.params:The maximum change in parameters was 0.0047650933265686035 for key π_gamma_email_prob_dist_match_level_0_probability\n", "INFO:splink.iterate:Iteration 2 complete\n", "INFO:splink.params:The maximum change in parameters was 0.0035470128059387207 for key π_gamma_email_prob_dist_match_level_0_probability\n", "INFO:splink.iterate:Iteration 3 complete\n", "INFO:splink.params:The maximum change in parameters was 0.0026850104331970215 for key π_gamma_email_prob_dist_match_level_1_probability\n", "INFO:splink.iterate:Iteration 4 complete\n", "INFO:splink.params:The maximum change in parameters was 0.0020679831504821777 for key π_gamma_email_prob_dist_match_level_1_probability\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tf_adjusted_match_probmatch_probabilityunique_id_lunique_id_rfirst_name_lfirst_name_rgamma_first_nameprob_gamma_first_name_non_matchprob_gamma_first_name_matchfirst_name_adj...city_lcity_rgamma_cityprob_gamma_city_non_matchprob_gamma_city_matchemail_lemail_rgamma_emailprob_gamma_email_non_matchprob_gamma_email_match
01.0000000.99996203JuliaJulia20.4690650.5681890.995446...LondonNone-11.0000001.000000hannah88@powers.comhannah88opowersc@m10.0013490.875251
10.9999890.99761302JuliaJulia20.4690650.5681890.995446...LondonLondon10.1408330.769179hannah88@powers.comhannah88@powers.com10.0013490.875251
20.9999890.99761301JuliaJulia20.4690650.5681890.995446...LondonLondon10.1408330.769179hannah88@powers.comhannah88@powers.com10.0013490.875251
30.9996530.98460613JuliaJulia20.4690650.5681890.995446...LondonNone-11.0000001.000000hannah88@powers.comhannah88opowersc@m10.0013490.875251
40.9999360.99714612JuliaJulia20.4690650.5681890.995446...LondonLondon10.1408330.769179hannah88@powers.comhannah88@powers.com10.0013490.875251
\n", "

5 rows × 31 columns

\n", "
" ], "text/plain": [ " tf_adjusted_match_prob match_probability unique_id_l unique_id_r \\\n", "0 1.000000 0.999962 0 3 \n", "1 0.999989 0.997613 0 2 \n", "2 0.999989 0.997613 0 1 \n", "3 0.999653 0.984606 1 3 \n", "4 0.999936 0.997146 1 2 \n", "\n", " first_name_l first_name_r gamma_first_name \\\n", "0 Julia Julia 2 \n", "1 Julia Julia 2 \n", "2 Julia Julia 2 \n", "3 Julia Julia 2 \n", "4 Julia Julia 2 \n", "\n", " prob_gamma_first_name_non_match prob_gamma_first_name_match \\\n", "0 0.469065 0.568189 \n", "1 0.469065 0.568189 \n", "2 0.469065 0.568189 \n", "3 0.469065 0.568189 \n", "4 0.469065 0.568189 \n", "\n", " first_name_adj ... city_l city_r gamma_city prob_gamma_city_non_match \\\n", "0 0.995446 ... London None -1 1.000000 \n", "1 0.995446 ... London London 1 0.140833 \n", "2 0.995446 ... London London 1 0.140833 \n", "3 0.995446 ... London None -1 1.000000 \n", "4 0.995446 ... London London 1 0.140833 \n", "\n", " prob_gamma_city_match email_l email_r \\\n", "0 1.000000 hannah88@powers.com hannah88opowersc@m \n", "1 0.769179 hannah88@powers.com hannah88@powers.com \n", "2 0.769179 hannah88@powers.com hannah88@powers.com \n", "3 1.000000 hannah88@powers.com hannah88opowersc@m \n", "4 0.769179 hannah88@powers.com hannah88@powers.com \n", "\n", " gamma_email prob_gamma_email_non_match prob_gamma_email_match \n", "0 1 0.001349 0.875251 \n", "1 1 0.001349 0.875251 \n", "2 1 0.001349 0.875251 \n", "3 1 0.001349 0.875251 \n", "4 1 0.001349 0.875251 \n", "\n", "[5 rows x 31 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from splink import load_from_json\n", "linker_2 = load_from_json(\"saved_model.json\", spark=spark, df=df) \n", "\n", "# Perform another set of iterations \n", "df_e_2 = linker_2.get_scored_comparisons()\n", "df_e_2.limit(5).toPandas()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# We can now see 10 iterations\n", "linker_2.params.all_charts_write_html_file(\"more_charts.html\", overwrite=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extracting m and u probabilities \n", "\n", "e.g. to copy and paste into a settings object as starting values" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gamma_first_name\n", "\"m_probabilities\": [0.36426427960395813, 0.06754633039236069, 0.5681893825531006],\n", "\"u_probabilities\": [0.5297731161117554, 0.0011621455196291208, 0.46906471252441406]\n", "gamma_surname\n", "\"m_probabilities\": [0.3780061900615692, 0.05709553509950638, 0.5648982524871826],\n", "\"u_probabilities\": [0.3240525722503662, 1.755945777404122e-05, 0.6759299039840698]\n", "gamma_dob\n", "\"m_probabilities\": [0.1348014920949936, 0.8651984930038452],\n", "\"u_probabilities\": [0.9818383455276489, 0.01816166192293167]\n", "gamma_city\n", "\"m_probabilities\": [0.23082058131694794, 0.7691794037818909],\n", "\"u_probabilities\": [0.8591668009757996, 0.14083316922187805]\n", "gamma_email\n", "\"m_probabilities\": [0.12474856525659561, 0.8752514123916626],\n", "\"u_probabilities\": [0.9986510276794434, 0.0013489817501977086]\n" ] } ], "source": [ "linker_2.params._print_m_u_probs()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }