{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 0-1 背包问题详解\n", "\n", "[@jameslao](https://github.com/jameslao) / [www.jlao.net](http://www.jlao.net)\n", "\n", "## 背包问题是啥\n", "\n", "不知道你还记不记得我们小时候看过的《太阳山》的故事:\n", "\n", "\"cover\"\n", "\n", "> 从前,有弟兄俩,老大是个贪心的富人,老二是个穷人。\n", "> \n", "> 老二没有地,种老大的地。他一年到头在地里做活;可是老是吃不上,穿不上,租子交不够。\n", "> \n", "> 有一天,老二上山打柴。天都快黑了,他还坐在山上;他发愁,加上打柴,租子还是交不够。这时候,忽然刮起一阵黑风,飞来一只大鸟。大鸟落在他跟前,对他说:“你不用发愁。太阳山上有很多金子、银子,还有宝石。我背你上太阳山去吧!你从那儿可以拿些金子回来。”\n", "> \n", "> 老二爬到大鸟的背上。大鸟叫他闭上眼睛。他刚闭上眼睛,觉得飕一下子,就听见大鸟说:“到了。你看,有多少金子!去拿吧!你可不要贪心。这是太阳山。如果我们在这儿待得时间太长,赶上太阳回来,就把你烧死了。那时候,我也没办法救你。”\n", "> \n", "> 老二一边答应着,一边从大鸟的背上跳下来。哎呀!遍地是金子、银子、宝石,金光闪闪,晃得人睁不开眼睛。老二想了想,就拿一小块金子,装进口袋,要大鸟背他回去。大鸟问他为什么只拿一小块儿。他说够了。\n", "> \n", "> 老二又爬到大鸟的背上,闭上眼睛。飕一下子,大鸟落到他家门口了。他谢过大鸟,大鸟就飞走了。\n", "> \n", "> 老二有了金子,买了地,盖了房子。他早晨起来去种地,晚上回到新盖的房子里休息,日子过得很好。\n", "> \n", "> 老大见老二有金子,很眼红。他问老二的金子是哪儿来的。老二把大鸟背自己到太阳山的事告诉他了。\n", "> \n", "> 第二天,老大学着老二到山上去打柴。天快黑了,他故意不回家,坐在山上假装发愁。大鸟飞来了,也答应背他上太阳山去拿些金子。\n", "> \n", "> 大鸟背老大到了太阳山,嘱咐他不要贪心,像嘱咐老二一样,可是他连哼也不哼。他看见那么多金子、银子、宝石,忙着张开大布袋子,捡最大块儿的金子往里装。\n", "> \n", "> 老大拼命的往大布袋子里装金子,装个没完。大鸟催他回去,他说再让他拿几块。大鸟再催他回去,他还是这么说。最后,大鸟说:“时候到了!太阳马上就回来了!”他这才站起来,朝着大鸟走,摇摇晃晃的,大布袋子压得他走不稳了。没走几步,他又弯下腰来,说:“等一会儿。我再拿几颗宝石吧!”\n", "> \n", "> 正在这时候,火红的太阳回来了。它把烈火似的阳光喷到太阳山上。大鸟飞走了。老大被烧死了。\n", "\n", "要是换你去拿,你会怎么拿呢?你是只拿一小块,还是拿到自己走不动?有没有什么办法,能够拿到最多的金子呢?\n", "\n", "所谓背包 (knapsack) 问题,就是说一个要去爬山的人(或者大盗之类的也行),他有很多东西可以带,他的包也足够大,但是背得动的总重量有一个上限 $ W$ (大家都是人)。假设我们一共有 $ n$ 种物品,每件物品 $ j$ 都有一个非负的重量 $ w_j$(氢气球之类不算)和一个非负的价值 $ v_j$ (我就不举例了)。那么应该怎样挑选,才能让背包中装的物品的总价值最高?\n", "\n", "也就是说要让目标函数\n", "\n", "$$ \\sum_{j=1}^n v_j x_j $$\n", "\n", "最大化。\n", "\n", "如果每种物品只有一件,要么带要么不带,那就是最基本的 **0-1 背包问题** 。这时要满足约束条件\n", "\n", "$$ \\sum_{j=1}^n w_j x_j \\le W, \\; x_j \\in \\{0, 1\\}. $$\n", "\n", "如果物品 $j$ 有最多 $b_j$ 件可拿,就称为**多重背包问题**或者**有界背包问题**,约束条件是\n", "\n", "$$ \\sum_{j=1}^n w_j x_j \\le W, \\; x_j \\in \\{0, 1, \\ldots, b_j\\}. $$\n", "\n", "如果可以敞开拿呢,就称为**完全背包问题**或者**无界背包问题** 。\n", "\n", "## 这问题有啥搞头\n", "\n", "我这么宅的人,一年也不会出去背包几趟,我也没那个机会去太阳山拣金子宝石或者去抢珠宝店,犯得着费这个劲研究这东西吗?\n", "\n", "但是——比如说你有一笔钱来投资,有 $n$ 种可能的投资选择,投资 $j$ 需要 $w_j$ 的资金,会创造 $v_j$ 的收益(这儿先不说风险的事儿……)。那解决了这个背包问题,不就得出了一个获得最大收益的最优组合吗?还有比如货轮装货,也是类似的问题。\n", "\n", "除了前面提到的这些实际应用,0-1 背包问题的意义还在于它是最简单的**整数规划问题**,并且可以为许多复杂问题提供一个解决的途径。\n", "\n", "可是我们有电脑呀,算一下还不快吗?最暴力的方法就是检查所有可能的 $x_j$ 组合,选出来最优的就行了。可是这个组合有 $2^n$ 个,只要 $n$ 稍微大一点点,就已经超出了计算机能够解决的范畴了(你可以算一下比如 $2^{64}$ 会怎么样……)。\n", "\n", "多年来,人们对这个看似简单的问题做出了相当多的研究,也得出了若干种能够比较高效解决这个问题的方法,成千上万个物品也不在话下了。我们后面一一加以介绍。\n", "\n", "## 哪能搞法——动态规划初试\n", "\n", "现在一共有 $n$ 个物品要考虑,总重量不超过 $W$。我们倒着来考虑一下,最优的背包中,包含不包含第 $n$ 个物品呢?\n", "\n", "* 如果**不**包含第 $n$ 个物品,那么这个物品就可以直接去掉了,最优解就是只有前 $n-1$ 个物品放在 $W$ 的背包中的结果;\n", "* 如果包含第 $n$ 个物品,那背包就被它占掉了 $w_n$ ,现在的问题就是要把前 $n-1$ 个物品放在 $W - w_n$ 的背包中怎么放法了。\n", "\n", "看出来了吗?这两种情况中,总价值比较高的就是我们所要的最优解。如果用 $f_n(W)$ 来表示总价值的话,那么\n", "\n", "$$ f_n(W) = \\max \\begin{cases} f_{n-1} (W) \\\\ f_{n-1} (W - w_n) + v_n \\end{cases} $$\n", "\n", "这就是我们所要的**状态转移方程**。\n", "\n", "空说不太好说,我们来举一个具体的例子吧。为了方便起见,我们规定输入的格式是这样的:第一行的两个数分别是重量上限 $W$ 和物品总数 $n$,接下来的每一行则是物品的价值 $v_j$ 和重量 $w_j$。比方说[输入文件](0-1-Knapsack/test/knapsack_tiny.txt)是:\n", "\n", "```\n", "16 4\n", "30 4\n", "20 5\n", "40 10\n", "10 3\n", "```\n", "\n", "意思是我们的背包总重不超过 16,有这样几件物品:\n", "\n", "| 物品 | 价值 | 重量 |\n", "| --- | --- | --- |\n", "| 1 | 30 | 4 |\n", "| 2 | 20 | 5 |\n", "| 3 | 40 | 10 |\n", "| 4 | 10 | 3 |\n", "\n", "我们按照前面的思路,写出来的代码是这样的:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Max value: 70\n", "item: 2 value: 40 weight: 10\n", "item: 0 value: 30 weight: 4\n" ] } ], "source": [ "def knapsack(v, w, limit, n):\n", " F = [[0] * (limit + 1) for x in range(n + 1)] \n", " for i in range(0, n): # F[-1] is all 0.\n", " for j in range(limit + 1):\n", " if j >= w[i]:\n", " F[i][j] = max(F[i - 1][j], F[i - 1][j - w[i]] + v[i])\n", " else:\n", " F[i][j] = F[i - 1][j]\n", " return F\n", "\n", "if __name__ == \"__main__\":\n", " with open(\"0-1-Knapsack/test/knapsack_tiny.txt\") as f:\n", " limit, n = map(int, f.readline().split())\n", " v, w = zip(*[map(int, ln.split()) for ln in f.readlines()])\n", "\n", " F = knapsack(v, w, limit, n)\n", " print(\"Max value:\", F[n - 1][limit])\n", "\n", " \"\"\" Display selected items\"\"\"\n", " y = limit\n", " for i in range(n - 1, -1, -1):\n", " if F[i][y] > F[i - 1][y]:\n", " print (\"item: \", i, \"value:\", v[i], \"weight:\", w[i])\n", " y -= w[i]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里利用了 python 允许数组下标为负,即 `F[-1]` 实际上就是 `F[n]`,可以减少一些边界条件的处理。最后显示加入了哪些物品时,只需对所有物品进行回溯,看看在哪个物品上导致 `F` 增加,即说明这个物品被加入了背包里。\n", "\n", "我顺便打印了 `F` 好让大家有一个更直观的感觉:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "data": { "text/plain": [ "[[0, 0, 0, 0, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30],\n", " [0, 0, 0, 0, 30, 30, 30, 30, 30, 50, 50, 50, 50, 50, 50, 50, 50],\n", " [0, 0, 0, 0, 30, 30, 30, 30, 30, 50, 50, 50, 50, 50, 70, 70, 70],\n", " [0, 0, 0, 10, 30, 30, 30, 40, 40, 50, 50, 50, 60, 60, 70, 70, 70],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里有一个稍微大一点的[例子](0-1-Knapsack/test/knapsack_small.txt),运行之后会输出结果:\n", "\n", "
Max value: 12248\n",
    "item:  13 value: 3878 weight: 9656\n",
    "item:  12 value: 1513 weight: 3926\n",
    "item:  7 value: 2890 weight: 7280\n",
    "item:  5 value: 1022 weight: 2744\n",
    "item:  2 value: 2945 weight: 7390
\n", "\n", "看到这两个循环,很明显这个算法的时间复杂度是 $O(nW)$,这个数组 `F` 的空间复杂度也是 $O(nW)$。能不能稍微改得好一点呢?\n", "\n", "从上面这个循环中我们发现,最重要的值实际上是 `F[n-1]`,前面的那些值我们并不太关心。那能不能不要保留这些值呢?如果我们原地刷新 `F[i][j] = max(F[i][j], F[i][j - w[i]] + v[i])` …… 不行,这样的话,如果左边的 `j` 先被刷新,右边再碰到同样的 `j - w[i]` 就已经不是上一轮的值了。那么……要是换从右边开始更新呢?那就没问题啦!咱们的空间复杂度就降到 $O(W)$ 啦!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def knapsack(v, w, limit, n):\n", " F = [0] * (limit + 1)\n", " for i in range(n):\n", " for j in range(limit, w[i], -1):\n", " F[j] = max(F[j - w[i]] + v[i], F[j])\n", " return F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "基本上,一般的书讲背包算法的动态规划也就到此为止了。真的就到此为止了吗?$O(nW)$ 就让我们满意了吗?你想想看,这玩意和 $W$ 相关,可 $W$ 是一个**数字**啊!我们哪天要是不开心买了个[大包包](0-1-Knapsack/test/knapsack_big.txt)难道就要算好久好久了吗?(有兴趣的可以试试……)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 哪能搞法——动态规划再试\n", "\n", "不知道有没有读者试过了这个[大包包](0-1-Knapsack/test/knapsack_big.txt),我试了一下,在我的 2.30GHz i5 5300U 上,这个函数用 cPython 跑了…… 1632.07 秒,将近半个小时。照这个搞法,每次那个包有多大,就得搞出多少项。那个包的大小可是个任意数啊?这 $ W$ 比 $ 2^n$ 大都说不定呢?\n", "\n", "咱们再来琢磨琢磨这个事儿。看看刚才那个最小的例子,也就是背包总重不超过 16,有这样几件物品:\n", "\n", "| 物品 | 价值 | 重量 |\n", "| --- | --- | --- |\n", "| 0 | 30 | 4 |\n", "| 1 | 20 | 5 |\n", "| 2 | 40 | 10 |\n", "| 3 | 10 | 3 |\n", "\n", "我们产生的 $ F$ 是这样的:\n", "\n", "```\n", "[0, 0, 0, 0, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30]\n", "[0, 0, 0, 0, 30, 30, 30, 30, 30, 50, 50, 50, 50, 50, 50, 50, 50]\n", "[0, 0, 0, 0, 30, 30, 30, 30, 30, 50, 50, 50, 50, 50, 70, 70, 70]\n", "[0, 0, 0, 10, 30, 30, 30, 40, 40, 50, 50, 50, 60, 60, 70, 70, 70]\n", "```\n", "\n", "你什么感觉,是不是**怎么这么多重复项**??我们真的需要保留这么多重复项吗?\n", "\n", "动态规划说到底,是把 $ n$ 个对象的情况化归成 $ n-x$ 个对象的子问题。我们这个 $ F$ 里面实际的信息量是什么呢?\n", "\n", "* 第一轮,加入 0 号物品,可以用 4 的重量做到 30 的价值。\n", "* 第二轮,加入 1 号物品,可以用 5 的重量做到 20 的价值——慢着,我们前一轮已经用 4 的重量做到的 30 的价值,你用 5 才做到 20,这完全没有意义嘛!所以它并没有表现在 $ F$ 中。然而如果同时加入 0 号和 1 号物品,我们就可以用 9 的重量做到 50 的价值。这是一个新的信息。\n", "* 第三轮,再加入 2 号物品,同理,我们只有在重量达到 14,价值达到 70 的时候才是一个新的信息。\n", "* 第四轮,我们用 3 的重量就做到了 10 的价值,而以前 3 的重量什么也做不出。那么这也是一个新信息。\n", "\n", "看出什么端倪了么?我们其实只需要保留有信息量的部分,即用最少的重量做到某一个价值的时刻。如果有两个状态 $ (V_1, W_1)$ 和 $ (V_2, W_2)$,使得 $ W_2 > W_1$,但 $V_2 \\leq V_1$,也就是说状态 2 的重量大但价值不高的话,那么就说明状态 2 是占劣势而无需继续考虑的。如果能把这些劣势状态通通舍弃,我们要处理的数据量(通常)可以大大缩减。\n", "\n", "道理理解了,如何实现呢?我们可以用元组 (tuple) 来表示价值和重量,并把它们按顺序排起来。仍以上面的例子来说:\n", "\n", "* 最初我们什么也没有,只有一个元组`(0, 0)`\n", "* 我们试着加上第 0 号元素,得到了`(30, 4)`。与最初的元组合并,于是变成了`(0, 0), (30, 4)`。\n", "* 再试着对每一种情况加上第 1 号元素,得到了`(20, 5), (50, 9)`。与上一轮的元组合并,由于`(20, 5)`处于劣势被舍弃了。合并的结果是`(0, 0), (30, 4), (50, 9)`。\n", "* 继续做这个操作,加上第 2 号元素得到`(40, 10), (70, 14), (90, 19)`。最后一个由于超限被舍弃了,再于上一轮合并,得到`(0, 0), (30, 4), (50, 9), (70, 14)`。\n", "* 最后一轮再加上去,终于得到了最终结果`(0, 0), (10, 3), (30, 4), (40, 7), (50, 9), (60, 12), (70, 14)。`\n", "\n", "\n", "\n", "我的这个实现为了合并的方便使用了 python 中的双向链表 `collections.deque`。它不但写起来方便,而且实测下来比直接用列表和指针还要快一些(原因未知):" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Max value: 4243395 Total weight: 1999783\n" ] } ], "source": [ "from collections import deque\n", "import sys\n", "INF = float(\"inf\")\n", "\n", "def knapsack(vw, limit, n):\n", " vw = sorted(vw, key=lambda x : x[1], reverse=True) # Accelerate\n", " A = deque([(0, 0)])\n", "\n", " for i in range(0, n):\n", " B = deque() # find all possiblities after adding one new item\n", " for item in A:\n", " if item[1] + vw[i][1] > limit: # A is sorted\n", " break\n", " B.append((item[0] + vw[i][0], item[1] + vw[i][1]))\n", "\n", " level, merge = -1, deque() # the bar keeps going up\n", " while A or B: # merging the two queues\n", " ia, ib = A[0][1] if A else INF, B[0][1] if B else INF\n", " x = A.popleft() if (ia < ib) else B.popleft()\n", " if x[0] > level:\n", " merge.append(x)\n", " level = x[0]\n", " A = merge\n", " return A[-1]\n", "\n", "if __name__ == \"__main__\":\n", " with open(\"0-1-Knapsack/test/knapsack_big.txt\") as f:\n", " limit, n = map(int, f.readline().split())\n", " vw = [tuple(map(int, ln.split())) for ln in f.readlines()]\n", " A = knapsack(vw, limit, n)\n", " print(\"Max value:\", A[0], \"Total weight:\", A[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "用这个来解决[大包包](0-1-Knapsack/test/knapsack_big.txt)是轻轻松松啦!在我机器上的运行时间大约只有半秒钟,真的可谓是秒杀!我们这个算法的时间复杂度是 $O(ns)$,其中 $s$ 是上面列出来的独立状态数。这个状态数既不可能超过 $W$,也不可能超过 $2^n$。对于分布比较均匀的背包物品来说,这个方法还是相当奏效的。\n", "\n", "也许你注意到了第 8 行有一条奇怪的排序……这个算法虽然对数据的顺序本身没有要求,但是我发现,如果对数据按照重量从大到小排序,速度比乱序提升了大约四倍。我也试过其他的排序方法,比如按照价值与重量的比值,但在这个特定例子上均不太理想。\n", "\n", "但是呢……遇到[更大的包包](0-1-Knapsack/test/knapsack_1000000_10000.txt)怎么办……我跑了一下 cPython 需要 677.3 秒…… PyPy 快一些需要 65 秒左右。\n", "\n", "其实,我们并不需要在动态规划这一棵树上吊死的呢…… " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 哪能搞法——分支界限\n", "\n", "我们先把动态规划的想法放一下,退一步来看看这个问题。每一个物品都可能放或者不放,如果用一个决策树来表示的话,就得到这样一棵满二叉树:\n", "\n", "\"fulltree\"\n", "\n", "这对于 $n$ 个物品就有 $2^n$ 种放法,总共要计算 $2^{n+1} - 2$ 个结点,当然是太多了。不过,有没有什么办法能让遍历的结点少一点呢?\n", "\n", "比较直观的,至少有这么两个:\n", "\n", "+ 重量超限的——如果一个结点重量超限,那再往上加东西的也不用看了。\n", "+ 最下面一层的右侧打叉的结点——结果肯定和上一层一样,也不用再算一遍了。\n", "\n", "\n", "还有吗?比如说,有些结点重量大而价值却很低,再怎么往上加东西也没有竞争力,能不能尽早找出这种“没有希望”的结点呢?\n", "\n", "对于一些固定的物品,包包里能装走的价值是有一个**上限**的。包的大小是一定的,那么为了让总体价值最大,就要让单位重量的价值最大。怎么才能让单位重量的价值最大呢?我们把所有的物品按照价值与重量的比值($v_j / w_j$)由大到小排序,先尽量拣着重量小价值大的东西装,而把又重又不值钱的放在后面。\n", "\n", "最后包包还没装满又放不下下一个物品怎么办呢?那我们就假装这个物品可以切开,然后把包包塞满。这样一来,现在得到的这个总价值就是一个上限,因为单位价值最高的那些已经都塞进来啦。\n", "\n", "我们把物品排好顺序挨个考虑,如果有个结点已经考虑了前 $i$ 个物品,得到的总价值是 $V_0$,总重量是 $W_0$;我们按照单位价值从大到小挨个往里放,放到第 $k$ 个物品放不下了,这时候包里的总重量是\n", "\n", "$$\n", "\\tilde{W} = W_0 + \\sum_{j=i+1}^{k-1} w_j\n", "$$\n", "\n", "而总价值的上限就是\n", "\n", "$$\n", "V_{\\rm bound} = \\left(V_0 + \\sum_{j=i+1}^{k-1} v_j \\right) + \\big(W - \\tilde{W}\\big)\\cdot \\frac{v_k}{w_k}\n", "$$\n", "\n", "其中前面一项是放进去的物品总价值,后面一项是把第 $k$ 个物品切开后塞满包包的价值。\n", "\n", "如果用 $V_{\\max}$ 表示已经出现过的最高价值,那遇到 $V_{\\rm bound} \\leq V_{\\max}$ 的结点就可以直接把它剪掉啦。\n", "\n", "这个程序写出来可以是这样的:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12248\n" ] } ], "source": [ "import sys\n", "\n", "def knapsack(vw, limit):\n", " \n", " def bound(v, w, j):\n", " if j >= len(vw) or w > limit:\n", " return -1\n", " else:\n", " while j < len(vw) and w + vw[j][1] <= limit:\n", " v, w, j = v + vw[j][0], w + vw[j][1], j + 1\n", " if j < len(vw):\n", " v += (limit - w) * vw[j][0] / (vw[j][1] * 1.0)\n", " return v\n", "\n", " def traverse(v, w, j):\n", " nonlocal maxValue\n", " if bound(v, w, j) >= maxValue: # promising \n", " if w + vw[j][1] <= limit: # w/ j\n", " maxValue = max(maxValue, v + vw[j][0])\n", " traverse(v + vw[j][0], w + vw[j][1], j + 1)\n", " if j < len(vw) - 1: # w/o j\n", " traverse(v, w, j + 1)\n", " return\n", "\n", " maxValue = 0\n", " traverse(0, 0, 0)\n", " return maxValue\n", " \n", "if __name__ == \"__main__\":\n", " with open(\"0-1-Knapsack/test/knapsack_small.txt\") as f:\n", " limit, n = map(int, f.readline().split())\n", " vw = [] # value, weight, value density\n", " for ln in f.readlines():\n", " vl, wl = tuple(map(int, ln.split()))\n", " vw.append([vl, wl, vl / (wl * 1.0)])\n", "\n", " print(knapsack(sorted(vw, key=lambda x : x[2], reverse=True), limit))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "简单解释一下,先对 vw 按照单位重量的价值排序,然后利用 bound 函数确定价值上限。如果价值上限超过了已经出现的最大价值,再分别计算加上当前物品和不加当前物品的两种情况,否则就跳过。\n", "\n", "我们还拿前面的那个小例子来看,它执行过程可以用下面这个图来表示:\n", "\n", "\"knapsack_tiny\"\n", "\n", "显然这是一个深度优先搜索,图中每个结点分别列出了它的当前价值、当前重量和价值上限,红色结点表示重量超限,黄色结点表示价值上限不足,加粗的圈是最优解。这个树看起来比上面那个小多了吧!\n", "\n", "如果包包大一点,效果就更明显了,比如这个有 15 个物品的包包,它的搜索过程是这样的(点击看大图):\n", "\n", "\"knapsack_15\"\n", "\n", "看看我们剪掉了多少红圈黄圈!看起来很不错,让我们来试试前面的那些大包包…… 诶怎么回事……\n", " \n", "
RuntimeError: maximum recursion depth exceeded while calling a Python object
\n", "\n", "好吧递归太深了……我们可以加大递归深度:\n", " \n", "```python\n", "import thread\n", "...\n", "def main():\n", "...\n", "\n", "if __name__ == \"__main__\":\n", " threading.stack_size(67108864) # 64MB stack\n", " sys.setrecursionlimit(20000)\n", " thread = threading.Thread(target=main)\n", " thread.start()\n", "```\n", "\n", "在 Windows 下面递归深度除了受 recursion limit 约束,还受栈空间限制,所以这里用 thread 来设定栈空间,而它只对新建的线程有效。如果不扩大栈空间,2000 个项目还可以,10000 个就不行了。\n", "\n", "当然,所有的递归都是可以解开的,为什么非要用递归呢!直接把后面待处理的结点压进栈里不就行了吗,当然要注意压栈的顺序变了:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "127132909\n", "Wall time: 3.44 s\n" ] } ], "source": [ "def knapsack(vw, limit):\n", " \n", " def bound(v, w, j):\n", " if j >= len(vw) or w > limit:\n", " return -1\n", " else:\n", " while j < len(vw) and w + vw[j][1] <= limit:\n", " v, w, j = v + vw[j][0], w + vw[j][1], j + 1\n", " if j < len(vw):\n", " v += (limit - w) * vw[j][0] / (vw[j][1] * 1.0)\n", " return v\n", " \n", " stack = [(0, 0, 0)]\n", " maxValue = -1\n", " while stack:\n", " v, w, j = stack.pop()\n", " if bound(v, w, j) >= maxValue:\n", " if j < len(vw) - 1: \n", " stack.append([v, w, j + 1]) \n", " if w + vw[j][1] <= limit:\n", " maxValue = max(maxValue, v + vw[j][0])\n", " stack.append([v + vw[j][0], w + vw[j][1], j + 1])\n", " \n", " return maxValue\n", "\n", "if __name__ == \"__main__\":\n", " with open(\"0-1-Knapsack/test/knapsack_extra.txt\") as f:\n", " limit, n = map(int, f.readline().split())\n", " vw = [] # value, weight, value density\n", " for ln in f.readlines():\n", " vl, wl = tuple(map(int, ln.split()))\n", " vw.append([vl, wl, vl / (wl * 1.0)])\n", "\n", "%time print(knapsack(sorted(vw, key=lambda x : x[2], reverse=True), limit))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "那个有 10000 个元素的大包包也完全不在话下,瞬间就可以解开啦!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 再接再厉\n", "\n", "前面的方法已经非常快了,应该说很满意。那还有没有再优化的余地呢?\n", "\n", "我们再来仔细看一下这个过程:深度优先搜索,逐级回溯,典型的二叉树遍历方式。为了方便说明,我标出了结点遍历的顺序:\n", "\n", "\"knapsack_tiny_arrow\"\n", "\n", "\n", "一切看起来都挺好的。不过再仔细看一下——hmm,似乎还是有些地方可以琢磨琢磨。比如说,遍历到第二个物品,也就是价值 \\$50,重量 6,价值上限为 \\$60 的那个结点的时候,旁边那个价值 \\$30 的结点的价值上限却是 \\$76.7。也就是说,我们现在遍历的这个结点,其实“希望”并不如另外的那个大。而最终结果也证明,沿着价值上限为 \\$60 的结点找下去似乎是多做了“无用功”,而全局最大值正是在那个价值上限为 \\$76.7 的结点之下。\n", "\n", "这就给我们一个启发——价值上限最大的结点,是不是更可能是拥有全局最大价值呢?即便不是,价值上限大,也更有可能更快地达到较高的价值,从而更快地提高 maxValue,剪掉更多毫无希望的枝节。\n", "\n", "那么如何实现呢?现在下一个遍历结点的是从栈中弹出的,弹出的顺序取决于元素压栈的顺序。有什么办法能让弹出的顺序和价值上限相关呢?价值上限越大的优先级越高——啊哈,优先队列!Python 里面实现起来再容易不过了,因为它提供了 heapq 这个二叉堆。不过这个堆默认是最小堆,也就是堆顶上的元素是最小的。怎么把它变成最大堆呢?只要把元素取相反数就可以了嘛。实现起来只需要在原来的程序上稍作改动:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "127132909\n", "CPU times: user 2 s, sys: 0 ns, total: 2 s\n", "Wall time: 2 s\n" ] } ], "source": [ "from heapq import *\n", "\n", "def knapsack(vw, limit):\n", "\n", " def bound(v, w, j):\n", " if j >= len(vw) or w > limit:\n", " return -1\n", " else:\n", " while j < len(vw) and w + vw[j][1] <= limit:\n", " v, w, j = v + vw[j][0], w + vw[j][1], j + 1\n", " if j < len(vw):\n", " v += (limit - w) * vw[j][0] / (vw[j][1] * 1.0)\n", " return v\n", "\n", " maxValue = 0\n", " PQ = [[-bound(0, 0, 0), 0, 0, 0]] # -bound to keep maxheap\n", " while PQ:\n", " b, v, w, j = heappop(PQ)\n", " if b <= -maxValue: # promising\n", " if w + vw[j][1] <= limit:\n", " maxValue = max(maxValue, v + vw[j][0])\n", " heappush(PQ, [-bound(v + vw[j][0], w + vw[j][1], j + 1), \n", " v + vw[j][0], w + vw[j][1], j + 1])\n", " if j < len(vw) - 1:\n", " heappush(PQ, [-bound(v, w, j + 1), v, w, j + 1])\n", " return maxValue\n", "\n", "if __name__ == \"__main__\":\n", "# with open(sys.argv[1] if len(sys.argv) > 1 else sys.exit(1)) as f:\n", " with open(\"0-1-Knapsack/test/knapsack_extra.txt\") as f:\n", " limit, n = map(int, f.readline().split())\n", " vw = [] # value, weight, value density\n", " for ln in f.readlines():\n", " vl, wl = tuple(map(int, ln.split()))\n", " vw.append([vl, wl, vl / (wl * 1.0)])\n", "\n", "%time print(knapsack(sorted(vw, key=lambda x : x[2], reverse=True), limit))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "那么这次的遍历顺序变成什么样了呢?\n", "\n", "\"knapsack_tiny_pq_arrow\"\n", "\n", "虽然看着有点眼花缭乱,但其实很明确——找价值上限最高的那个。另外你注意到了吗?遍历的结点少了一个!也许你会说这少一个有什么大不了的……可是元素一多可就显出来啦!比如这个大包包,原先的方法遍历了多少个结点呢?8512 个。换用二叉堆之后呢?只有 2618 个!而这个有 10000 个元素的大包包呢?遍历结点数从 246426 降到了 47445,运行时间不到原先的三分之一(虽然原先也才 0.6 秒)!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }