移除盒子 - 区间DP算法练习

题目

Leetcode 546. 移除盒子

给出一些不同颜色的盒子 boxes ,盒子的颜色由不同的正数表示。

你将经过若干轮操作去去掉盒子,直到所有的盒子都去掉为止。每一轮你可以移除具有相同颜色的连续 k 个盒子(k >= 1),这样一轮之后你将得到 k * k 个积分。

返回 你能获得的最大积分和 。

示例 1:

1
2
3
4
5
6
7
8
9
输入:boxes = [1,3,2,2,2,3,4,3,1]
输出:23
解释:
[1, 3, 2, 2, 2, 3, 4, 3, 1]
----> [1, 3, 3, 4, 3, 1] (3*3=9 分)
----> [1, 3, 3, 3, 1] (1*1=1 分)
----> [1, 1] (3*3=9 分)
----> [] (2*2=4 分)

示例 2:

1
2
3
输入:boxes = [1,1,1]
输出:9

示例 3:

1
2
输入:boxes = [1]
输出:1

思路

  • 问题可以被分解为多个子问题,通过递归求解并缓存中间结果来提高效率。
  • 递归函数 f(l, r, k) 表示在区间 [l, r] 上,移除盒子的最大分数,k 表示当前右边界 r 右侧有多少个与 boxes[r] 颜色相同的盒子。

详细设计

  1. 构造递归函数 f(l, r, k):
    • l: 左边界,表示当前考虑的盒子区间的左端点。
    • r: 右边界,表示当前考虑的盒子区间的右端点。
    • k: 当前右端点 r 右侧与 boxes[r] 颜色相同的盒子的数量。
  2. 边界判断:
    • 如果l大于r,说明区间为空,返回0
  3. 移除r节点:
    • 移除一个r节点,可以得到的积分是 f(l, r - 1, 0) + (k + 1) ** 2
  4. 合并与 r节点 值相同的节点:
    • 遍历区间 [l, r-1],寻找与右端点 r 颜色相同的盒子,将其与右端点合并,计算新的最大分数。
    • 如果 boxes[i] == boxes[r] 那合并后的可以得到的积分为 f(l, i, k + 1) + f(i + 1, r - 1, 0)),这个和之前的分数比较,保留最大值。
    • 循环结束,返回分数最大值
  5. 递归函数初始: f(0, len(boxes) - 1, 0)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import sys
sys.setrecursionlimit(10**6)

class Solution(object):
def removeBoxes(self, boxes):
"""
:type boxes: List[int]
:rtype: int
"""

def f(l, r, k):
if l > r:
return 0

points = f(l,r-1,0) + (k+1) ** 2

for i in range(l,r):
if boxes[i] == boxes[r]:
points = max( points, f(l,i,k+1) + f(i+1,r-1,0) )

return points

size = len(boxes)
maxpoints = f(0,size-1,0)
return maxpoints

执行结果

可以通过测试用例,提交的时候超时了,卡在下面的数据:

1
[1,2,2,1,1,1,2,1,1,2,1,2,1,1,2,2,1,1,2,2,1,1,1,2,2,2,2,1,2,1,1,2,2,1,2,1,2,2,2,2,2,1,2,1,2,2,1,1,1,2,2,1,2,1,2,2,1,2,1,1,1,2,2,2,2,2,1,2,2,2,2,2,1,1,1,1,1,2,2,2,2,2,1,1,1,1,2,2,1,1,1,1,1,1,1,2,1,2,2,1]

优化

  • 加了合并
  • 加了缓存
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import List
from functools import lru_cache

class Solution:
def removeBoxes(self, boxes: List[int]) -> int:
@lru_cache(None)
def f(l, r, k):
if l > r: return 0

while r > l and boxes[r] == boxes[r-1]:
r -= 1
k += 1

points = f(l,r-1,0) + (k+1) ** 2

for i in range(l,r):
if boxes[i] == boxes[r]:
points = max( points, f(l,i,k+1) + f(i+1,r-1,0) )

return points

size = len(boxes)
maxpoints = f(0,size-1,0)
return maxpoints

请我喝杯咖啡吧~

支付宝
微信