{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AdaDelta算法\n",
"\n",
"除了RMSProp算法以外,另一个常用优化算法AdaDelta算法也针对AdaGrad算法在迭代后期可能较难找到有用解的问题做了改进 [1]。有意思的是,AdaDelta算法没有学习率这一超参数。\n",
"\n",
"## 算法\n",
"\n",
"AdaDelta算法也像RMSProp算法一样,使用了小批量随机梯度$\\boldsymbol{g}_t$按元素平方的指数加权移动平均变量$\\boldsymbol{s}_t$。在时间步0,它的所有元素被初始化为0。给定超参数$0 \\leq \\rho < 1$(对应RMSProp算法中的$\\gamma$),在时间步$t>0$,同RMSProp算法一样计算\n",
"\n",
"$$\\boldsymbol{s}_t \\leftarrow \\rho \\boldsymbol{s}_{t-1} + (1 - \\rho) \\boldsymbol{g}_t \\odot \\boldsymbol{g}_t. $$\n",
"\n",
"与RMSProp算法不同的是,AdaDelta算法还维护一个额外的状态变量$\\Delta\\boldsymbol{x}_t$,其元素同样在时间步0时被初始化为0。我们使用$\\Delta\\boldsymbol{x}_{t-1}$来计算自变量的变化量:\n",
"\n",
"$$ \\boldsymbol{g}_t' \\leftarrow \\sqrt{\\frac{\\Delta\\boldsymbol{x}_{t-1} + \\epsilon}{\\boldsymbol{s}_t + \\epsilon}} \\odot \\boldsymbol{g}_t, $$\n",
"\n",
"其中$\\epsilon$是为了维持数值稳定性而添加的常数,如$10^{-5}$。接着更新自变量:\n",
"\n",
"$$\\boldsymbol{x}_t \\leftarrow \\boldsymbol{x}_{t-1} - \\boldsymbol{g}'_t. $$\n",
"\n",
"最后,我们使用$\\Delta\\boldsymbol{x}_t$来记录自变量变化量$\\boldsymbol{g}'_t$按元素平方的指数加权移动平均:\n",
"\n",
"$$\\Delta\\boldsymbol{x}_t \\leftarrow \\rho \\Delta\\boldsymbol{x}_{t-1} + (1 - \\rho) \\boldsymbol{g}'_t \\odot \\boldsymbol{g}'_t. $$\n",
"\n",
"可以看到,如不考虑$\\epsilon$的影响,AdaDelta算法与RMSProp算法的不同之处在于使用$\\sqrt{\\Delta\\boldsymbol{x}_{t-1}}$来替代超参数$\\eta$。\n",
"\n",
"\n",
"## 从零开始实现\n",
"\n",
"AdaDelta算法需要对每个自变量维护两个状态变量,即$\\boldsymbol{s}_t$和$\\Delta\\boldsymbol{x}_t$。我们按AdaDelta算法中的公式实现该算法。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "11"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import d2ltorch as d2lt\n",
"import torch\n",
"from torch import optim\n",
"\n",
"features, labels = d2lt.get_data_ch7()\n",
"\n",
"def init_adadelta_states():\n",
" s_w, s_b = torch.zeros(features.shape[1], 1), torch.zeros(1)\n",
" delta_w, delta_b = torch.zeros(features.shape[1], 1), torch.zeros(1)\n",
" return ((s_w, delta_w), (s_b, delta_b))\n",
"\n",
"def adadelta(params, states, hyperparams):\n",
" rho, eps = hyperparams['rho'], 1e-5\n",
" for p, (s, delta) in zip(params, states):\n",
" s.data = rho * s.data + (1 - rho) * (p.grad.data ** 2)\n",
" g = ((delta.data + eps).sqrt() / (s.data + eps).sqrt()) * p.grad.data\n",
" p.data -= g.data\n",
" delta.data = rho * delta.data + (1 - rho) * g.data * g.data\n",
" p.grad.data.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"使用超参数$\\rho=0.9$来训练模型。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "12"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 0.244901, 0.091154 sec per epoch\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"d2lt.train_ch7(adadelta, init_adadelta_states(), {'rho': 0.9}, features,\n",
" labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 简洁实现\n",
"\n",
"通过名称为“Adadelta”的`Optimizer`实例,我们便可使用PyTorch提供的AdaDelta算法。它的超参数可以通过`rho`来指定。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "9"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 0.256363, 0.076244 sec per epoch\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"d2lt.train_nn_ch7(optim.Adadelta, {'rho': 0.9}, features, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 小结\n",
"\n",
"* AdaDelta算法没有学习率超参数,它通过使用有关自变量更新量平方的指数加权移动平均的项来替代RMSProp算法中的学习率。\n",
"\n",
"## 练习\n",
"\n",
"* 调节AdaDelta算法中超参数$\\rho$的值,观察实验结果。\n",
"\n",
"\n",
"\n",
"## 参考文献\n",
"\n",
"[1] Zeiler, M. D. (2012). ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.\n",
"\n",
"## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/2277)\n",
"\n",
""
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:pytorch]",
"language": "python",
"name": "conda-env-pytorch-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}