3691. Maximum Total Subarray Value II
題目 / Problem
中文
給你一個長度為 n 的整數陣列 nums 和一個整數 k。
你必須選出恰好 k 個不同的非空子陣列 nums[l..r]。子陣列可以重疊,但同一個子陣列(相同的 l 和 r)不能被選兩次。
一個子陣列 nums[l..r] 的價值定義為 max(nums[l..r]) - min(nums[l..r])(區間最大值減去最小值)。總價值是所有被選子陣列價值的總和。請回傳能達到的最大總價值。
English
You are given an integer array nums of length n and an integer k. You must pick exactly k distinct non-empty subarrays nums[l..r]. Subarrays may overlap, but the same (l, r) pair cannot be chosen twice. The value of nums[l..r] is max(nums[l..r]) - min(nums[l..r]). The total value is the sum of the chosen subarrays' values. Return the maximum total value possible.
Constraints
- 1 <= n == nums.length <= 5 * 10^4
- 0 <= nums[i] <= 10^9
- 1 <= k <= min(10^5, n * (n + 1) / 2)
Worked example
nums = [1,3,2], k = 2 → 4. Pick [1,3] (value 3-1=2) and [1,3,2] (value 3-1=2). Total 2 + 2 = 4.
名詞解釋 / Glossary
- 子陣列 / Subarray:陣列中連續的一段元素
nums[l..r]。注意和「子序列」不同,子陣列必須相鄰連續。 - 區間最大/最小值查詢 RMQ / Range Min/Max Query:給定區間
[l, r],快速回答這段裡的最大值或最小值。 - 稀疏表 / Sparse Table:一種預處理結構,先花
O(n log n)建表,之後每次區間最大/最小查詢只要O(1)。它存的是「從i開始、長度為2^j的區間」的答案。 - 最大堆(優先佇列)/ Max-Heap (Priority Queue):一種容器,能在
O(log n)時間內取出目前最大的元素。C++ 的priority_queue預設就是最大堆。 - 貪心 / Greedy:每一步都拿「當下最好的選項」,並能證明這樣得到全局最佳解。
- 單調性 / Monotonicity:某個量隨參數變化只朝一個方向走(只增不減,或只減不增)。本題的關鍵性質就是一種單調性。
- 指標解參考 / Top-k merge of sorted lists:把多個已排序的列表合併,取最大的
k個值——正是堆的經典用法。
思路
中文
最暴力的想法是:枚舉所有 O(n^2) 個子陣列,算出每個的價值,排序後取最大的 k 個相加。但 n 可達 5*10^4,子陣列數量高達約 1.25*10^9,完全無法枚舉,必須找更聰明的辦法。
關鍵觀察(提示一):固定左端點 l,當右端點 r 往右移時,價值 v(l,r)=max-min 只會變大或不變,不會變小。 為什麼?把區間往右延伸一格,多了一個元素:最大值只可能持平或變大,最小值只可能持平或變小,所以「最大減最小」只會持平或變大。反過來說,固定 l 時,v(l, n-1) 是這個 l 能達到的最大價值,而 v(l, r) 隨 r 從 n-1 往左遞減是非遞增的。
於是問題變成:我有 n 條「已排序(由大到小)」的鏈,第 l 條鏈是 v(l, n-1) >= v(l, n-2) >= ...,我要從所有鏈中取出最大的 k 個值相加。這正是「合併多個有序列表取前 k 大」的經典題,用最大堆解決:先把每條鏈的頭(即每個 l 的 v(l, n-1))放進堆;每次彈出堆頂(最大的),累加進答案,並把同一條鏈的下一個元素 v(l, r-1)(若 r>l)推進堆。重複 k 次即可。因為每條鏈本身遞減,堆頂一定是所有「尚未取出」的值裡最大的,貪心正確。
最後一個問題:怎麼快速算 v(l, r)?用稀疏表對 max 和 min 各建一個 RMQ,預處理 O(n log n),之後每次查詢 O(1)。整體 O((n+k) log n),足夠快。
English
The brute force is to enumerate all O(n^2) subarrays, compute each value, sort, and sum the top k. But with n up to 5*10^4 there are about 1.25*10^9 subarrays — impossible to list. We need structure.
The key observation (hint 1): for a fixed left endpoint l, the value v(l,r)=max-min is non-decreasing as r grows. Extending the window one step to the right adds an element; the max can only stay or rise, the min can only stay or fall, so max - min can only stay or rise. Equivalently, for fixed l the biggest value sits at r = n-1, and as r shrinks the values are non-increasing.
So each left endpoint l defines a sorted (descending) chain v(l, n-1) >= v(l, n-2) >= ..., and we want the k largest values across all n chains. That is the classic "merge sorted lists, take top k" pattern, solved with a max-heap: seed it with every chain's head v(l, n-1); repeatedly pop the maximum, add it to the answer, and push the same chain's next element v(l, r-1) (when r > l). Do this k times. Because each chain is itself descending, the heap's top is always the global maximum among all not-yet-taken values, so the greedy is correct.
The remaining piece is computing v(l, r) fast. Build two sparse tables (one for range max, one for range min) in O(n log n); each query is then O(1). Overall O((n+k) log n), comfortably fast.
逐步走查 / Walkthrough
Input: nums = [1,3,2], k = 2, so n = 3.
First, each chain's head v(l, n-1) = v(l, 2):
| l | 區間 / range | max | min | value |
|---|---|---|---|---|
| 0 | [1,3,2] |
3 | 1 | 2 |
| 1 | [3,2] |
3 | 2 | 1 |
| 2 | [2] |
2 | 2 | 0 |
Heap is seeded with these three (value, l, r) triples. Now pop k = 2 times:
| Step | Heap (value@(l,r)) | Pop 堆頂 / pop top | ans 累加 / running sum | Push next 推入 |
|---|---|---|---|---|
| init | 2@(0,2), 1@(1,2), 0@(2,2) |
— | 0 | — |
| 1 | top = 2@(0,2) |
take 2@(0,2) |
0 + 2 = 2 | r>l so push v(0,1)=max(1,3)-min(1,3)=3-1=2 → 2@(0,1) |
| 2 | 2@(0,1), 1@(1,2), 0@(2,2) |
take 2@(0,1) |
2 + 2 = 4 | r=1 > l=0, could push v(0,0)=0, but we already did k=2 pops → stop |
Answer = 4. This matches the expected output: the two chosen subarrays are [1,3,2] and [1,3], each worth 2.
Solution — C
#include <stdlib.h>
// 演算法 / Algorithm:
// 固定左端 l 時,v(l,r)=max-min 隨 r 增大而不減 → 每個 l 的最大值在 r=n-1。
// For fixed l, v(l,r) is non-decreasing in r, so its max is at r=n-1.
// 用最大堆裝每個 l 的 v(l,n-1),彈 k 次堆頂;彈出 (l,r) 後若 r>l 推入 (l,r-1)。
// Seed a max-heap with v(l,n-1) per l; pop k times, after popping (l,r) push (l,r-1).
// 稀疏表讓區間 max/min 查詢變 O(1) / Sparse tables make range max/min O(1).
static int LOG[50005]; // LOG[x] = floor(log2(x)),查詢時用 / used to pick the right table level
// 堆的節點:價值、左端、右端 / heap node: value, left index, right index
typedef struct { int val; int l; int r; } Node;
static Node *heap; // 堆的陣列 / array backing the heap
static int heapSize; // 目前堆裡的元素個數 / current number of elements
// 把節點推入最大堆,並向上調整 / push a node and sift it up to keep max-heap order
static void hpush(Node x) {
int i = heapSize++; // 放在末端,i 為新位置 / place at the end
heap[i] = x;
while (i > 0) { // 與父節點比較,比父大就上浮 / bubble up while bigger than parent
int p = (i - 1) / 2; // 父節點索引 / parent index in a 0-based binary heap
if (heap[p].val >= heap[i].val) break; // 父已不小,停止 / parent not smaller → done
Node t = heap[p]; heap[p] = heap[i]; heap[i] = t; // 交換 / swap with parent
i = p;
}
}
// 彈出並回傳堆頂(最大值),再向下調整 / pop the max (root) and sift down to restore order
static Node hpop(void) {
Node top = heap[0]; // 堆頂就是最大值 / root holds the maximum
heap[0] = heap[--heapSize]; // 把最後一個搬到頂端 / move last element to the root
int i = 0;
while (1) {
int l = 2 * i + 1, r = 2 * i + 2, m = i; // 左右子節點索引 / children indices
if (l < heapSize && heap[l].val > heap[m].val) m = l; // 找較大的子節點 / pick larger child
if (r < heapSize && heap[r].val > heap[m].val) m = r;
if (m == i) break; // 已經比兩個子節點大,停止 / heap property restored
Node t = heap[m]; heap[m] = heap[i]; heap[i] = t; // 與較大子節點交換 / swap down
i = m;
}
return top;
}
// 用稀疏表 O(1) 算 v(l,r)=max-min / compute v(l,r)=max-min in O(1) via sparse tables
// mx/mn 是攤平的二維表,索引方式 mx[level*n + i] / mx,mn are flat 2D tables: mx[level*n + i]
static int rangeVal(int *mx, int *mn, int n, int l, int r) {
int j = LOG[r - l + 1]; // 選最大的 2^j <= 區間長度 / largest power-of-two ≤ length
int span = 1 << j; // 2^j,位移即乘法 / 1<<j means 2^j
// 兩段長度 2^j 的區間覆蓋 [l,r](可重疊)/ two length-2^j windows cover [l,r] (overlap is fine for max/min)
int hiA = mx[j * n + l], hiB = mx[j * n + (r - span + 1)];
int hi = hiA > hiB ? hiA : hiB; // 區間最大 / range maximum
int loA = mn[j * n + l], loB = mn[j * n + (r - span + 1)];
int lo = loA < loB ? loA : loB; // 區間最小 / range minimum
return hi - lo; // 價值 = 最大 - 最小 / value = max - min
}
long long maxTotalValue(int* nums, int numsSize, int k) {
int n = numsSize;
// 預計算 LOG 表 / precompute floor-log2 for query speed
LOG[1] = 0;
for (int i = 2; i <= n; i++) LOG[i] = LOG[i / 2] + 1;
int K = LOG[n] + 1; // 稀疏表的層數 / number of levels needed
// 配置兩張稀疏表(攤平)/ allocate the two sparse tables (flattened)
int *mx = (int*)malloc((size_t)K * n * sizeof(int)); // malloc 動態配記憶體 / dynamic allocation
int *mn = (int*)malloc((size_t)K * n * sizeof(int));
for (int i = 0; i < n; i++) { mx[i] = nums[i]; mn[i] = nums[i]; } // 第 0 層 = 單一元素 / level 0 = single elements
// 用上一層拼出這一層 / build each level from the previous one
for (int j = 1; j < K; j++) {
int half = 1 << (j - 1); // 上一層區間長度 / previous level's window length
for (int i = 0; i + (1 << j) <= n; i++) {
int a = (j - 1) * n + i; // 左半段 / left half starting at i
int b = (j - 1) * n + i + half; // 右半段 / right half starting at i+half
mx[j * n + i] = mx[a] > mx[b] ? mx[a] : mx[b]; // 合併取大 / combine by taking max
mn[j * n + i] = mn[a] < mn[b] ? mn[a] : mn[b]; // 合併取小 / combine by taking min
}
}
// 堆最多同時放 n 個(每個 l 一個)/ heap holds at most n nodes (one per left endpoint l)
heap = (Node*)malloc((size_t)n * sizeof(Node));
heapSize = 0;
for (int l = 0; l < n; l++) { // 放入每條鏈的頭 v(l,n-1) / seed each chain's head
Node x = { rangeVal(mx, mn, n, l, n - 1), l, n - 1 };
hpush(x);
}
long long ans = 0; // 答案可達 1e5*1e9=1e14,必須用 long long / sum can reach 1e14
for (int t = 0; t < k; t++) { // 取最大的 k 個價值 / take the k largest values
Node top = hpop(); // 目前所有未取值中的最大 / global current maximum
ans += top.val;
if (top.r > top.l) { // 同一條鏈還有下一個 / chain still has a next element
Node nx = { rangeVal(mx, mn, n, top.l, top.r - 1), top.l, top.r - 1 };
hpush(nx); // 推入 v(l,r-1) / push the next (smaller) value of this chain
}
}
free(mx); free(mn); free(heap); // 釋放記憶體,避免洩漏 / free to avoid memory leaks
return ans;
}
Solution — C++
#include <vector>
#include <queue>
#include <tuple>
#include <algorithm>
using namespace std;
// 演算法 / Algorithm:
// 固定 l 時 v(l,r)=max-min 隨 r 增大不減;每個 l 形成一條由大到小的鏈。
// For fixed l, v(l,r) is non-decreasing in r; each l forms a descending chain.
// 最大堆合併 n 條鏈,取前 k 大;稀疏表讓 v(l,r) 查詢 O(1)。
// A max-heap merges the n chains to take the top-k; sparse tables answer v(l,r) in O(1).
class Solution {
public:
long long maxTotalValue(vector<int>& nums, int k) {
int n = nums.size();
// 計算需要的層數 K / how many sparse-table levels we need
int K = 1;
while ((1 << K) <= n) K++; // 1<<K 是 2^K / 1<<K means 2^K
// vector<vector<int>> 是二維陣列 / a 2D array; auto-initialised to 0
vector<vector<int>> mx(K, vector<int>(n)), mn(K, vector<int>(n));
for (int i = 0; i < n; i++) mx[0][i] = mn[0][i] = nums[i]; // 第 0 層 / level 0 = the elements
for (int j = 1; j < K; j++) // 由上一層建本層 / build level j from level j-1
for (int i = 0; i + (1 << j) <= n; i++) {
int half = 1 << (j - 1); // 上一層窗長 / previous window length
mx[j][i] = max(mx[j-1][i], mx[j-1][i + half]); // 取兩半較大 / max of the two halves
mn[j][i] = min(mn[j-1][i], mn[j-1][i + half]); // 取兩半較小 / min of the two halves
}
// lambda:捕捉 [&] 表示以參考存取外部變數 / lambda capturing surrounding vars by reference
auto val = [&](int l, int r) -> int {
int j = 31 - __builtin_clz(r - l + 1); // __builtin_clz 算前導零,得 floor(log2) / floor-log2 trick
int span = 1 << j;
int hi = max(mx[j][l], mx[j][r - span + 1]); // 區間最大 / range max
int lo = min(mn[j][l], mn[j][r - span + 1]); // 區間最小 / range min
return hi - lo; // 價值 / the value
};
// tuple<int,int,int> = (value, l, r);priority_queue 預設最大堆,依第一元素排序
// priority_queue is a max-heap by default; tuples compare by first element first
priority_queue<tuple<int,int,int>> pq;
for (int l = 0; l < n; l++)
pq.emplace(val(l, n - 1), l, n - 1); // emplace 直接就地建構元素 / construct in place
long long ans = 0; // 用 long long 防溢位 / long long to avoid overflow (up to ~1e14)
while (k--) { // 取最大的 k 個 / take the k largest values
auto [v, l, r] = pq.top(); // 結構化綁定一次拆出三個值 / structured bindings unpack the tuple
pq.pop();
ans += v;
if (r > l) // 同一條鏈還有下一個 / this chain has a smaller next value
pq.emplace(val(l, r - 1), l, r - 1); // 推入 v(l,r-1) / push it
}
return ans;
}
};
複雜度 / Complexity
- Time:
O((n + k) log n)— 建稀疏表O(n log n)(log n層、每層掃n)。堆一開始放n個元素;之後彈/推共k次,每次O(log n),每次價值查詢是稀疏表的O(1)。n是陣列長度,k是要選的子陣列數。Building the sparse tables costsO(n log n); the heap is seeded withnitems and then runskpop/push pairs atO(log n)each, with each value query beingO(1). - Space:
O(n log n)— 兩張稀疏表各O(n log n),是主要消耗;堆最多只放n個節點,是O(n)。Dominated by the two sparse tables; the heap never exceedsnnodes.
Pitfalls & Edge Cases
- 整數溢位 / Overflow:總和最多
10^5 * 10^9 = 10^14,超過 32-bitint上限。答案累加器一定要用long long;單一價值max-min <= 10^9仍可放進int。The running sum can hit10^14, so the accumulator must belong long; individual values still fit inint. - 每個
l只放一個堆元素 / One heap entry per chain:彈出(l,r)後只推(l,r-1),不是把整條鏈倒進去。這保證堆大小始終<= n,且取出的順序恰好是全局由大到小。Only push the immediate next element of the same chain — this keeps the heap at size≤ nand guarantees a globally sorted pop order. - 單調性方向別搞反 / Mind the monotonicity direction:
v(l,r)是隨r增大 而變大。所以鏈頭(最大值)在r=n-1,往r-1走是變小。若方向搞反,貪心就錯了。The values grow asrgrows, so the chain head is atr=n-1and moving tor-1goes down. r > l的邊界 / Stop condition:當r==l(單一元素,價值 0)就不再往下推,否則會越界存取。Stop pushing once the window is a single element to avoid out-of-bounds access.n = 1的退化情形 / Single-element array:唯一子陣列價值為0,k必為1,答案0。程式天然處理:LOG[1]=0、稀疏表只有第 0 層、查詢回傳0。Handled naturally — the only subarray has value0.__builtin_clz(0)未定義 /__builtin_clz(0)is UB:本題查詢長度r-l+1 >= 1,永遠不會傳入 0,因此安全。clz(1)=31正確給出j=0。The query length is always≥ 1, so the clz trick is safe here.