{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=logging.DEBUG)\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.svm import SVC\n", "\n", "from wtk.utilities import get_ucr_dataset, krein_svm_grid_search\n", "from wtk import transform_to_dist_matrix, get_kernel_matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read UCR data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "X_train, y_train, X_test, y_test = get_ucr_dataset('../data/UCR/raw_data/', 'ItalyPowerDemand')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute wasserstein distance matrices with subsequent length $k=10$\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "D_train, D_test = transform_to_dist_matrix(X_train, X_test, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the grid search" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Starting analysis\n", "INFO:root:Accuracy = 95.34\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Best C: 1.0\n", "Best gamma: 0.1\n" ] } ], "source": [ "svm_clf = krein_svm_grid_search(D_train, D_test, y_train, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Alternatively: Get the kernel matrices computed from the distance matrices ..." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "K_train = get_kernel_matrix(D_train, psd=True, gamma=0.2)\n", "K_test = get_kernel_matrix(D_test, psd=False, gamma=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ... and train your own classifier" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9640427599611273" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = SVC(C=5, kernel='precomputed')\n", "clf.fit(K_train, y_train)\n", "y_pred = clf.predict(K_test)\n", "accuracy_score(y_test, y_pred)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }