{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 一.变分EM算法\n", "在介绍LDA的变分EM实现之前,首先我们要弄懂什么是变分EM,变分推断我们在之前提过,EM算法也在之前提过,可是将它俩凑一起似乎就不认识了呢....,这里还请大家回去看一下15章的这俩小结[《变分推断的原理推导》](https://nbviewer.jupyter.org/github/zhulei227/ML_Notes/blob/master/notebooks/15_01_VI_%E5%8F%98%E5%88%86%E6%8E%A8%E6%96%AD%E7%9A%84%E5%8E%9F%E7%90%86%E6%8E%A8%E5%AF%BC.ipynb)以及[《变分推断与EM的关系》](https://nbviewer.jupyter.org/github/zhulei227/ML_Notes/blob/master/notebooks/15_02_VI_%E5%8F%98%E5%88%86%E6%8E%A8%E6%96%AD%E4%B8%8EEM%E7%9A%84%E5%85%B3%E7%B3%BB.ipynb) 看完之后,也许就能猜出变分EM是要怎么做了,下面我再做一个简单的说明\n", "\n", "![avatar](./source/15_EM中三者间的关系.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "如上图,我们知道带参数的对数似然函数$ln\\ p(X\\mid\\theta)$,可以拆解为一个证据下界ELBO函数$L(q,\\theta)$和一个KL距离$KL(q||p)$,这里的$q$即是我们的变分分布,为了使它更加简单通常对其做一个平均场假设,即各个隐变量(组)之间是独立的: \n", "\n", "$$\n", "q(z)=q(z_1)q(z_2)\\cdots q(z_n)\n", "$$ \n", "\n", "而$p$则是复杂的后验概率分布: \n", "\n", "$$\n", "p(z)=p(z\\mid X,\\theta)\n", "$$ \n", "\n", "好的,在此基础上,我们来看看变分推断和EM分别要做怎么样的一件事: \n", "\n", "#### 变分推断\n", "变分推断要做的事情是让简单的变分分布去近似复杂的后验分布: \n", "\n", "$$\n", "q(z)\\rightarrow p(z)\n", "$$ \n", "\n", "它并不关心$\\theta$,将其视作一个常数处理,通过最大化ELBO函数来使得$KL(q||p)$最小化,从而使$q$与$p$近似,即它要做的是如下的优化问题: \n", "\n", "$$\n", "q^*=arg\\max_{q}L(q,\\theta)\n", "$$ \n", "\n", "#### EM\n", "而EM算法的初衷是要通过优化$\\theta$使得对数似然函数极大化,它令: \n", "\n", "$$\n", "q(z)=p(z\\mid X,\\theta^{old})\n", "$$ \n", "\n", "这时$KL(q||p)=0$,所以有: \n", "\n", "$$\n", "L(q,\\theta)=ln\\ p(X\\mid \\theta)\n", "$$ \n", "\n", "这时,对ELBO函数极大化等价于对对数似然函数极大化,从而得到最优解: \n", "\n", "$$\n", "\\theta=arg\\max_{\\theta}L(q\\mid\\theta),q(z)=p(z\\mid X,\\theta^{old})\n", "$$ \n", "\n", "显然,使用EM算法的前提是后验概率分布$p(z\\mid X,\\theta)$的形式比较方便求解,如果它很复杂呢?那变分EM就诞生了...\n", "\n", "#### 变分EM\n", "\n", "变分EM不改EM的初衷,即要使得对数似然函数$ln\\ p(X\\mid \\theta)$极大化,同时对于$p(z\\mid X,\\theta)$利用一个简单的变分分布$q(z)$去近似,所以变分EM算法的优化变量包括两个:$q$和$\\theta$ \n", "\n", "$$\n", "q^*,\\theta^*=arg\\max_{q,\\theta}L(q,\\theta)\n", "$$ \n", "而优化过程通常采用坐标轮换法,即: \n", "\n", "(1)E步:固定$\\theta$,求$L(q,\\theta)$对$q$的最大化; \n", "(2)M步:固定$q$,求$L(q,\\theta)$对$\\theta$的最大化; \n", "\n", "如果对变分EM没问题了,接下来就要考虑利用变分EM去求LDA模型的什么分布勒?回想一下上一节的Gibbs采样,我们去求相同的分布不就可以了,但LDA的变分EM求解还做了进一步的简化" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 二.ELBO推导\n", "变分EM对LDA模型做了如下简化\n", "![avatar](./source/16_变分EM的LDA模型.png)\n", "省略了超参数$\\beta$,其中$\\alpha,\\varphi$为模型参数,$\\theta,z$是隐变量,$w$是可观测变量,为了简便,一次只考虑一个文本,记作$w=(w_1,w_2,...,w_N)$,对应的主题序列$z=(z_1,z_2,...,z_N)$,对应的话题分布为$\\theta$,所以其联合概率分布可以表示为: \n", "\n", "$$\n", "p(\\theta,z,w\\mid\\alpha,\\varphi)=p(\\theta\\mid\\alpha)\\prod_{n=1}^Np(z_n\\mid\\theta)p(w_n\\mid z_n,\\varphi)\n", "$$ \n", "\n", "所以,我们需要去近似后验概率$p(\\theta,z\\mid w,\\alpha,\\varphi)$,可以定义变分分布为: \n", "\n", "$$\n", "q(\\theta,z\\mid\\gamma,\\eta)=q(\\theta\\mid\\gamma)\\prod_{n=1}^Nq(z_n\\mid\\eta_n)\n", "$$ \n", "\n", "其中,$\\gamma=(\\gamma_1,\\gamma_2,...,\\gamma_K)$是狄利克雷分布参数,$\\eta=(\\eta_1,\\eta_2,...,\\eta_n)$是多项分布参数,变量$\\theta$和$z$的各个分量都是条件独立的,它的盘子图如下: \n", "![avatar](./source/16_变分EM的变分分布.png) \n", "\n", "所以,其证据下界ELBO可以写作: \n", "\n", "$$\n", "L(\\gamma,\\eta,\\alpha,\\varphi)=E_q[ln\\ p(\\theta,z,w\\mid\\alpha,\\varphi)]-E_q[ln\\ q(\\theta,z\\mid\\gamma,\\eta)]\n", "$$ \n", "\n", "其中数学期望是对分布$q(\\theta,z\\mid\\gamma,\\eta)$定义的,为了方便简写为$E_q[\\cdot]$,$\\gamma$和$\\eta$是变分分布的参数,$\\alpha$和$\\varphi$是LDA模型的参数" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 三.参数求解\n", "\n", "所以接下来,就是固定$\\gamma,\\eta$,求$L(\\gamma,\\eta,\\alpha,\\varphi)$对$\\alpha,\\varphi$的极大化,然后固定$\\alpha,\\varphi$,求$L(\\gamma,\\eta,\\alpha,\\varphi)$对$\\gamma,\\eta$的极大化,持续下去直到收敛...推导过程的公式不想码了,哈哈哈,自己看《统计学习方法》吧,下面就直接写参数求解的算法了,并对必要的符号做说明\n", "\n", "#### 算法一:对变分参数$\\gamma,\\eta$估计\n", ">初始化:对所有的$k$和$n$,$\\eta_{nk}^{(0)}=1/K$($k$表示主题,$n$表示当前文本的第$n$个位置,$K$表示总主题数) \n", "\n", ">初始化:对所有的$k$,$\\gamma_k=\\alpha_k+N/K$($k$表示主题,$K$表示总主题数,$N$表示当前文本的总字数) \n", "\n", ">重复\n", ">>对$n=1:N$\n", ">>>对$k=1:K$\n", ">>>>$\\eta_{nk}^{(t+1)}=\\varphi_{kv}exp\\left[\\Psi(\\gamma_k^{(t)})-\\Psi(\\sum_{l=1}^K\\gamma_l^{(t)})\\right]$\n", "\n", ">>>规范化$\\eta_{nk}^{(t+1)}$使其和为1\n", "\n", ">>$\\gamma^{(t+1)}=\\alpha+\\sum_{n=1}^N\\eta_n^{(t+1)}$\n", "\n", ">直到收敛\n", "\n", "这里,$\\Psi(\\cdot)$为digamma函数,即: \n", "\n", "$$\n", "\\Psi(x)=\\frac{d\\ ln\\ \\Gamma(x)}{dx}\n", "$$\n", "$\\Psi(\\cdot)$可以使用`scipy.special.digamma`直接求解,哈哈哈~\n", "\n", "#### 算法二:对LDA参数$\\alpha,\\varphi$估计\n", "基于上面的$\\gamma,\\eta$可以写出$\\alpha,\\varphi$的计算公式,...省略了推导过程....\n", "\n", "$$\n", "\\varphi_{kv}=\\sum_{m=1}^M\\sum_{n=1}^{N_m}\\eta_{mnk}w_{mn}^v\n", "$$ \n", "\n", "其中,$\\eta_{mnk}$表示第$m$个文本的第$n$个单词属于第$k$个话题的概率,$w_{mn}^v$在第$m$个文本的第$n$个单词是单词集合的第$v$个单词时取值为1,否则为0,而$\\alpha$的更新为 \n", "\n", "$$\n", "\\alpha_{new}=\\alpha_{old}-H(\\alpha_{old})^{-1}g(\\alpha_{old})\n", "$$ \n", "\n", "其中,$g(\\cdot)$表示其梯度,计算公式为: \n", "\n", "$$\n", "\\frac{\\partial L}{\\partial \\alpha_k}=M\\left[\\Psi(\\sum_{l=1}^K\\alpha_l)-\\Psi(\\alpha_k) \\right]-\\sum_{m=1}^M\\left[\\Psi(\\gamma_{mk})-\\Psi(\\sum_{l=1}^K\\gamma_{ml})\\right]\n", "$$ \n", "\n", "而$H$表示Hessian矩阵,计算公式如下: \n", "\n", "$$\n", "\\frac{\\partial^2 L}{\\partial\\alpha_k\\partial\\alpha_l}=M\\left[\\Psi'\\left(\\sum_{l=1}^K\\alpha_l\\right)-\\delta(k,l)\\Psi'(\\alpha_k)\\right]\n", "$$ \n", "\n", "其中,$\\delta(k,l)$是delta函数\n", "\n", "所以,对**算法一**和**算法二**交替迭代,直到收敛即是我们想要的结果,另外对于$\\alpha$的更新,为了方便,笔者将Hessian矩阵的逆$H^{-1}$这一部分替换为学习率进行计算,即将二阶的牛顿法替换为一阶的梯度下降法(其实是能力有限,不知道$\\psi'$和$delta$如何计算...哈哈哈)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "隐狄利克雷分布的代码实现,包括Gibbs采样和变分EM算法,代码封装在ml_models.latent_dirichlet_allocation\n", "\"\"\"\n", "import numpy as np\n", "from scipy.special import digamma\n", "\n", "\n", "class LDA(object):\n", " def __init__(self, alpha=None, beta=None, K=10, tol=1e-3, epochs=100, method=\"gibbs\", lr=1e-5):\n", " \"\"\"\n", " :param alpha: 主题分布的共轭狄利克雷分布的超参数\n", " :param beta: 单词分布的共轭狄利克雷分布的超参数\n", " :param K: 主题数量\n", " :param tol:容忍度,允许tol的隐变量差异\n", " :param epochs:最大迭代次数\n", " :param method:优化方法,默认gibbs,另外还有变分EM,vi_em\n", " :param lr:学习率,对vi_em生效\n", " \"\"\"\n", " self.alpha = alpha\n", " self.beta = beta\n", " self.K = K\n", " self.tol = tol\n", " self.epochs = epochs\n", " self.method = method\n", " self.lr = lr\n", " self.phi = None # 主题-单词矩阵\n", "\n", " def _init_params(self, W):\n", " \"\"\"\n", " 初始化参数\n", " :param W:\n", " :return:\n", " \"\"\"\n", " M = len(W) # 文本数\n", " V = 0 # 词典大小\n", " I = 0 # 单词总数\n", " for w in W:\n", " V = max(V, max(w))\n", " I += len(w)\n", " V += 1 # 包括0\n", " # 文本话题计数\n", " N_M_K = np.zeros(shape=(M, self.K))\n", " N_M = np.zeros(M)\n", " # 话题单词计数\n", " N_K_V = np.zeros(shape=(self.K, V))\n", " N_K = np.zeros(self.K)\n", " # 初始化隐状态,计数矩阵\n", " Z = [] # 隐状态,与W一一对应\n", " p = [1 / self.K] * self.K\n", " hidden_status = list(range(self.K))\n", " for m, w in enumerate(W):\n", " z = np.random.choice(hidden_status, len(w), replace=True, p=p).tolist()\n", " Z.append(z)\n", " for n, k in enumerate(z):\n", " v = w[n]\n", " N_M_K[m][k] += 1\n", " N_M[m] += 1\n", " N_K_V[k][v] += 1\n", " N_K[k] += 1\n", " # 初始化alpha和beta\n", " if self.alpha is None:\n", " self.alpha = np.ones(self.K)\n", " if self.beta is None:\n", " self.beta = np.ones(V)\n", " return Z, N_M_K, N_M, N_K_V, N_K, M, V, I, hidden_status\n", "\n", " def _fit_gibbs(self, W):\n", " \"\"\"\n", " :param W: 文本集合[[...],[...]]\n", " :return:\n", " \"\"\"\n", " Z, N_M_K, N_M, N_K_V, N_K, M, V, I, hidden_status = self._init_params(W)\n", " for _ in range(self.epochs):\n", " error_num = 0\n", " for m, w in enumerate(W):\n", " z = Z[m]\n", " for n, topic in enumerate(z):\n", " word = w[n]\n", " N_M_K[m][topic] -= 1\n", " N_M[m] -= 1\n", " N_K_V[topic][word] -= 1\n", " N_K[topic] -= 1\n", " # 采样一个新k\n", " p = [] # 更新多项分布\n", " for k_ in range(self.K):\n", " p_ = (N_K_V[k_][word] + self.beta[word]) * (N_M_K[m][k_] + self.alpha[topic]) / (\n", " (N_K[k_] + np.sum(self.beta)) * (N_M[m] + np.sum(self.alpha)))\n", " p.append(p_)\n", " ps = np.sum(p)\n", " p = [p_ / ps for p_ in p]\n", " topic_new = np.random.choice(hidden_status, 1, p=p)[0]\n", " if topic_new != topic:\n", " error_num += 1\n", " Z[m][n] = topic_new\n", " N_M_K[m][topic_new] += 1\n", " N_M[m] += 1\n", " N_K_V[topic_new][word] += 1\n", " N_K[topic_new] += 1\n", " if error_num / I < self.tol:\n", " break\n", "\n", " # 计算参数phi\n", " self.phi = N_K_V / np.sum(N_K_V, axis=1, keepdims=True)\n", "\n", " def _fit_vi_em(self, W):\n", " \"\"\"\n", " 分为两部分,迭代计算:\n", " (1)给定lda参数,更新变分参数\n", " (2)给定变分参数,更新lda参数\n", " :param W:\n", " :return:\n", " \"\"\"\n", " V = 0 # 词典大小\n", " for w in W:\n", " V = max(V, max(w))\n", " V += 1\n", " M = len(W)\n", "\n", " # 给定lda参数,更新变分参数\n", " def update_vi_params(alpha, phi):\n", " eta = []\n", " gamma = []\n", " for w in W:\n", " N = len(w)\n", " eta_old = np.ones(shape=(N, self.K)) * (1 / self.K)\n", " gamma_old = alpha + N / self.K\n", " eta_new = np.zeros_like(eta_old)\n", " for _ in range(self.epochs):\n", " for n in range(0, N):\n", " for k in range(0, self.K):\n", " eta_new[n, k] = phi[k, w[n]] * np.exp(digamma(gamma_old[k]) - digamma(np.sum(gamma_old)))\n", " eta_new = eta_new / np.sum(eta_new, axis=1, keepdims=True)\n", " gamma_new = alpha + np.sum(eta_new, axis=0)\n", " if (np.sum(np.abs(gamma_new - gamma_old)) + np.sum(np.abs((eta_new - eta_old)))) / (\n", " (N + 1) * self.K) < self.tol:\n", " break\n", " else:\n", " eta_old = eta_new.copy()\n", " gamma_old = gamma_new.copy()\n", " eta.append(eta_new)\n", " gamma.append(gamma_new)\n", " return eta, gamma\n", "\n", " # 给定变分参数,更新lda参数\n", " def update_lda_params(eta, gamma, alpha_old):\n", " # 更新phi\n", " phi = np.zeros(shape=(self.K, V))\n", " for m, w in enumerate(W):\n", " for n, word in enumerate(w):\n", " for k in range(0, self.K):\n", " for v in range(0, V):\n", " phi[k, v] += eta[m][n, k] * (word == v)\n", " # 更新alpha\n", " d_alpha = []\n", " for k, alpha_ in enumerate(alpha_old):\n", " tmp = M * (digamma(np.sum(alpha_old)) - digamma(alpha_))\n", " for m in range(M):\n", " tmp -= (digamma(gamma[m][k]) - digamma(np.sum(gamma[m])))\n", " d_alpha.append(tmp)\n", " alpha_new = alpha_old - self.lr * np.asarray(d_alpha)\n", " alpha_new = np.where(alpha_new < 0.0, 0.0, alpha_new)\n", " alpha_new = alpha_new / (1e-9 + np.sum(alpha_new)) * self.K\n", " phi = phi / (np.sum(phi, axis=1, keepdims=True) + 1e-9)\n", " return alpha_new, phi\n", "\n", " # 初始化alpha和phi\n", " alpha_old = np.random.random(self.K)\n", " phi_old = np.random.random(size=(self.K, V))\n", " phi_old = phi_old / np.sum(phi_old, axis=1, keepdims=True)\n", " for _ in range(self.epochs):\n", " eta, gamma = update_vi_params(alpha_old, phi_old)\n", " alpha_new, phi_new = update_lda_params(eta, gamma, alpha_old)\n", " if (np.sum(np.abs(alpha_new - alpha_old)) + np.sum(np.abs((phi_new - phi_old)))) / (\n", " (V + 1) * self.K) < self.tol:\n", " break\n", " else:\n", " alpha_old = alpha_new.copy()\n", " phi_old = phi_new.copy()\n", " self.phi = phi_new\n", "\n", " def fit(self, W):\n", " if self.method == \"gibbs\":\n", " self._fit_gibbs(W)\n", " else:\n", " self._fit_vi_em(W)\n", "\n", " def transform(self, W):\n", " rst = []\n", " for w in W:\n", " tmp = np.zeros(shape=self.K)\n", " for v in w:\n", " try:\n", " v_ = self.phi[:, v]\n", " except:\n", " v_ = np.zeros(shape=self.K)\n", " tmp += v_\n", " if np.sum(tmp) > 0:\n", " tmp = tmp / np.sum(tmp)\n", " rst.append(tmp)\n", " return np.asarray(rst)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 六.测试" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "docs=[\n", " [\"有\",\"微信\",\"红包\",\"的\",\"软件\"],\n", " [\"微信\",\"支付\",\"不行\",\"的\"],\n", " [\"我们\",\"需要\",\"稳定的\",\"微信\",\"支付\",\"接口\"],\n", " [\"申请\",\"公众号\",\"认证\"],\n", " [\"这个\",\"还有\",\"几天\",\"放\",\"垃圾\",\"流量\"],\n", " [\"可以\",\"提供\",\"聚合\",\"支付\",\"系统\"]\n", "]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[0, 1, 2, 3, 4],\n", " [1, 5, 6, 3],\n", " [7, 8, 9, 1, 5, 10],\n", " [11, 12, 13],\n", " [14, 15, 16, 17, 18, 19],\n", " [20, 21, 22, 5, 23]]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word2id={}\n", "idx=0\n", "W=[]\n", "for doc in docs:\n", " tmp=[]\n", " for word in doc:\n", " if word in word2id:\n", " tmp.append(word2id[word])\n", " else:\n", " word2id[word]=idx\n", " idx+=1\n", " tmp.append(word2id[word])\n", " W.append(tmp)\n", "W" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "lda=LDA(epochs=200,method=\"vi_em\")\n", "lda.fit(W)\n", "trans=lda.transform(W)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15263755445005087" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#第二句和第三句应该比较近似,因为它们都含有“微信”,“支付”\n", "trans[1].dot(trans[2])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0021851131099540695" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#而第二句和第四句的相似度显然不如第二句和第三句\n", "trans[1].dot(trans[3])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.00023384778414162343" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#当然第二句和第五句的差距也有些大\n", "trans[1].dot(trans[4])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15459772783826697" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#而第一句和第二句都含有“微信”,所以相似度会比第四、五句高,但这里比第三句高...\n", "trans[1].dot(trans[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从结果来看还基本能接受的,还有训练速度会比gibbs快不少,另外代码效率还有不少的优化空间~~~" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }