2025-10-12:求出数组的 X 值Ⅱ。用go语言,给出一个只含正整数的数组 nums 和一个正整数 k,同时有若干个查询
2025-10-12:求出数组的 X 值Ⅱ。用go语言,给出一个只含正整数的数组 nums 和一个正整数 k,同时有若干个查询 queries,其中每个查询是四元组 [index, value, start, x]。
可以对数组做一次这样的变动:从末尾删除若干个元素(可以不删,但删除后数组必须至少保留一个元素)。对于某个给定的值 x,把数组的“x 值”定义为:在可以进行上述一次末尾删减的所有可行方案中,使得剩下元素的乘积对 k 取模后等于 x 的方案数。
对于每个查询,按下面顺序处理:
-
将 nums[index] 更新为 value(该修改会影响后续所有查询)。
-
从数组开头删除前 start 个元素(start 可能为 0,表示不删除),得到一个新的子数组。
-
在该子数组内,再次可选择删除其末尾若干元素(可以不删,但最终子数组必须非空),统计使得剩余元素乘积 mod k 等于给定的 xi 的删法数量。
输出一个与 queries 等长的数组 result,其中 result[i] 是第 i 个查询在上述步骤后得到的答案。
1 <= nums[i] <= 1000000000。
1 <= nums.length <= 100000。
1 <= k <= 5。
1 <= queries.length <= 20000。
queries[i] == [indexi, valuei, starti, xi]。
0 <= indexi <= nums.length - 1。
1 <= valuei <= 1000000000。
0 <= starti <= nums.length - 1。
0 <= xi <= k - 1。
输入: nums = [1,1,2,1,1], k = 2, queries = [[2,1,0,1]]。
输出: [5]。
题目来自力扣3525。
1. 关键观察
- 因为
k
很小(最大 5),所以模k
的结果只有0, 1, ..., k-1
这几种可能。 - 乘积模
k
只与每个数模k
的值有关,所以我们可以将每个数取nums[i] % k
来简化计算。 - 问题等价于:给定一个数组,求它的所有非空前缀(因为从末尾删除元素,剩下的就是前缀)中,乘积模
k
等于x
的个数。
注意:这里“从末尾删除”意味着:对于子数组 arr[0..m-1]
,我们可以保留 arr[0..0]
, arr[0..1]
, …, arr[0..m-1]
这些前缀,统计这些前缀里乘积模 k == x
的数量。
2. 数据结构设计
为了高效处理区间查询和单点更新,这里使用了线段树(Segment Tree)。
2.1 线段树节点设计
每个节点 Node
包含:
product
:该区间所有元素的乘积模k
counts
:一个长度为k
的数组,counts[r]
表示:从该区间左端点开始,依次取前缀,这些前缀的乘积模k
等于r
的个数。
例如,对于区间 [a, b, c]
,前缀有:
a
→ 乘积 mod k = r1a*b
→ 乘积 mod k = r2a*b*c
→ 乘积 mod k = r3
counts
会统计这些 r1, r2, r3 中每个余数的出现次数。
2.2 线段树合并操作
当合并左右两个子节点 left
和 right
时:
- 新节点的
product = (left.product * right.product) % k
- 新节点的
counts
初始化为left.counts
的副本 - 对于
right
的每个余数r
(即right.counts[r]
表示右区间的前缀乘积模 k 为 r 的数量),
这些前缀在拼接左区间后,乘积模 k 变为(left.product * r) % k
,所以要把right.counts[r]
加到新节点的counts[(left.product * r) % k]
上。
这样,新节点就统计了整个区间所有前缀的乘积模 k 的分布。
2.3 叶子节点
叶子节点只有一个元素 v
,它的 product = v % k
,counts
中只有 counts[v % k] = 1
,其余为 0。
3. 查询处理流程
对于每个查询 [index, value, start, x]
:
- 单点更新:调用线段树的
Update(index, value)
,将nums[index]
改为value % k
,并更新所有相关节点。 - 区间查询:调用线段树的
Query(start, n-1, x)
,查询区间[start, n-1]
对应的节点。 - 该节点的
counts[x]
就是我们要的答案,因为counts[x]
统计了从start
开始到末尾的所有前缀(即从start
开始,依次增加一个元素直到末尾)的乘积模 k 等于 x 的数量,这正好对应了“从末尾删除若干个元素”的所有方案。
4. 例子说明
以题目给的例子:
nums = [1, 1, 2, 1, 1], k = 2
queries = [[2, 1, 0, 1]]
初始数组模 2 后为 [1, 1, 0, 1, 1]
。
查询处理:
- 更新
nums[2] = 1
(模 2 后为 1),数组变为[1, 1, 1, 1, 1]
。 start = 0
,子数组为整个数组[1, 1, 1, 1, 1]
。- 计算所有非空前缀的乘积模 2:
- [1] → 1
- [1,1] → 1
- [1,1,1] → 1
- [1,1,1,1] → 1
- [1,1,1,1,1] → 1
所有前缀模 2 都是 1,所以x = 1
的方案数 = 5。
输出 [5]
。
5. 复杂度分析
时间复杂度
- 建树:O(n × k),因为每个节点合并是 O(k),n 个元素建树 O(n × k)。
- 每次更新:O(k × log n),因为更新路径长度 O(log n),每层合并 O(k)。
- 每次查询:O(k × log n),原因同上。
- 总复杂度:O((n + q × log n) × k),其中 q 是查询数,k 很小可视为常数,所以约 O((n + q) log n)。
额外空间复杂度
- 线段树节点数 O(n),每个节点 O(k) 空间,所以总空间 O(n × k) = O(n)(因为 k 很小)。
- 递归栈深度 O(log n)。
总结:
- 总时间复杂度:O((n + q) log n)
- 总额外空间复杂度:O(n)
Go完整代码如下:
package main
import (
"fmt"
)
type Node struct {
product int
counts []int
}
type SegmentTree struct {
n int
k int
tree []*Node
}
func resultArray(nums []int, k int, queries [][]int) []int {
st := NewSegmentTree(nums, k)
n := len(nums)
result := make([]int, len(queries))
for i, query := range queries {
index, value, start, x := query[0], query[1], query[2], query[3]
st.Update(index, value)
result[i] = st.Query(start, n-1, x)
}
return result
}
func NewSegmentTree(nums []int, k int) *SegmentTree {
n := len(nums)
st := &SegmentTree{
n: n,
k: k,
tree: make([]*Node, 4*n),
}
st.build(0, n-1, 0, nums)
return st
}
func (st *SegmentTree) Query(start, end, x int) int {
node := st.query(start, end, 0, 0, st.n-1)
return node.counts[x]
}
func (st *SegmentTree) Update(index, value int) {
st.update(index, value, 0, 0, st.n-1)
}
func (st *SegmentTree) build(start, end, treeIndex int, nums []int) {
if start == end {
st.tree[treeIndex] = st.newNode(nums[start])
return
}
mid := start + (end-start)/2
st.build(start, mid, treeIndex*2+1, nums)
st.build(mid+1, end, treeIndex*2+2, nums)
st.tree[treeIndex] = st.merge(st.tree[treeIndex*2+1], st.tree[treeIndex*2+2])
}
func (st *SegmentTree) query(rangeStart, rangeEnd, treeIndex, treeStart, treeEnd int) *Node {
if rangeStart == treeStart && rangeEnd == treeEnd {
return st.tree[treeIndex]
}
mid := treeStart + (treeEnd-treeStart)/2
if rangeEnd <= mid {
return st.query(rangeStart, rangeEnd, treeIndex*2+1, treeStart, mid)
} else if rangeStart > mid {
return st.query(rangeStart, rangeEnd, treeIndex*2+2, mid+1, treeEnd)
} else {
leftNode := st.query(rangeStart, mid, treeIndex*2+1, treeStart, mid)
rightNode := st.query(mid+1, rangeEnd, treeIndex*2+2, mid+1, treeEnd)
return st.merge(leftNode, rightNode)
}
}
func (st *SegmentTree) update(rangeIndex, value, treeIndex, start, end int) {
if start == end {
st.tree[treeIndex] = st.newNode(value)
return
}
mid := start + (end-start)/2
if rangeIndex <= mid {
st.update(rangeIndex, value, treeIndex*2+1, start, mid)
} else {
st.update(rangeIndex, value, treeIndex*2+2, mid+1, end)
}
st.tree[treeIndex] = st.merge(st.tree[treeIndex*2+1], st.tree[treeIndex*2+2])
}
func (st *SegmentTree) newNode(product int) *Node {
product %= st.k
counts := make([]int, st.k)
counts[product] = 1
return &Node{product, counts}
}
func (st *SegmentTree) merge(node1, node2 *Node) *Node {
counts := make([]int, st.k)
// Copy counts from node1
for x := 0; x < st.k; x++ {
counts[x] = node1.counts[x]
}
// Combine with node2
for x := 0; x < st.k; x++ {
newIndex := (node1.product * x) % st.k
counts[newIndex] += node2.counts[x]
}
product := (node1.product * node2.product) % st.k
return &Node{product, counts}
}
func main() {
nums := []int{1, 1, 2, 1, 1}
k := 2
queries := [][]int{{2, 1, 0, 1}}
result := resultArray(nums, k, queries)
fmt.Println(result)
}
Python完整代码如下:
# -*-coding:utf-8-*-
class Node:
def __init__(self, product, k):
self.product = product % k
self.counts = [0] * k
self.counts[self.product] = 1
class SegmentTree:
def __init__(self, nums, k):
self.n = len(nums)
self.k = k
self.tree = [None] * (4 * self.n)
self._build(0, self.n - 1, 0, nums)
def query(self, start, end, x):
node = self._query(start, end, 0, 0, self.n - 1)
return node.counts[x]
def update(self, index, value):
self._update(index, value, 0, 0, self.n - 1)
def _build(self, start, end, tree_index, nums):
if start == end:
self.tree[tree_index] = Node(nums[start], self.k)
return
mid = start + (end - start) // 2
self._build(start, mid, tree_index * 2 + 1, nums)
self._build(mid + 1, end, tree_index * 2 + 2, nums)
self.tree[tree_index] = self._merge(self.tree[tree_index * 2 + 1],
self.tree[tree_index * 2 + 2])
def _query(self, range_start, range_end, tree_index, tree_start, tree_end):
if range_start == tree_start and range_end == tree_end:
return self.tree[tree_index]
mid = tree_start + (tree_end - tree_start) // 2
if range_end <= mid:
return self._query(range_start, range_end, tree_index * 2 + 1,
tree_start, mid)
elif range_start > mid:
return self._query(range_start, range_end, tree_index * 2 + 2,
mid + 1, tree_end)
else:
left_node = self._query(range_start, mid, tree_index * 2 + 1,
tree_start, mid)
right_node = self._query(mid + 1, range_end, tree_index * 2 + 2,
mid + 1, tree_end)
return self._merge(left_node, right_node)
def _update(self, range_index, value, tree_index, start, end):
if start == end:
self.tree[tree_index] = Node(value, self.k)
return
mid = start + (end - start) // 2
if range_index <= mid:
self._update(range_index, value, tree_index * 2 + 1, start, mid)
else:
self._update(range_index, value, tree_index * 2 + 2, mid + 1, end)
self.tree[tree_index] = self._merge(self.tree[tree_index * 2 + 1],
self.tree[tree_index * 2 + 2])
def _merge(self, node1, node2):
counts = [0] * self.k
# Copy counts from node1
for x in range(self.k):
counts[x] = node1.counts[x]
# Combine with node2
for x in range(self.k):
new_index = (node1.product * x) % self.k
counts[new_index] += node2.counts[x]
product = (node1.product * node2.product) % self.k
merged_node = Node(product, self.k)
merged_node.counts = counts
return merged_node
def resultArray(nums, k, queries):
st = SegmentTree(nums, k)
n = len(nums)
result = []
for query in queries:
index, value, start, x = query
st.update(index, value)
result.append(st.query(start, n - 1, x))
return result
# 测试代码
if __name__ == "__main__":
nums = [1, 1, 2, 1, 1]
k = 2
queries = [[2, 1, 0, 1]]
result = resultArray(nums, k, queries)
print(result)
- 点赞
- 收藏
- 关注作者
评论(0)