{ "cells": [ { "cell_type": "markdown", "id": "sweet-external", "metadata": {}, "source": [ "# How to use PCMF package!" ] }, { "cell_type": "code", "execution_count": 1, "id": "outer-viewer", "metadata": {}, "outputs": [], "source": [ "#import PCMF\n", "from PCMF.Positive_CMF import Positive_Collective_Matrix_Factorization\n", "#data processing\n", "import numpy as np" ] }, { "cell_type": "markdown", "id": "dirty-manual", "metadata": {}, "source": [ "# Make sample dataset(Example data set)" ] }, { "cell_type": "code", "execution_count": 2, "id": "creative-lobby", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[0.24847769, 0.92837696, 0.30059839, ..., 0.12561425, 0.66277199,\n", " 0.34027897],\n", " [0.22307267, 0.48443072, 0.57220612, ..., 0.49297108, 0.35689179,\n", " 0.41624412],\n", " [0.7361003 , 0.77629826, 0.42066231, ..., 0.97781795, 0.84493988,\n", " 0.35037373],\n", " ...,\n", " [0.12404128, 0.0019809 , 0.6390083 , ..., 0.94168058, 0.09853853,\n", " 0.38612824],\n", " [0.07181645, 0.56613074, 0.60231026, ..., 0.55167077, 0.71364693,\n", " 0.17202591],\n", " [0.88714109, 0.94597517, 0.09069533, ..., 0.71342583, 0.77460753,\n", " 0.78309143]]),\n", " array([[0.46091998, 0.15893277, 0.50549811, ..., 0.24884235, 0.73122779,\n", " 0.98397691],\n", " [0.02287738, 0.3495672 , 0.93856203, ..., 0.98989155, 0.66211367,\n", " 0.92033014],\n", " [0.08953751, 0.20408854, 0.79054535, ..., 0.24215255, 0.74023685,\n", " 0.32782198],\n", " ...,\n", " [0.1782251 , 0.08667535, 0.64887784, ..., 0.04384896, 0.65464416,\n", " 0.27241068],\n", " [0.32575456, 0.10711986, 0.43629966, ..., 0.46911493, 0.61064817,\n", " 0.1754898 ],\n", " [0.43208215, 0.29567345, 0.41956385, ..., 0.8987314 , 0.64731738,\n", " 0.14023682]]))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Randomly generate the original matrix.\n", "X = np.random.random_sample((1000, 100))\n", "Y = np.random.random_sample((1000, 30))\n", "X, Y" ] }, { "cell_type": "code", "execution_count": 3, "id": "committed-texture", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.59242526, 0.94391661, 0.81555406, ..., 0.45420413, 0.07689032,\n", " 0.38228694],\n", " [0.10679816, 0.76132749, 0.8370853 , ..., 0.94615018, 0.65010113,\n", " 0.51260314],\n", " [0.69759489, 0.28559725, 0.09380406, ..., 0.24183029, 0.55903652,\n", " 0.55620306],\n", " ...,\n", " [0.84851777, 0.8064424 , 0.84641176, ..., 0.81146947, 0.67658812,\n", " 0.29319071],\n", " [0.11170263, 0.76870714, 0.83241645, ..., 0.3189608 , 0.04840133,\n", " 0.86164286],\n", " [0.07015802, 0.31711469, 0.2241136 , ..., 0.92555989, 0.00483782,\n", " 0.80231256]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Weight of each element of Y of loss function.\n", "wY = np.random.random_sample((1000, 30))\n", "wY" ] }, { "cell_type": "markdown", "id": "visible-helen", "metadata": {}, "source": [ "# Training PCMF" ] }, { "cell_type": "code", "execution_count": 4, "id": "historical-officer", "metadata": {}, "outputs": [], "source": [ "cmf = Positive_Collective_Matrix_Factorization(X, Y, alpha=0.5, d_hidden=12, lamda=0.1)" ] }, { "cell_type": "code", "execution_count": 5, "id": "posted-cooling", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Info] At time-step 0, loss is 17.701169967651367\n", "[Info] At time-step 50, loss is 9.045083999633789\n", "[Info] At time-step 100, loss is 5.19536828994751\n", "[Info] At time-step 150, loss is 3.2300522327423096\n", "[Info] At time-step 200, loss is 1.722853183746338\n", "[Info] At time-step 250, loss is 1.0922274589538574\n", "[Info] At time-step 300, loss is 1.0090701580047607\n", "[Info] At time-step 350, loss is 0.9783576130867004\n", "[Info] At time-step 400, loss is 0.9572911858558655\n", "[Info] At time-step 450, loss is 0.9419869184494019\n", "[Info] At time-step 500, loss is 0.9304507374763489\n" ] }, { "data": { "text/plain": [ "(array([[0.00619301, 0.00809157, 0.00185326, ..., 0.00295491, 0.00102456,\n", " 0.00220665],\n", " [0.00611927, 0.007062 , 0.00174178, ..., 0.00672322, 0.00129402,\n", " 0.00208969],\n", " [0.00177775, 0.00116573, 0.00310689, ..., 0.00278253, 0.00548184,\n", " 0.00568578],\n", " ...,\n", " [0.00470339, 0.00284737, 0.00219152, ..., 0.00136558, 0.00203684,\n", " 0.00263309],\n", " [0.00225398, 0.00183056, 0.00184434, ..., 0.00957577, 0.00270244,\n", " 0.00212347],\n", " [0.00132301, 0.00318799, 0.00137424, ..., 0.00370901, 0.00549211,\n", " 0.00241262]], dtype=float32),\n", " array([[0.00328673, 0.00322875, 0.00319095, ..., 0.00300826, 0.00441328,\n", " 0.00408365],\n", " [0.00308168, 0.00410404, 0.00410875, ..., 0.00431702, 0.00392337,\n", " 0.00203434],\n", " [0.00486501, 0.00450918, 0.00258157, ..., 0.00380386, 0.00226068,\n", " 0.00241681],\n", " ...,\n", " [0.00401949, 0.00417189, 0.00420266, ..., 0.00427896, 0.00387536,\n", " 0.0037397 ],\n", " [0.00444068, 0.00399209, 0.00425376, ..., 0.00415507, 0.00242299,\n", " 0.00392819],\n", " [0.0037223 , 0.00423859, 0.00303006, ..., 0.00450027, 0.00434054,\n", " 0.00723165]], dtype=float32),\n", " array([[0.00203041, 0.00376991, 0.00133253, 0.00786313, 0.00257054,\n", " 0.00449617, 0.007386 , 0.0019676 , 0.0023274 , 0.00209709,\n", " 0.00200457, 0.00219077, 0.00198193, 0.00706211, 0.00227018,\n", " 0.00481099, 0.00537675, 0.00124624, 0.00234203, 0.00193601,\n", " 0.00210485, 0.00208073, 0.0052475 , 0.00182663, 0.00518937,\n", " 0.00415232, 0.00221185, 0.00352488, 0.00112584, 0.00541137],\n", " [0.00786245, 0.0013524 , 0.00294638, 0.00209544, 0.00152745,\n", " 0.00171906, 0.00206297, 0.0036735 , 0.0014902 , 0.00108148,\n", " 0.00619554, 0.0020093 , 0.00211902, 0.00214887, 0.00600566,\n", " 0.0041916 , 0.00765501, 0.00143352, 0.00230244, 0.01046781,\n", " 0.00196616, 0.00523918, 0.00141301, 0.01054175, 0.00335738,\n", " 0.00449362, 0.00234817, 0.00410568, 0.00351846, 0.00220969],\n", " [0.00224716, 0.00442828, 0.01085493, 0.00311857, 0.00219517,\n", " 0.00283331, 0.00404022, 0.00163152, 0.00256795, 0.00198213,\n", " 0.00217025, 0.00572459, 0.00205693, 0.00228126, 0.00520555,\n", " 0.00637635, 0.00138925, 0.00335152, 0.0049909 , 0.00351976,\n", " 0.00899771, 0.00465523, 0.00293077, 0.00214331, 0.00488522,\n", " 0.00404807, 0.0034802 , 0.00184786, 0.00557039, 0.00337978],\n", " [0.00201379, 0.00207763, 0.00268584, 0.00226676, 0.00293873,\n", " 0.01008144, 0.00952133, 0.00227971, 0.00218996, 0.0015605 ,\n", " 0.00574247, 0.00192397, 0.00221183, 0.00177528, 0.0021055 ,\n", " 0.00483836, 0.00147783, 0.00213373, 0.00374119, 0.00227146,\n", " 0.00466285, 0.00233669, 0.00374851, 0.0047995 , 0.00202598,\n", " 0.0017043 , 0.00354986, 0.00212978, 0.00501988, 0.00422556],\n", " [0.00520168, 0.00253476, 0.00183491, 0.00461178, 0.00474329,\n", " 0.00267396, 0.00142287, 0.00229034, 0.00214646, 0.00526973,\n", " 0.00213205, 0.00225656, 0.00331753, 0.00302078, 0.002235 ,\n", " 0.0021887 , 0.00226603, 0.00292394, 0.00131771, 0.00196172,\n", " 0.01076068, 0.00266699, 0.00221681, 0.00203325, 0.00920436,\n", " 0.00187069, 0.00193321, 0.00468633, 0.00203348, 0.0088259 ],\n", " [0.00280944, 0.00278756, 0.00181516, 0.0021846 , 0.00521531,\n", " 0.00107158, 0.00185494, 0.00305022, 0.00216553, 0.00122412,\n", " 0.00184354, 0.00208403, 0.00184739, 0.00166199, 0.00241919,\n", " 0.00154003, 0.00603467, 0.00338922, 0.00222716, 0.00118293,\n", " 0.00105201, 0.0033551 , 0.00189878, 0.00147722, 0.00898661,\n", " 0.00141901, 0.00123493, 0.00385798, 0.00379461, 0.0022931 ],\n", " [0.0031502 , 0.00475225, 0.00865962, 0.00303563, 0.00182524,\n", " 0.00310831, 0.00373822, 0.00209216, 0.00199991, 0.00354059,\n", " 0.00470833, 0.00584881, 0.00214915, 0.00227718, 0.0012489 ,\n", " 0.00503554, 0.00162324, 0.00518536, 0.00220567, 0.00503302,\n", " 0.00514321, 0.00171418, 0.00533724, 0.0020934 , 0.00191555,\n", " 0.00274799, 0.00538925, 0.00432319, 0.00169857, 0.00504138],\n", " [0.00325267, 0.00193892, 0.00139512, 0.00207437, 0.0020957 ,\n", " 0.00391857, 0.00118856, 0.00230416, 0.00281114, 0.00442782,\n", " 0.0015881 , 0.00745826, 0.0015965 , 0.0021638 , 0.00214091,\n", " 0.00161764, 0.00514841, 0.00279982, 0.00196707, 0.00524262,\n", " 0.00158327, 0.00100088, 0.00202392, 0.00387418, 0.00152139,\n", " 0.00151206, 0.00716418, 0.00489936, 0.0019699 , 0.00336328],\n", " [0.00218368, 0.00209763, 0.00441022, 0.00222084, 0.00211042,\n", " 0.001714 , 0.00165781, 0.00979603, 0.00334506, 0.0056306 ,\n", " 0.00433817, 0.00260242, 0.0021696 , 0.00360582, 0.00221111,\n", " 0.00174567, 0.00201513, 0.00176302, 0.00242707, 0.00182754,\n", " 0.00273384, 0.00439774, 0.00220329, 0.00359295, 0.00474804,\n", " 0.00147336, 0.00929792, 0.00143062, 0.00572066, 0.00224921],\n", " [0.00566491, 0.00661019, 0.0054153 , 0.00215879, 0.00176761,\n", " 0.00571894, 0.00469357, 0.0036881 , 0.00326901, 0.00523191,\n", " 0.00190031, 0.00750817, 0.00314911, 0.00232369, 0.00144236,\n", " 0.00209216, 0.00211145, 0.00424654, 0.00447402, 0.00329003,\n", " 0.00203624, 0.00307613, 0.0041928 , 0.00394207, 0.00836946,\n", " 0.00179823, 0.00496766, 0.00457853, 0.00279429, 0.00168206],\n", " [0.00687583, 0.00147088, 0.00666172, 0.00400995, 0.00519511,\n", " 0.0019539 , 0.00226857, 0.00211126, 0.00292674, 0.00222733,\n", " 0.0020272 , 0.00416601, 0.0038624 , 0.00223911, 0.00400104,\n", " 0.00192922, 0.0022951 , 0.00689524, 0.00284972, 0.00430824,\n", " 0.00193489, 0.00391366, 0.00642208, 0.00718445, 0.00153731,\n", " 0.01052608, 0.0019937 , 0.00192113, 0.00621613, 0.00137313],\n", " [0.00236742, 0.0021714 , 0.00133542, 0.00227059, 0.00217337,\n", " 0.00664698, 0.00876401, 0.00232659, 0.00504048, 0.00407829,\n", " 0.00324705, 0.00403846, 0.00566776, 0.00845151, 0.00742011,\n", " 0.00648395, 0.00294321, 0.00462923, 0.00176773, 0.00153202,\n", " 0.00228893, 0.00198374, 0.00234062, 0.0018417 , 0.00632038,\n", " 0.00118078, 0.0017166 , 0.00607118, 0.00287131, 0.00236422]],\n", " dtype=float32))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cmf.train(link_X = 'sigmoid', link_Y = 'sigmoid', \n", " weight_X = None, weight_Y =wY, \n", " optim_steps=501, verbose=50, lr=0.05)" ] }, { "cell_type": "code", "execution_count": null, "id": "narrow-logging", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }