{"nbformat": 4, "cells": [{"source": "# CSRNDArray - NDArray in Compressed Sparse Row Storage Format\n\nMany real world datasets deal with high dimensional sparse feature vectors. Take for instance a recommendation system where the number of categories and users is on the order of millions. The purchase data for each category by user would show that most users only make a few purchases, leading to a dataset with high sparsity (i.e. most of the elements are zeros).\n\nStoring and manipulating such large sparse matrices in the default dense structure results in wasted memory and processing on the zeros. To take advantage of the sparse structure of the matrix, the `CSRNDArray` in MXNet stores the matrix in [compressed sparse row (CSR)](https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_.28CSR.2C_CRS_or_Yale_format.29) format and uses specialized algorithms in operators.\n**The format is designed for 2D matrices with a large number of columns,\nand each row is sparse (i.e. with only a few nonzeros).**\n\n## Advantages of Compressed Sparse Row NDArray (CSRNDArray)\nFor matrices of high sparsity (e.g. ~1% non-zeros = ~1% density), there are two primary advantages of `CSRNDArray` over the existing `NDArray`:\n\n- memory consumption is reduced significantly\n- certain operations are much faster (e.g. matrix-vector multiplication)\n\nYou may be familiar with the CSR storage format in [SciPy](https://www.scipy.org/) and will note the similarities in MXNet's implementation. However there are some additional competitive features in `CSRNDArray` inherited from `NDArray`, such as non-blocking asynchronous evaluation and automatic parallelization that are not available in SciPy's flavor of CSR. You can find further explainations for evaluation and parallization strategy in MXNet in the [NDArray tutorial](https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html#lazy-evaluation-and-automatic-parallelization).\n\nThe introduction of `CSRNDArray` also brings a new attribute, `stype` as a holder for storage type info, to `NDArray`. You can query **ndarray.stype** now in addition to the oft-queried attributes such as **ndarray.shape**, **ndarray.dtype**, and **ndarray.context**. For a typical dense NDArray, the value of `stype` is **\"default\"**. For a `CSRNDArray`, the value of stype is **\"csr\"**.\n\n## Prerequisites\n\nTo complete this tutorial, you will need:\n\n- MXNet. See the instructions for your operating system in [Setup and Installation](https://mxnet.io/get_started/install.html)\n- [Jupyter](http://jupyter.org/)\n ```\n pip install jupyter\n ```\n- Basic knowledge of NDArray in MXNet. See the detailed tutorial for NDArray in [NDArray - Imperative tensor operations on CPU/GPU](https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html).\n- SciPy - A section of this tutorial uses SciPy package in Python. If you don't have SciPy, the example in that section will be ignored.\n- GPUs - A section of this tutorial uses GPUs. If you don't have GPUs on your machine, simply set the variable `gpu_device` (set in the GPUs section of this tutorial) to `mx.cpu()`.\n\n## Compressed Sparse Row Matrix\n\nA CSRNDArray represents a 2D matrix as three separate 1D arrays: **data**, **indptr** and **indices**, where the column indices for row `i` are stored in `indices[indptr[i]:indptr[i+1]]` in ascending order, and their corresponding values are stored in `data[indptr[i]:indptr[i+1]]`.\n\n- **data**: CSR format data array of the matrix\n- **indices**: CSR format index array of the matrix\n- **indptr**: CSR format index pointer array of the matrix\n\n### Example Matrix Compression\n\nFor example, given the matrix:\n```\n[[7, 0, 8, 0]\n [0, 0, 0, 0]\n [0, 9, 0, 0]]\n```\n\nWe can compress this matrix using CSR, and to do so we need to calculate `data`, `indices`, and `indptr`.\n\nThe `data` array holds all the non-zero entries of the matrix in row-major order. Put another way, you create a data array that has all of the zeros removed from the matrix, row by row, storing the numbers in that order. Your result:\n\n data = [7, 8, 9]\n\nThe `indices` array stores the column index for each non-zero element in `data`. As you cycle through the data array, starting with 7, you can see it is in column 0. Then looking at 8, you can see it is in column 2. Lastly 9 is in column 1. Your result:\n\n indices = [0, 2, 1]\n\nThe `indptr` array is what will help identify the rows where the data appears. It stores the offset into `data` of the first non-zero element number of each row of the matrix. This array always starts with 0 (reasons can be explored later), so indptr[0] is 0. Each subsequent value in the array is the aggregate number of non-zero elements up to that row. Looking at the first row of the matrix you can see two non-zero values, so indptr[1] is 2. The next row contains all zeros, so the aggregate is still 2, so indptr[2] is 2. Finally, you see the last row contains one non-zero element bring the aggregate to 3, so indptr[3] is 3. To reconstruct the dense matrix, you will use `data[0:2]` and `indices[0:2]` for the first row, `data[2:2]` and `indices[2:2]` for the second row (which contains all zeros), and `data[2:3]` and `indices[2:3]` for the third row. Your result:\n\n indptr = [0, 2, 2, 3]\n\nNote that in MXNet, the column indices for a given row are always sorted in ascending order,\nand duplicated column indices for the same row are not allowed.\n\n## Array Creation\n\nThere are a few different ways to create a `CSRNDArray`, but first let's recreate the matrix we just discussed using the `data`, `indices`, and `indptr` we calculated in the previous example.\n\nYou can create a CSRNDArray with data, indices and indptr by using the `csr_matrix` function:", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\n# Create a CSRNDArray with python lists\nshape = (3, 4)\ndata_list = [7, 8, 9]\nindices_list = [0, 2, 1]\nindptr_list = [0, 2, 2, 3]\na = mx.nd.sparse.csr_matrix(data_list, indptr_list, indices_list, shape)\n# Inspect the matrix\na.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "import numpy as np\n# Create a CSRNDArray with numpy arrays\ndata_np = np.array([7, 8, 9])\nindptr_np = np.array([0, 2, 2, 3])\nindices_np = np.array([0, 2, 1])\nb = mx.nd.sparse.csr_matrix(data_np, indptr_np, indices_np, shape)\nb.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "# Compare the two. They are exactly the same.\n{'a':a.asnumpy(), 'b':b.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "You can create an MXNet CSRNDArray from a `scipy.sparse.csr.csr_matrix` object by using the `array` function:", "cell_type": "markdown", "metadata": {}}, {"source": "try:\n import scipy.sparse as spsp\n # generate a csr matrix in scipy\n c = spsp.csr.csr_matrix((data_np, indices_np, indptr_np), shape=shape)\n # create a CSRNDArray from a scipy csr object\n d = mx.nd.sparse.array(c)\n print('d:{}'.format(d.asnumpy()))\nexcept ImportError:\n print(\"scipy package is required\")", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "What if you have a big set of data and you haven't calculated indices or indptr yet? Let's try a simple CSRNDArray from an existing array of data and derive those values with some built-in functions. We can mockup a \"big\" dataset with a random amount of the data being non-zero, then compress it by using the `tostype` function, which is explained further in the [Storage Type Conversion](#storage-type-conversion) section:", "cell_type": "markdown", "metadata": {}}, {"source": "big_array = mx.nd.round(mx.nd.random.uniform(low=0, high=1, shape=(1000, 100)))\nprint(big_array)\nbig_array_csr = big_array.tostype('csr')\n# Access indices array\nindices = big_array_csr.indices\n# Access indptr array\nindptr = big_array_csr.indptr\n# Access data array\ndata = big_array_csr.data\n# The total size of `data`, `indices` and `indptr` arrays is much lesser than the dense big_array!", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "You can also create a CSRNDArray from another using the `array` function specifying the element data type with the option `dtype`,\nwhich accepts a numpy type. By default, `float32` is used.", "cell_type": "markdown", "metadata": {}}, {"source": "# Float32 is used by default\ne = mx.nd.sparse.array(a)\n# Create a 16-bit float array\nf = mx.nd.array(a, dtype=np.float16)\n(e.dtype, f.dtype)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Inspecting Arrays\n\nA variety of methods are available for you to use for inspecting CSR arrays:\n* **.asnumpy()**\n* **.data**\n* **.indices**\n* **.indptr**\n\nAs you have seen already, we can inspect the contents of a `CSRNDArray` by filling\nits contents into a dense `numpy.ndarray` using the `asnumpy` function.", "cell_type": "markdown", "metadata": {}}, {"source": "a.asnumpy()", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "You can also inspect the internal storage of a CSRNDArray by accessing attributes such as `indptr`, `indices` and `data`:", "cell_type": "markdown", "metadata": {}}, {"source": "# Access data array\ndata = a.data\n# Access indices array\nindices = a.indices\n# Access indptr array\nindptr = a.indptr\n{'a.stype': a.stype, 'data':data, 'indices':indices, 'indptr':indptr}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Storage Type Conversion\n\nYou can also convert storage types with:\n* **tostype**\n* **cast_storage**\n\nTo convert an NDArray to a CSRNDArray and vice versa by using the ``tostype`` function:", "cell_type": "markdown", "metadata": {}}, {"source": "# Create a dense NDArray\nones = mx.nd.ones((2,2))\n# Cast the storage type from `default` to `csr`\ncsr = ones.tostype('csr')\n# Cast the storage type from `csr` to `default`\ndense = csr.tostype('default')\n{'csr':csr, 'dense':dense}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "To convert the storage type by using the `cast_storage` operator:", "cell_type": "markdown", "metadata": {}}, {"source": "# Create a dense NDArray\nones = mx.nd.ones((2,2))\n# Cast the storage type to `csr`\ncsr = mx.nd.sparse.cast_storage(ones, 'csr')\n# Cast the storage type to `default`\ndense = mx.nd.sparse.cast_storage(csr, 'default')\n{'csr':csr, 'dense':dense}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Copies\n\nYou can use the `copy` method which makes a deep copy of the array and its data, and returns a new array.\nYou can also use the `copyto` method or the slice operator `[]` to deep copy to an existing array.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.ones((2,2)).tostype('csr')\nb = a.copy()\nc = mx.nd.sparse.zeros('csr', (2,2))\nc[:] = a\nd = mx.nd.sparse.zeros('csr', (2,2))\na.copyto(d)\n{'b is a': b is a, 'b.asnumpy()':b.asnumpy(), 'c.asnumpy()':c.asnumpy(), 'd.asnumpy()':d.asnumpy()}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "If the storage types of source array and destination array do not match,\nthe storage type of destination array will not change when copying with `copyto` or\nthe slice operator `[]`.", "cell_type": "markdown", "metadata": {}}, {"source": "e = mx.nd.sparse.zeros('csr', (2,2))\nf = mx.nd.sparse.zeros('csr', (2,2))\ng = mx.nd.ones(e.shape)\ne[:] = g\ng.copyto(f)\n{'e.stype':e.stype, 'f.stype':f.stype, 'g.stype':g.stype}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Indexing and Slicing\nYou can slice a CSRNDArray on axis 0 with operator `[]`, which copies the slices and returns a new CSRNDArray.", "cell_type": "markdown", "metadata": {}}, {"source": "a = mx.nd.array(np.arange(6).reshape(3,2)).tostype('csr')\nb = a[1:2].asnumpy()\nc = a[:].asnumpy()\n{'a':a, 'b':b, 'c':c}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Note that multi-dimensional indexing or slicing along a particular axis is currently not supported for a CSRNDArray.\n\n## Sparse Operators and Storage Type Inference\n\nOperators that have specialized implementation for sparse arrays can be accessed in `mx.nd.sparse`. You can read the [mxnet.ndarray.sparse API documentation](https://mxnet.incubator.apache.org/versions/master/api/python/ndarray/sparse.html) to find what sparse operators are available.", "cell_type": "markdown", "metadata": {}}, {"source": "shape = (3, 4)\ndata = [7, 8, 9]\nindptr = [0, 2, 2, 3]\nindices = [0, 2, 1]\na = mx.nd.sparse.csr_matrix(data, indptr, indices, shape) # a csr matrix as lhs\nrhs = mx.nd.ones((4, 1)) # a dense vector as rhs\nout = mx.nd.sparse.dot(a, rhs) # invoke sparse dot operator specialized for dot(csr, dense)\n{'out':out}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "For any sparse operator, the storage type of output array is inferred based on inputs. You can either read the documentation or inspect the `stype` attribute of the output array to know what storage type is inferred:", "cell_type": "markdown", "metadata": {}}, {"source": "b = a * 2 # b will be a CSRNDArray since zero multiplied by 2 is still zero\nc = a + mx.nd.ones(shape=(3, 4)) # c will be a dense NDArray\n{'b.stype':b.stype, 'c.stype':c.stype}", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "For operators that don't specialize in sparse arrays, we can still use them with sparse inputs with some performance penalty. In MXNet, dense operators require all inputs and outputs to be in the dense format.\n\nIf sparse inputs are provided, MXNet will convert sparse inputs into dense ones temporarily, so that the dense operator can be used.\n\nIf sparse outputs are provided, MXNet will convert the dense outputs generated by the dense operator into the provided sparse format.", "cell_type": "markdown", "metadata": {}}, {"source": "e = mx.nd.sparse.zeros('csr', a.shape)\nd = mx.nd.log(a) # dense operator with a sparse input\ne = mx.nd.log(a, out=e) # dense operator with a sparse output\n{'a.stype':a.stype, 'd.stype':d.stype, 'e.stype':e.stype} # stypes of a and e will be not changed", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Note that warning messages will be printed when such a storage fallback event happens. If you are using jupyter notebook, the warning message will be printed in your terminal console.\n\n## Data Loading\n\nYou can load data in batches from a CSRNDArray using `mx.io.NDArrayIter`:", "cell_type": "markdown", "metadata": {}}, {"source": "# Create the source CSRNDArray\ndata = mx.nd.array(np.arange(36).reshape((9,4))).tostype('csr')\nlabels = np.ones([9, 1])\nbatch_size = 3\ndataiter = mx.io.NDArrayIter(data, labels, batch_size, last_batch_handle='discard')\n# Inspect the data batches\n[batch.data[0] for batch in dataiter]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "You can also load data stored in the [libsvm file format](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/) using `mx.io.LibSVMIter`, where the format is: ``