{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import numpy as np\n",
    "from scipy.sparse import csr_matrix\n",
    "from sklearn.datasets import fetch_20newsgroups\n",
    "from sklearn.feature_extraction.text import CountVectorizer as skCountVectorizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CountVectorizer():\n",
    "    def _analyze(self, doc):\n",
    "        token_pattern = re.compile(r\"\\b\\w\\w+\\b\")\n",
    "        return token_pattern.findall(doc.lower())\n",
    "\n",
    "    def _count_vocab(self, X, fixed_vocabulary):\n",
    "        if fixed_vocabulary is False:\n",
    "            vocabulary = {}\n",
    "            vocabulary_cnt = 0\n",
    "        else:\n",
    "            vocabulary = self.vocabulary_\n",
    "        values = []\n",
    "        j_indices = []\n",
    "        indptr = [0]\n",
    "        for doc in X:\n",
    "            feature_counter = {}\n",
    "            for feature in self._analyze(doc):\n",
    "                if fixed_vocabulary is False:\n",
    "                    if feature not in vocabulary:\n",
    "                        vocabulary[feature] = vocabulary_cnt\n",
    "                        vocabulary_cnt += 1\n",
    "                else:\n",
    "                    if feature not in vocabulary:\n",
    "                        continue\n",
    "                feature_idx = vocabulary[feature]\n",
    "                if feature_idx not in feature_counter:\n",
    "                    feature_counter[feature_idx] = 1\n",
    "                else:\n",
    "                    feature_counter[feature_idx] += 1\n",
    "            values.extend(feature_counter.values())\n",
    "            j_indices.extend(feature_counter.keys())\n",
    "            indptr.append(len(j_indices))\n",
    "        Xt = csr_matrix((values, j_indices, indptr),\n",
    "                        shape=(len(indptr) - 1, len(vocabulary)))\n",
    "        return vocabulary, Xt\n",
    "\n",
    "    def fit(self, X):\n",
    "        vocabulary, Xt = self. _count_vocab(X, fixed_vocabulary=False)\n",
    "        sorted_features = sorted(vocabulary.items())\n",
    "        for new_val, (term, old_val) in enumerate(sorted_features):\n",
    "            vocabulary[term] = new_val\n",
    "        self.vocabulary_ = vocabulary\n",
    "        return self\n",
    "\n",
    "    def transform(self, X):\n",
    "        _, Xt = self._count_vocab(X, fixed_vocabulary=True)\n",
    "        return Xt\n",
    "\n",
    "    def get_feature_names(self):\n",
    "        return sorted(self.vocabulary_.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = fetch_20newsgroups(remove=('headers', 'footers', 'quotes')).data\n",
    "for subset in [10, 100, 1000]:\n",
    "    X_train = X[:subset]\n",
    "    X_test = X[subset: 2 * subset]\n",
    "    vec1 = CountVectorizer().fit(X_train)\n",
    "    vec2 = skCountVectorizer().fit(X_train)\n",
    "    assert np.array_equal(vec1.get_feature_names(), vec2.get_feature_names())\n",
    "    Xt1 = vec1.transform(X_train)\n",
    "    Xt2 = vec2.transform(X_train)\n",
    "    assert np.array_equal(Xt1.toarray(), Xt2.toarray())\n",
    "    Xt1 = vec1.transform(X_test)\n",
    "    Xt2 = vec2.transform(X_test)\n",
    "    assert np.array_equal(Xt1.toarray(), Xt2.toarray())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}