{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "This IPython notebook illustrates how to select the best learning based matcher. First, we need to import py_entitymatching package and other libraries as follows:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import py_entitymatching package\n", "import py_entitymatching as em\n", "import os\n", "import pandas as pd\n", "\n", "# Set the seed value \n", "seed = 0" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Get the datasets directory\n", "datasets_dir = em.get_install_path() + os.sep + 'datasets'\n", "\n", "path_A = datasets_dir + os.sep + 'dblp_demo.csv'\n", "path_B = datasets_dir + os.sep + 'acm_demo.csv'\n", "path_labeled_data = datasets_dir + os.sep + 'labeled_data_demo.csv'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Metadata file is not present in the given path; proceeding to read the csv file.\n", "Metadata file is not present in the given path; proceeding to read the csv file.\n" ] } ], "source": [ "A = em.read_csv_metadata(path_A, key='id')\n", "B = em.read_csv_metadata(path_B, key='id')\n", "# Load the pre-labeled data\n", "S = em.read_csv_metadata(path_labeled_data, \n", " key='_id',\n", " ltable=A, rtable=B, \n", " fk_ltable='ltable_id', fk_rtable='rtable_id')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, split the labeled data into development set and evaluation set and convert them into feature vectors" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Split S into I an J\n", "IJ = em.split_train_test(S, train_proportion=0.5, random_state=0)\n", "I = IJ['train']\n", "J = IJ['test']" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Generate a set of features\n", "F = em.get_features_for_matching(A, B, validate_inferred_attr_types=False)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Convert I into feature vectors using updated F\n", "H = em.extract_feature_vecs(I, \n", " feature_table=F, \n", " attrs_after='label',\n", " show_progress=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Compute accuracy of X (Decision Tree) on J" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It involves the following steps:\n", "\n", "1. Train X using H\n", "2. Convert J into a set of feature vectors (L)\n", "3. Predict on L using X\n", "4. Evaluate the predictions" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Instantiate the matcher to evaluate.\n", "dt = em.DTMatcher(name='DecisionTree', random_state=0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Train using feature vectors from I \n", "dt.fit(table=H, \n", " exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'label'], \n", " target_attr='label')\n", "\n", "# Convert J into a set of feature vectors using F\n", "L = em.extract_feature_vecs(J, feature_table=F,\n", " attrs_after='label', show_progress=False)\n", "\n", "# Predict on L \n", "predictions = dt.predict(table=L, exclude_attrs=['_id', 'ltable_id', 'rtable_id', 'label'], \n", " append=True, target_attr='predicted', inplace=False, return_probs=True,\n", " probs_attr='proba')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | _id | \n", "ltable_id | \n", "rtable_id | \n", "predicted | \n", "proba | \n", "
---|---|---|---|---|---|
124 | \n", "124 | \n", "l1647 | \n", "r366 | \n", "0 | \n", "0.0 | \n", "
54 | \n", "54 | \n", "l332 | \n", "r1463 | \n", "0 | \n", "0.0 | \n", "
268 | \n", "268 | \n", "l1499 | \n", "r1725 | \n", "0 | \n", "0.0 | \n", "
293 | \n", "293 | \n", "l759 | \n", "r1749 | \n", "1 | \n", "1.0 | \n", "
230 | \n", "230 | \n", "l1580 | \n", "r1711 | \n", "1 | \n", "1.0 | \n", "