Skip to content

回退背包 学习笔记

摘自 blog III | 回退背包 - 洛谷专栏

讲解

我们通过一个例题来引入回退背包:

P4141 消失之物

题目大意:有 n 件物品,每件物品都有一个体积 wi。对于每一个 i=1,2,,n,需要回答在不选第 i 件物品的情况下,有多少种方案可以填满容积为 x 的背包。

不考虑限制,这题就是一个朴素的 01 背包。而对于 “不选某件物品” 的限制,可以看作是从原背包中退掉了这个物品,所以叫回退背包。

一种最暴力的想法是,每次枚举不能选的物品,然后对剩下的物品做 01 背包,一共做 n 次。这样的复杂度是 O(n2x)

每次都重新求一次背包显然太慢了,我们尝试用回退的思想。先对所有的物品做一次背包,得到一个 dp 数组。dpj 代表填满容量 j 的方案数。

然后对每个限制,开一个 ans 数组,记 ansj 代表限制不选该物品时填满容量 j 的方案数。

设我们不能选的是物品 i,其体积为 wi,然后就有以下的式子:

  • 对于 j<wi,有 ansj=dpj。因为这个时候显然没法选到物品 i
  • 对于 jwi,有 ansj=dpjansjwi

这个要怎么理解呢?我们实际上是要把 dpj 中选了物品 i 的那些方案给退掉,从退掉物品 i 的角度去考虑,这个不合法的方案数显然与 “填满容量 jwi 且不含物品 i ” 的方案数是相等的。而后者恰好就是 ansjwi

这样就做完了。可以看到,我们处理一个限制的复杂度是 O(x) 的,所以总复杂度是 O(nx)

main code

cpp
void solve() {
    int n, m;
    cin >> n >> m;
    vector<int> v(n + 1);
    for (int i = 1; i <= n; i++)
        cin >> v[i];
    vector<int> dp(m + 1);
    dp[0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = m; j >= v[i]; j--) {
            dp[j] = (dp[j] + dp[j - v[i]]) % 10;
        }
    }

    for (int i = 1; i <= n; i++) {
        vector<int> ans(m + 1);
        for (int j = 0; j < v[i]; j++)
            ans[j] = dp[j];
        for (int j = v[i]; j <= m; j++)
            ans[j] = (dp[j] - ans[j - v[i]] + 10) % 10;
        for (int j = 1; j <= m; j++)
            cout << ans[j];
        cout << endl;
    }
}

值得一提的是,具有回退性质的背包种类很少,大多数都是 01 背包,这与回退背包的主要思想有关。回退背包的思想是:背包问题与物品的求解顺序无关,因此每一个物品都可以认为是最后一个被求解的,那就可以对着 dp 的式子逆向操作,退回这个物品的 dp 状态。

其实回退背包的回退转移式更多是通过对原转移式逆向操作得到的,因为有一些题目很难通过纯分析来推导回退转移式。

如何通过逆向操作得到回退转移式呢?我们用上面的题目为例,原转移式有:

  • 遍历顺序 j:m0;
  • 对于 j<wifj=gj;
  • 对于 jwifj=gj+gjwi.

那么逆向操作即为:

  • 遍历顺序 j:0m;
  • 对于 j<wigj=fj;
  • 对于 jwigj=fjgjwi.

显然 g 就是我们经过回退后的答案数组。在遍历顺序合理的情况下,fg 可以合并为一个数组,如上面的式子 fg 均可以用一个 dp 代替。

需要注意的是,由于回退背包需要对原转移式进行逆向操作,而有些背包的转移式是 不存在逆向操作 的,例如最经典的 01 背包:

fj=max(fj,gjwi+vi)

显然,这样的转移式就没办法回退了。如果要维护这种背包且要支持回退 / 删除,只能另寻他法(比如离线后做线段树分治,代表题目 CF601E

上面的例题作为最基础的回退背包还是太简单了,接下来给大伙上上强度。


题目


1.

Nowcoder890A Blackjack,2019 牛客暑期多校 Day10

题目大意:有 n 张牌,每张牌有一个点数 xi。你可以不断抽牌并随时中止,每次抽牌随机等概率从牌堆抽取一张,抽的牌点数和大于 a 则获胜,但点数和大于 b 则失败。求最优操作下获胜的概率。

1n500, 1a<b500, 1xi500, i=1nxi>b.

首先我们很容易想到 dp。

dpj,k 为抽了 j 张牌,总点数为 k 的方案数。

转移式即为:

dpj,k=dpj,k+dpj1,kxi(j>0, kxi).

那么这似乎就已经做完了。对于选了 j 张总点数为 k 的概率,就是

dpj,kj!(nj)!n!.

后面乘这么一个分式的含义是:选了的牌和没选的牌都可以任意排列,所以情况数是 j!(nj)!,而总排列数是 n!,所以概率就是两者相除。只需要统计 a<kb 的部分就可以了。

但是这样做其实是不对的,样例 2 就是 hack。假如 a=2,b=4,我们选了两张牌 1,3,那么看似我们可以按顺序选 1,33,1,但选 3,1 这种情况是不存在的。因为选了 3 之后已经满足大于 a 的要求了,显然不会再继续选了。

那如何正确计算呢?有一个好的思路是,我们可以钦定最后一张牌选的是哪张,这张牌点数是 xi,那么我们只需要对其他的牌做背包,然后统计

axi<kmin(bxi,a)

的这部分答案就可以了,这样做保证了答案是正确的。

但是显然,每次钦定一张牌都做一次背包是很慢的,所以我们考虑对所有牌做完一次背包后,每次把钦定的那张牌给退掉,这样也达到了相同的效果。

逆向操作一下转移式,可以得到回退转移式。

回退转移式:

dpj,k=dpj,kdpj1,kxi(j>0, kxi).

(由于转移不会影响同一层 j 的答案,所以遍历的顺序没有什么要求

这样做出来的 dp 含义是还没有选第 i 张牌的方案数,由于最后还要把这张牌选上,所以概率计算需要乘以

j!(nj1)!n!

这么一个分式。

需要注意的是,计算 dp 数组的过程中会爆 long long,所以要开 long double。

时间复杂度 O(n3+n2b)

main code

cpp
void solve() {
    int n, a, b;
    cin >> n >> a >> b;
    vector<int> x(n + 1);
    for (int i = 1; i <= n; i++)
        cin >> x[i];

    vector<vector<long double>> dp(n + 1, vector<long double>(b + 1, 0));
    dp[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        for (int j = i; j >= 1; j--) {
            for (int k = x[i]; k <= b; k++) {
                dp[j][k] += dp[j - 1][k - x[i]];
            }
        }
    }

    long double ans = 0;
    for (int i = 1; i <= n; i++) {
        vector<vector<long double>> tdp = dp;
        for (int j = 1; j <= n; j++) {
            for (int k = x[i]; k <= b; k++) {
                tdp[j][k] -= tdp[j - 1][k - x[i]];
            }
        }
        for (int j = 0; j < n; j++) {
            long double sum = 0;
            for (int k = max(0, a - x[i] + 1); k <= min(a, b - x[i]);
                 k++) {
                sum += tdp[j][k];
            }
            for (int k = n - j - 1; k >= 1; k--)
                sum = sum * k / (k + j + 1);
            ans += sum / (j + 1);
        }
    }

    cout << fixed << setprecision(12) << ans << endl;
}

2.

Gym105459E Marble Race,2024 CCPC 哈尔滨站 E

题目大意:m 个球,速度为每单位时间 v1,v2,,vmn 个在负半轴的位置 x1,,xn,每个球都会随机选择一个位置,然后向正方向移动,求坐标中位数在原点的期望时间,对 109+7 取模。

n,m500m 是奇数。

首先有一个不太难想的 O(nm3) 的做法。

由于 m 是奇数,中位数一定是中间那个球的位置。所以我们可以枚举完所有 nm 种某个球在某一个位置的情况。对于一种情况,我们钦定这个球是中间的球。那么在

t=xivj

也就是这个球到达原点的时候,需要有恰好一半的球没到达原点。

所以我们要求的是在这个时间 t,恰好一半的球没到达原点的概率。由于时间确定,所以可以通过枚举位置计算出其他的球无法到达原点的概率。有 p 的概率没到达,1p 的概率能到达,求恰好有

m2

个球还没到达原点的概率,这个可以直接背包去做。做一次的复杂度是 O(m2) 的,于是总复杂度是 O(nm3),显然太慢了。

如何去优化它?观察球的状态数量,虽然一共有 nm 种时间 t,但是对于一个球而言,其能否到达原点的概率 p 实际上只有 n 种不同的取值。也就是每个球的状态只有 n 种,总状态数 n2 种。而一个球在哪个状态只取决于 t 的大小。

所以这启发我们将所有时间 t 进行排序。显然在 t 变化时,不是所有的球状态都发生了变化。具体而言,我们从小到大考虑 t,可以发现每次 t 变化实际上只会影响到这个 t 对应的球的概率 pj,而其他的球是不发生改变的。

所以我们可以先对所有的球进行一次 dp,然后从小到大遍历每个 t,每次回退一个球的贡献,统计答案,然后再把这个球的贡献修改并加入回去。

来看一下转移式。对于一个时间 t 对应的球,有 a 个位置能够让其到达原点,记 dpk 为有 k 个球无法到达原点的概率,其转移式就有:

  • 遍历顺序 k:m0;
  • 对于 k=0dp0=dp0nan;
  • 对于 k>0dpk=dpk1an+dpknan.

那么逆向操作即为:

  • 遍历顺序 k:0m;
  • 对于 k=0dp0=dp0nna;
  • 对于 k>0dpk=(dpkdpk1an)nna.

我们每次是先进行逆向操作,统计贡献后,再正向操作回去。统计贡献前后,因为时间推进,其能够到达原点的位置会多一个,也就是有 a+1=a

为了方便统计,我们可以把 x 先取反再从小到大排序。为了规避浮点误差,每个 t 都用一个 pair 存下 (xi,vj) 来表示。因为排了序的缘故,会发现上面的 axi 的下标 i 是一致的,这样可以方便我们代码的实现。

这样做的时间复杂度降到了 O(nm2)。要注意的是,如果不预处理逆元,则会再带一个 log(109+7) 的常数,这个会 T,所以要预处理一下。

main code

cpp
constexpr int mod = 1e9 + 7;

ll qpow(ll b, ll k) {
    ll res = 1;
    while (k) {
        if (k & 1)
            res = (res * b) % mod;
        b = (b * b) % mod;
        k >>= 1;
    }
    return res;
}

ll inv(ll b) { return qpow(b, mod - 2); }

void solve() {
    int n, m;
    cin >> n >> m;
    vector<ll> x(n + 1), v(m + 1);
    vector<ll> invv(m + 1), invi(n + 1);

    for (int i = 1; i <= n; i++)
        cin >> x[i], x[i] = -x[i], invi[i] = inv(i);
    for (int i = 1; i <= m; i++)
        cin >> v[i], invv[i] = inv(v[i]);
    sort(x.begin() + 1, x.end());

    vector<array<int, 2>> t;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            t.push_back({i, j});
        }
    }
    sort(t.begin(), t.end(), [&](auto A, auto B) {
        return x[A[0]] * v[B[1]] < x[B[0]] * v[A[1]];
    });

    ll ans = 0;
    vector<ll> dp(m + 1, 0);
    dp[0] = 1; // 最开始可以视作时间为 0,没有球能到原点
    for (auto [a, b] : t) {
        dp[0] = dp[0] * (n * invi[n - a + 1] % mod) % mod;
        for (int i = 1; i <= m; i++)
            dp[i] = (dp[i] - dp[i - 1] * ((a - 1) * invi[n] % mod) % mod +
                     mod) %
                    mod * (n * invi[n - a + 1] % mod) % mod;

        ans = (ans + dp[m / 2] * (x[a] * invv[b] % mod) % mod) % mod;

        for (int i = m; i >= 0; i--) {
            dp[i] = dp[i] * ((n - a) * invi[n] % mod) % mod;
            if (i)
                dp[i] =
                    (dp[i] + dp[i - 1] * (a * invi[n] % mod) % mod) % mod;
        }
    }

    cout << ans * invi[n] % mod << endl;
}