{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a0268e63", "metadata": {}, "outputs": [], "source": [ "from tensorbay import GAS\n", "from tensorbay.dataset import Segment\n", "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow import keras as tfk\n", "import math\n", "import pnlp\n", "from pnlp import Dict\n", "from tensorflow.data import Dataset\n", "from transformers import AutoTokenizer\n", "from tensorbay.dataset import Dataset as TensorBayDataset\n", "from tensorbay.opendataset import Newsgroups20\n", "from typing import Union\n", "\n", "token = \"Accesskey-098e0c26fdc79f31a085f5b897052ba4\"" ] }, { "cell_type": "markdown", "id": "44a8c8a0", "metadata": {}, "source": [ "## Exploration" ] }, { "cell_type": "markdown", "id": "dabfa3ec", "metadata": {}, "source": [ "### Seg" ] }, { "cell_type": "code", "execution_count": null, "id": "c1d9f9d3", "metadata": {}, "outputs": [], "source": [ "seg = \"20news-18828\"\n", "seg = \"20news-bydate-train\"\n", "gas = GAS(token)\n", "dataset_client = gas.get_dataset(\"Newsgroups20\")\n", "segments = dataset_client.list_segment_names()\n", "segment = Segment(seg, dataset_client)" ] }, { "cell_type": "code", "execution_count": null, "id": "e7b461f0", "metadata": {}, "outputs": [], "source": [ "ele = segment[0]" ] }, { "cell_type": "markdown", "id": "c51eb70e", "metadata": {}, "source": [ "### DS" ] }, { "cell_type": "code", "execution_count": null, "id": "445ddfa1", "metadata": {}, "outputs": [], "source": [ "ds = TensorBayDataset(\"Newsgroups20\", gas)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2abc886", "metadata": {}, "outputs": [], "source": [ "ds.keys()" ] }, { "cell_type": "code", "execution_count": null, "id": "3be66d25", "metadata": {}, "outputs": [], "source": [ "ds.catalog.classification.categories" ] }, { "cell_type": "markdown", "id": "2bc7c63f", "metadata": {}, "source": [ "### LocalDS" ] }, { "cell_type": "code", "execution_count": null, "id": "09b587b0", "metadata": {}, "outputs": [], "source": [ "dataset = Newsgroups20(\"./data/\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0b24a7ca", "metadata": {}, "outputs": [], "source": [ "seg = dataset[\"20news-18828\"]\n", "x = seg[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "5f3e84df", "metadata": {}, "outputs": [], "source": [ "x.label" ] }, { "cell_type": "code", "execution_count": null, "id": "c73feacd", "metadata": {}, "outputs": [], "source": [ "x.open().read()[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "6ab415c0", "metadata": {}, "outputs": [], "source": [ "i = 0\n", "for v in seg:\n", " i += 1\n", " pass\n", "i" ] }, { "cell_type": "markdown", "id": "569a9d6e", "metadata": {}, "source": [ "## Usage" ] }, { "cell_type": "code", "execution_count": 2, "id": "0e604c7d", "metadata": {}, "outputs": [], "source": [ "class NewsGroupSegment:\n", "\n", " def __init__(\n", " self, \n", " client: Union[str, GAS], \n", " segment_name: str, \n", " tokenizer_path: str, \n", " label_file: str, \n", " max_length: int = 512\n", " ):\n", " if isinstance(client, GAS):\n", " self.dataset = TensorBayDataset(\"Newsgroups20\", client)\n", " elif isinstance(client, str):\n", " self.dataset = Newsgroups20(client)\n", " else:\n", " raise ValueError(\"Invalid dataset client\")\n", " self.segment = self.dataset[segment_name]\n", " self.max_length = max_length\n", " self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n", " labels = pnlp.read_lines(label_file)\n", " self.category_to_index = dict(zip(labels, range(len(labels))))\n", "\n", " def __call__(self):\n", " for data in self.segment:\n", " with data.open() as fp:\n", " txt = fp.read().decode(\"utf8\", errors=\"ignore\")\n", " ids = self.tokenizer.encode(\n", " txt, max_length=self.max_length, truncation=True, padding=\"max_length\"\n", " )\n", " input_tensor = tf.convert_to_tensor(np.array(ids), dtype=tf.int32)\n", " category = self.category_to_index[data.label.classification.category]\n", " category_tensor = tf.convert_to_tensor(category, dtype=tf.int32)\n", " yield input_tensor, category_tensor" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3a6ebbe", "metadata": {}, "outputs": [], "source": [ "def text_cnn(config, inputs):\n", " embed = tfk.layers.Embedding(config.vocab_size, config.embed_size,\n", " embeddings_initializer=tfk.initializers.RandomUniform(minval=-1, maxval=1),\n", " input_length=config.max_len,\n", " name='embedding')(inputs)\n", " embed = tfk.layers.Reshape((config.max_len, config.embed_size, 1), name='add_channel')(embed)\n", "\n", " pool_outputs = []\n", " for filter_size in list(map(int, config.filter_sizes.split(','))):\n", " conv = tfk.layers.Conv2D(config.num_filters, \n", " (filter_size, config.embed_size), \n", " strides=(1, 1), \n", " padding='valid',\n", " data_format='channels_last', \n", " activation='relu',\n", " kernel_initializer='glorot_normal',\n", " bias_initializer=tfk.initializers.constant(0.1),\n", " name='convolution_{:d}'.format(filter_size)\n", " )(embed)\n", " pool = tfk.layers.MaxPool2D(pool_size=(config.max_len - filter_size + 1, 1),\n", " strides=(1, 1), padding='valid',\n", " data_format='channels_last',\n", " name='max_pooling_{:d}'.format(filter_size))(conv)\n", " pool_outputs.append(pool)\n", "\n", " z = tfk.layers.concatenate(pool_outputs, axis=-1, name='concatenate')\n", " z = tfk.layers.Flatten(data_format='channels_last', name='flatten')(z)\n", " z = tfk.layers.Dropout(config.dropout, name='dropout')(z)\n", " return z" ] }, { "cell_type": "code", "execution_count": 4, "id": "1e762bae", "metadata": {}, "outputs": [], "source": [ "def build_model(config, module):\n", " inputs = tfk.Input(shape=(config.max_len, ), name='input_data')\n", " z = module(config, inputs)\n", " outputs = tfk.layers.Dense(config.num_classes, activation='softmax',\n", " kernel_initializer='glorot_normal',\n", " bias_initializer=tfk.initializers.constant(0.1),\n", " kernel_regularizer=tfk.regularizers.l2(config.regularizers_lambda),\n", " bias_regularizer=tfk.regularizers.l2(config.regularizers_lambda),\n", " name='dense')(z)\n", " model = tfk.Model(inputs=inputs, outputs=outputs)\n", " return model" ] }, { "cell_type": "code", "execution_count": 5, "id": "8875cd30", "metadata": {}, "outputs": [], "source": [ "config = Dict({\n", " \"vocab_size\": 21128,\n", " \"embed_size\": 256,\n", " \"max_len\": 512,\n", " \"num_filters\": 128,\n", " \"filter_sizes\": \"2,3,4\",\n", " \"dropout\": 0.1,\n", " \"regularizers_lambda\": 0.01,\n", " \"num_classes\": 20\n", "})" ] }, { "cell_type": "code", "execution_count": 7, "id": "b0bff0c4", "metadata": {}, "outputs": [], "source": [ "max_len = 512\n", "batch_size = 32\n", "segment_name = \"20news-18828\"\n", "client = \"./data/\" # GAS(token)\n", "data = NewsGroupSegment(client, segment_name, \"./bert/\", \"labels.txt\", max_len)\n", "epochs = 5\n", "steps_per_epoch = math.ceil(len(data.segment) / batch_size)\n", "\n", "\n", "dataset = Dataset.from_generator(\n", " data,\n", " output_signature=(\n", " tf.TensorSpec(shape=(max_len, ), dtype=tf.float32),\n", " tf.TensorSpec(shape=(), dtype=tf.int32),\n", " ),\n", ").shuffle(buffer_size=len(data.segment), reshuffle_each_iteration=True).batch(batch_size).repeat(epochs)" ] }, { "cell_type": "code", "execution_count": 8, "id": "212d0948", "metadata": {}, "outputs": [], "source": [ "model = build_model(config, text_cnn)" ] }, { "cell_type": "code", "execution_count": 9, "id": "32fa417e", "metadata": {}, "outputs": [], "source": [ "model.compile(\n", " optimizer=tfk.optimizers.Adamax(learning_rate=1e-3),\n", " loss=tfk.losses.SparseCategoricalCrossentropy(),\n", " metrics=[tfk.metrics.SparseCategoricalAccuracy()],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "03dd220e", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] } ], "source": [ "model.fit(dataset, epochs=epochs, steps_per_epoch=steps_per_epoch)" ] }, { "cell_type": "markdown", "id": "664bc0ac", "metadata": {}, "source": [ "## Summary" ] }, { "cell_type": "markdown", "id": "7c7feb46", "metadata": {}, "source": [ "一些使用中的问题:\n", "\n", "- 数据集文件或 Segment 无法 shuffle,可以按文件随机选择某个文件中的样例\n", "- 加载异常慢,尤其是在线加载,可以先把文件 Load 到本地\n", "- 20Newsgroup 数据集类别有误,只有 16 个类别,实际 20 个" ] }, { "cell_type": "markdown", "id": "8bcefd46", "metadata": {}, "source": [ "## Reference\n", "\n", "- [TensorFlow — TensorBay documentation](https://tensorbay-python-sdk.graviti.com/en/stable/integrations/tensorflow.html)\n", "- [20 Newsgroups — TensorBay documentation](https://tensorbay-python-sdk.graviti.com/en/stable/examples/Newsgroups20.html#newsgroups)\n", "- [Home Page for 20 Newsgroups Data Set](http://qwone.com/~jason/20Newsgroups/)\n", "- [Python SDK - 太子长琴 / 20 Newsgroups - Graviti](https://gas.graviti.cn/dataset/yam/Newsgroups20/code/python-sdk)\n", "- [ShaneTian/TextCNN: TextCNN by TensorFlow 2.0.0 ( tf.keras mainly ).](https://github.com/ShaneTian/TextCNN)" ] }, { "cell_type": "code", "execution_count": null, "id": "c3459ebc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }