{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import re\n", "import numpy as np\n", "from scipy.sparse import csr_matrix\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer as skCountVectorizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class CountVectorizer():\n", " def _analyze(self, doc):\n", " token_pattern = re.compile(r\"\\b\\w\\w+\\b\")\n", " return token_pattern.findall(doc.lower())\n", "\n", " def _count_vocab(self, X, fixed_vocabulary):\n", " if fixed_vocabulary is False:\n", " vocabulary = {}\n", " vocabulary_cnt = 0\n", " else:\n", " vocabulary = self.vocabulary_\n", " values = []\n", " j_indices = []\n", " indptr = [0]\n", " for doc in X:\n", " feature_counter = {}\n", " for feature in self._analyze(doc):\n", " if fixed_vocabulary is False:\n", " if feature not in vocabulary:\n", " vocabulary[feature] = vocabulary_cnt\n", " vocabulary_cnt += 1\n", " else:\n", " if feature not in vocabulary:\n", " continue\n", " feature_idx = vocabulary[feature]\n", " if feature_idx not in feature_counter:\n", " feature_counter[feature_idx] = 1\n", " else:\n", " feature_counter[feature_idx] += 1\n", " values.extend(feature_counter.values())\n", " j_indices.extend(feature_counter.keys())\n", " indptr.append(len(j_indices))\n", " Xt = csr_matrix((values, j_indices, indptr),\n", " shape=(len(indptr) - 1, len(vocabulary)))\n", " return vocabulary, Xt\n", "\n", " def fit(self, X):\n", " vocabulary, Xt = self. _count_vocab(X, fixed_vocabulary=False)\n", " sorted_features = sorted(vocabulary.items())\n", " for new_val, (term, old_val) in enumerate(sorted_features):\n", " vocabulary[term] = new_val\n", " self.vocabulary_ = vocabulary\n", " return self\n", "\n", " def transform(self, X):\n", " _, Xt = self._count_vocab(X, fixed_vocabulary=True)\n", " return Xt\n", "\n", " def get_feature_names(self):\n", " return sorted(self.vocabulary_.keys())" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X = fetch_20newsgroups(remove=('headers', 'footers', 'quotes')).data\n", "for subset in [10, 100, 1000]:\n", " X_train = X[:subset]\n", " X_test = X[subset: 2 * subset]\n", " vec1 = CountVectorizer().fit(X_train)\n", " vec2 = skCountVectorizer().fit(X_train)\n", " assert np.array_equal(vec1.get_feature_names(), vec2.get_feature_names())\n", " Xt1 = vec1.transform(X_train)\n", " Xt2 = vec2.transform(X_train)\n", " assert np.array_equal(Xt1.toarray(), Xt2.toarray())\n", " Xt1 = vec1.transform(X_test)\n", " Xt2 = vec2.transform(X_test)\n", " assert np.array_equal(Xt1.toarray(), Xt2.toarray())" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }