{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test For `DriftCheckerEstimator`-`pydrift` \n",
    "\n",
    "We're going to test how it works with the famous titanic dataset\n",
    "\n",
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from sklearn import set_config\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.compose import make_column_transformer\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OrdinalEncoder\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from catboost import CatBoostClassifier\n",
    "\n",
    "from pydrift import DriftCheckerEstimator\n",
    "from pydrift.exceptions import ColumnsNotMatchException\n",
    "from pydrift.constants import PATH_DATA, RANDOM_STATE\n",
    "from pydrift.models import cat_features_fillna\n",
    "from pydrift.exceptions import DriftEstimatorException\n",
    "\n",
    "\n",
    "set_config(display='diagram')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read Data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_titanic = pd.read_csv(PATH_DATA / 'titanic.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Constants "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_LENGTH = df_titanic.shape[0]\n",
    "TARGET = 'Survived'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Split\n",
    "\n",
    "50% sample will give us a non-drift problem\n",
    "\n",
    "We drop Ticket and Cabin features because of cardinality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df_titanic.drop(columns=['Ticket', 'Cabin', 'PassengerId', 'Name', TARGET])\n",
    "y = df_titanic[TARGET]\n",
    "\n",
    "cat_features = (X\n",
    "                .select_dtypes(include=['category', 'object'])\n",
    "                .columns)\n",
    "\n",
    "X_filled = cat_features_fillna(X, cat_features)\n",
    "\n",
    "X_train_filled, X_test_filled, y_train, y_test = train_test_split(\n",
    "    X_filled, y, test_size=.5, random_state=RANDOM_STATE, stratify=y\n",
    ")\n",
    "\n",
    "catboost_classifier = CatBoostClassifier(\n",
    "    num_trees=5,\n",
    "    max_depth=3,\n",
    "    cat_features=cat_features,\n",
    "    random_state=RANDOM_STATE,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build Pipeline With DriftCheckerEstimator\n",
    "\n",
    "Catboost as estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style><div class=\"sk-top-container\"><div class=\"sk-container\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"bed3b610-44de-4e33-aa54-761dd0d44b26\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"bed3b610-44de-4e33-aa54-761dd0d44b26\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[('driftcheckerestimator',\n",
       "                 DriftCheckerEstimator(column_names=Index(['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], dtype='object'),\n",
       "                                       ml_classifier_model=<catboost.core.CatBoostClassifier object at 0x7ff2fa244dc0>))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"2efa4ab2-d7ce-48f2-b99b-57b1443fd3a2\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"2efa4ab2-d7ce-48f2-b99b-57b1443fd3a2\">driftcheckerestimator: DriftCheckerEstimator</label><div class=\"sk-toggleable__content\"><pre>DriftCheckerEstimator(column_names=Index(['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], dtype='object'),\n",
       "                      ml_classifier_model=<catboost.core.CatBoostClassifier object at 0x7ff2fa244dc0>)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"280681a7-01da-4a50-8b00-bde98cc0c218\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"280681a7-01da-4a50-8b00-bde98cc0c218\">CatBoostClassifier</label><div class=\"sk-toggleable__content\"><pre><catboost.core.CatBoostClassifier object at 0x7ff2fa244dc0></pre></div></div></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "Pipeline(steps=[('driftcheckerestimator',\n",
       "                 DriftCheckerEstimator(column_names=Index(['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], dtype='object'),\n",
       "                                       ml_classifier_model=<catboost.core.CatBoostClassifier object at 0x7ff2fa244dc0>))])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pipeline_catboost_drift_checker = make_pipeline(\n",
    "    DriftCheckerEstimator(ml_classifier_model=catboost_classifier, column_names=X.columns)\n",
    ")\n",
    "\n",
    "display(pipeline_catboost_drift_checker)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Let´s Fit And Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC training data: 0.86\n",
      "AUC testing data: 0.84\n"
     ]
    }
   ],
   "source": [
    "pipeline_catboost_drift_checker.fit(X_train_filled, y_train)\n",
    "\n",
    "y_score_train = pipeline_catboost_drift_checker.predict_proba(X_train_filled)[:, 1]\n",
    "y_score_test = pipeline_catboost_drift_checker.predict_proba(X_test_filled)[:, 1]\n",
    "\n",
    "auc_train = roc_auc_score(y_true=y_train, y_score=y_score_train)\n",
    "auc_test = roc_auc_score(y_true=y_test, y_score=y_score_test)\n",
    "\n",
    "print(f'AUC training data: {auc_train:.2f}')\n",
    "print(f'AUC testing data: {auc_test:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Same With Logistic Regression Pipeline "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>div.sk-top-container {color: black;background-color: white;}div.sk-toggleable {background-color: white;}label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.2em 0.3em;box-sizing: border-box;text-align: center;}div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}div.sk-estimator {font-family: monospace;background-color: #f0f8ff;margin: 0.25em 0.25em;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;}div.sk-estimator:hover {background-color: #d4ebff;}div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 2em;bottom: 0;left: 50%;}div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;}div.sk-item {z-index: 1;}div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;}div.sk-parallel-item {display: flex;flex-direction: column;position: relative;background-color: white;}div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}div.sk-parallel-item:only-child::after {width: 0;}div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0.2em;box-sizing: border-box;padding-bottom: 0.1em;background-color: white;position: relative;}div.sk-label label {font-family: monospace;font-weight: bold;background-color: white;display: inline-block;line-height: 1.2em;}div.sk-label-container {position: relative;z-index: 2;text-align: center;}div.sk-container {display: inline-block;position: relative;}</style><div class=\"sk-top-container\"><div class=\"sk-container\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"14349a73-0d27-467e-b441-51ef40f0e57c\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"14349a73-0d27-467e-b441-51ef40f0e57c\">Pipeline</label><div class=\"sk-toggleable__content\"><pre>Pipeline(steps=[('columntransformer',\n",
       "                 ColumnTransformer(transformers=[('pipeline',\n",
       "                                                  Pipeline(steps=[('simpleimputer',\n",
       "                                                                   SimpleImputer(strategy='most_frequent')),\n",
       "                                                                  ('ordinalencoder',\n",
       "                                                                   OrdinalEncoder())]),\n",
       "                                                  Index(['Sex', 'Embarked'], dtype='object')),\n",
       "                                                 ('simpleimputer',\n",
       "                                                  SimpleImputer(strategy='median'),\n",
       "                                                  Index(['Pclass', 'Age', 'SibSp', 'Parch', 'Fare'], dtype='object'))])),\n",
       "                ('driftcheckerestimator',\n",
       "                 DriftCheckerEstimator(column_names=Index(['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], dtype='object'),\n",
       "                                       ml_classifier_model=LogisticRegression(max_iter=1000,\n",
       "                                                                              random_state=1994)))])</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"d82612f9-5fca-4f86-8815-50fe1de71408\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"d82612f9-5fca-4f86-8815-50fe1de71408\">columntransformer: ColumnTransformer</label><div class=\"sk-toggleable__content\"><pre>ColumnTransformer(transformers=[('pipeline',\n",
       "                                 Pipeline(steps=[('simpleimputer',\n",
       "                                                  SimpleImputer(strategy='most_frequent')),\n",
       "                                                 ('ordinalencoder',\n",
       "                                                  OrdinalEncoder())]),\n",
       "                                 Index(['Sex', 'Embarked'], dtype='object')),\n",
       "                                ('simpleimputer',\n",
       "                                 SimpleImputer(strategy='median'),\n",
       "                                 Index(['Pclass', 'Age', 'SibSp', 'Parch', 'Fare'], dtype='object'))])</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"449d95ed-3aa4-4b66-bc70-abd36d04b7c8\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"449d95ed-3aa4-4b66-bc70-abd36d04b7c8\">pipeline</label><div class=\"sk-toggleable__content\"><pre>Index(['Sex', 'Embarked'], dtype='object')</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"3f1c38fc-75d6-49ec-bd6d-17bbbc2b1f56\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"3f1c38fc-75d6-49ec-bd6d-17bbbc2b1f56\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='most_frequent')</pre></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"550188a3-99ed-40d6-966e-9da051f8dfdd\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"550188a3-99ed-40d6-966e-9da051f8dfdd\">OrdinalEncoder</label><div class=\"sk-toggleable__content\"><pre>OrdinalEncoder()</pre></div></div></div></div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"503c97a4-d8d0-49ea-b9c5-33c491600f97\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"503c97a4-d8d0-49ea-b9c5-33c491600f97\">simpleimputer</label><div class=\"sk-toggleable__content\"><pre>Index(['Pclass', 'Age', 'SibSp', 'Parch', 'Fare'], dtype='object')</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"ca9cd588-45aa-49a7-a780-4d3cbec38925\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"ca9cd588-45aa-49a7-a780-4d3cbec38925\">SimpleImputer</label><div class=\"sk-toggleable__content\"><pre>SimpleImputer(strategy='median')</pre></div></div></div></div></div></div></div></div><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"e1e6e6df-ac5a-45a1-8b0f-5d751bea1a77\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"e1e6e6df-ac5a-45a1-8b0f-5d751bea1a77\">driftcheckerestimator: DriftCheckerEstimator</label><div class=\"sk-toggleable__content\"><pre>DriftCheckerEstimator(column_names=Index(['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], dtype='object'),\n",
       "                      ml_classifier_model=LogisticRegression(max_iter=1000,\n",
       "                                                             random_state=1994))</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"4407d737-eec9-42c6-87e2-db0dfaffb7ec\" type=\"checkbox\" ><label class=\"sk-toggleable__label\" for=\"4407d737-eec9-42c6-87e2-db0dfaffb7ec\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(max_iter=1000, random_state=1994)</pre></div></div></div></div></div></div></div></div></div></div></div></div>"
      ],
      "text/plain": [
       "Pipeline(steps=[('columntransformer',\n",
       "                 ColumnTransformer(transformers=[('pipeline',\n",
       "                                                  Pipeline(steps=[('simpleimputer',\n",
       "                                                                   SimpleImputer(strategy='most_frequent')),\n",
       "                                                                  ('ordinalencoder',\n",
       "                                                                   OrdinalEncoder())]),\n",
       "                                                  Index(['Sex', 'Embarked'], dtype='object')),\n",
       "                                                 ('simpleimputer',\n",
       "                                                  SimpleImputer(strategy='median'),\n",
       "                                                  Index(['Pclass', 'Age', 'SibSp', 'Parch', 'Fare'], dtype='object'))])),\n",
       "                ('driftcheckerestimator',\n",
       "                 DriftCheckerEstimator(column_names=Index(['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], dtype='object'),\n",
       "                                       ml_classifier_model=LogisticRegression(max_iter=1000,\n",
       "                                                                              random_state=1994)))])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, test_size=.5, random_state=RANDOM_STATE, stratify=y\n",
    ")\n",
    "\n",
    "categorical_pipeline = make_pipeline(\n",
    "    SimpleImputer(strategy='most_frequent'),\n",
    "    OrdinalEncoder()\n",
    ")\n",
    "\n",
    "column_transformer = make_column_transformer(\n",
    "    (categorical_pipeline, cat_features),\n",
    "    (SimpleImputer(strategy='median'), X_train.select_dtypes(include='number').columns)\n",
    ")\n",
    "\n",
    "pipeline_lr_drift_checker = make_pipeline(\n",
    "    column_transformer,\n",
    "    DriftCheckerEstimator(\n",
    "        ml_classifier_model=LogisticRegression(max_iter=1000,\n",
    "                                               random_state=RANDOM_STATE),\n",
    "        column_names=X.columns\n",
    "    )\n",
    ")\n",
    "\n",
    "display(pipeline_lr_drift_checker)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Let´s Fit And Predict "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC training data: 0.851\n",
      "AUC testing data: 0.855\n"
     ]
    }
   ],
   "source": [
    "pipeline_lr_drift_checker.fit(X_train, y_train)\n",
    "\n",
    "y_score_train = pipeline_lr_drift_checker.predict_proba(X_train)[:, 1]\n",
    "y_score_test = pipeline_lr_drift_checker.predict_proba(X_test)[:, 1]\n",
    "\n",
    "auc_train = roc_auc_score(y_true=y_train, y_score=y_score_train)\n",
    "auc_test = roc_auc_score(y_true=y_test, y_score=y_score_test)\n",
    "\n",
    "print(f'AUC training data: {auc_train:.3f}')\n",
    "print(f'AUC testing data: {auc_test:.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ok, Now With Drifted Data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df_titanic.drop(columns=['Ticket', 'Cabin', 'PassengerId', 'Name', TARGET])\n",
    "y = df_titanic[TARGET]\n",
    "\n",
    "cat_features = (X\n",
    "                .select_dtypes(include=['category', 'object'])\n",
    "                .columns)\n",
    "\n",
    "X_filled = cat_features_fillna(X, cat_features)\n",
    "\n",
    "X_train_filled, X_test_filled, y_train, y_test = train_test_split(\n",
    "    X_filled, y, test_size=.5, random_state=RANDOM_STATE, stratify=y\n",
    ")\n",
    "\n",
    "df_train_filled = pd.concat([X_train_filled, y_train], axis=1)\n",
    "df_train_filled_drifted = df_train_filled[(df_train_filled['Pclass'] > 1) & (df_train_filled['Fare'] > 10)].copy()\n",
    "\n",
    "X_train_filled_drifted = df_train_filled_drifted.drop(columns=TARGET)\n",
    "y_train_filled_drifted = df_train_filled_drifted[TARGET]\n",
    "\n",
    "df_test_filled = pd.concat([X_test_filled, y_test], axis=1)\n",
    "df_test_filled_drifted = df_test_filled[~(df_test_filled['Pclass'] > 1) & (df_test_filled['Fare'] > 10)].copy()\n",
    "\n",
    "X_test_filled_drifted = df_test_filled_drifted.drop(columns=TARGET)\n",
    "y_test_filled_drifted = df_test_filled_drifted[TARGET]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Let´s Try To Fit And Predict \n",
    "\n",
    "DriftEstimatorException tells you that there are some data drifts, you can acces to `drifted_columns` attribute to ckeck them, we will do in the next cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drift found in numerical columns check step\n",
      "Drift found in categorical columns check step\n",
      "Drift found in discriminative model step\n",
      "Drift found in pipeline_catboost_drift_checker\n"
     ]
    }
   ],
   "source": [
    "pipeline_catboost_drift_checker.fit(X_train_filled_drifted, y_train_filled_drifted)\n",
    "\n",
    "y_score_train = pipeline_catboost_drift_checker.predict_proba(X_train_filled_drifted)[:, 1]\n",
    "\n",
    "try:\n",
    "    y_score_test = pipeline_catboost_drift_checker.predict_proba(X_test_filled_drifted)[:, 1]\n",
    "\n",
    "    auc_train = roc_auc_score(y_true=y_train_filled_drifted, y_score=y_score_train)\n",
    "    auc_test = roc_auc_score(y_true=y_test_filled_drifted, y_score=y_score_test)\n",
    "\n",
    "    print(f'AUC training data: {auc_train:.2f}')\n",
    "    print(f'AUC testing data: {auc_test:.2f}')\n",
    "except DriftEstimatorException:\n",
    "    print('Drift found in pipeline_catboost_drift_checker')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# You Can Get Drifted Features From `DriftCheckerEstimator` Object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Fare, Parch, Age, Embarked, Pclass'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drifted_features = (\n",
    "    pipeline_catboost_drift_checker\n",
    "    .named_steps['driftcheckerestimator']\n",
    "    .get_drifted_features()\n",
    ")\n",
    "\n",
    "drifted_features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# You Can Also Get High Cardinality Features\n",
    "\n",
    "None in this case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "''"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(\n",
    "    pipeline_catboost_drift_checker\n",
    "    .named_steps['driftcheckerestimator']\n",
    "    .get_high_cardinality_features()\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}