2025-11-30:树中找到带权中位节点。用go语言,给出一个含 n 个节点(编号 0 到 n-1)的带权无向树,树的根定为节
2025-11-30:树中找到带权中位节点。用go语言,给出一个含 n 个节点(编号 0 到 n-1)的带权无向树,树的根定为节点 0。树用长度为 n-1 的数组 edges 描述,每个 edges[i] = [ui, vi, wi] 表示 ui 与 vi 之间有一条权值为 wi 的边。
在两节点间的路径上,把从起点累积经过的边权和视作距离。所谓“带权中位节点”是指沿从起点 ui 到终点 vi 的路径,从 ui 出发第一个使得累计边权达到(或超过)整条路径总权重一半的节点 x。
现在给出若干查询 queries,每个 queries[j] = [uj, vj] 要求找出 uj 到 vj 路径上的带权中位节点。输出一个数组 ans,其中 ans[j] 是对应查询的带权中位节点编号。
2 <= n <= 100000。
edges.length == n - 1。
edges[i] == [ui, vi, wi]。
0 <= ui, vi < n。
1 <= wi <= 1000000000。
1 <= queries.length <= 100000。
queries[j] == [uj, vj]。
0 <= uj, vj < n。
输入保证 edges 表示一棵合法的树。
输入: n = 2, edges = [[0,1,7]], queries = [[1,0],[0,1]]。
输出: [0,1]。
解释:

| 查询 | 路径 | 边权 | 总路径权值和 | 一半 | 解释 | 答案 |
|---|---|---|---|---|---|---|
| [1, 0] | 1 → 0 | [7] | 7 | 3.5 | 从 1 → 0 的权重和为 7 ≥ 3.5,中位节点是 0。 | 0 |
| [0, 1] | 0 → 1 | [7] | 7 | 3.5 | 从 0 → 1 的权重和为 7 ≥ 3.5,中位节点是 1。 | 1 |
题目来自力扣3585。
步骤概述
- 图的构建:将边列表转换为邻接表表示的树结构。
- LCA预处理:通过DFS计算节点深度和距离,并构建倍增表以支持快速祖先查询。
- 查询处理:对每个查询,计算路径总权值、确定中位点位置,并利用倍增跳跃定位节点。
- 时间复杂度:预处理阶段O(n log n),查询阶段O(q log n),总复杂度O((n + q) log n)。
- 空间复杂度:主要开销来自存储树结构和倍增表,为O(n log n)。
详细分步过程
步骤1: 构建树结构(邻接表)
- 输入:边列表edges,每条边包含两个节点和边权值。
- 过程:
- 初始化一个大小为n的邻接表g,每个节点对应一个列表,存储相邻节点及边权。
- 遍历所有边,由于树是无向的,每条边在邻接表中双向添加(例如,边(u, v, w)会同时添加到g[u]和g[v]的列表中)。
- 目的:为后续DFS遍历提供高效的邻接关系查询。
步骤2: LCA预处理(DFS和倍增表构建)
- DFS遍历(计算深度和距离):
- 从根节点0开始递归遍历树。
- 维护三个数组:
- dep[]:记录每个节点到根节点的深度(根节点深度为0)。
- dis[]:记录每个节点到根节点的路径权值累加和(根节点距离为0)。
- pa[][]:倍增表,pa[x][i]表示节点x的2^i级祖先节点。
- 对于当前节点x,遍历其所有邻居节点y(跳过父节点避免循环)。更新y的深度dep[y] = dep[x] + 1,距离dis[y] = dis[x] + 边权。同时记录y的直接父节点pa[y][0] = x。
- 构建倍增表:
- 计算最大跳跃层级mx = ceil(log₂(n))(例如n=100,000时,mx≈17)。
- 通过动态规划填充pa数组:对于每个层级i(从1到mx-1),遍历所有节点x,若pa[x][i-1]存在,则pa[x][i] = pa[pa[x][i-1]][i-1](即x的2^i祖先等于x的2^{i-1}祖先的2^{i-1}祖先)。
- 目的:将任意两点路径查询转化为O(log n)时间的跳跃操作。
步骤3: 处理查询(定位带权中位节点)
对每个查询queries[j] = [u, v],执行以下子步骤:
- 特判相同节点:若u == v,直接返回u作为中位节点(路径权值为0,节点自身即中点)。
- 计算LCA和路径总权值:
- 调用getLCA(u, v)找到最近公共祖先lca(算法:先将u和v调整到同一深度,然后同步向上跳跃直至相遇)。
- 路径总权值dist = dis[u] + dis[v] - 2*dis[lca](利用到根节点距离的差值计算)。
- 计算半权值阈值half = (dist + 1) / 2(向上取整,确保累计权值≥一半)。
- 确定中位节点位置:
- 判断u到lca的子路径权值是否足够覆盖half:
- 若dis[u] - dis[lca] ≥ half:
- 中位节点位于u到lca的路径上。
- 从u向上回溯至多half-1权值(通过uptoDis函数):沿倍增表从高位到低位尝试跳跃,确保跳跃后累计距离不超过half-1。
- 此时到达节点to,中位节点是to的父节点pa[to][0](再跳一步即超过half)。
- 否则中位节点位于v到lca的路径上:
- 从v向上回溯权值dist - half(即从v出发走剩余路径达到half)。
- 直接调用uptoDis(v, dist - half)定位节点,该节点即为中位节点。
- 若dis[u] - dis[lca] ≥ half:
- 判断u到lca的子路径权值是否足够覆盖half:
- 输出结果:将每个查询的结果存入答案数组ans。
示例验证(针对输入n=2, edges=[[0,1,7]], queries=[[1,0],[0,1]])
- 查询[1,0]:
- LCA为0,路径总权值=7,half=4。
- dis[1]-dis[0]=7≥4,中位在1→0路径。从1回溯min(4-1,7)=3权值(实际回溯0权值,因半路已超),跳至父节点0,输出0。
- 查询[0,1]:
- 路径相同,half=4。dis[0]-dis[0]=0<4,中位在1→0路径。从1回溯7-4=3权值(实际回溯至1本身),输出1。
时间复杂度和空间复杂度
- 时间复杂度:
- 预处理:DFS遍历O(n),倍增表构建O(n log n)。
- 每个查询:LCA计算O(log n),路径权值计算O(1),跳跃操作O(log n)。
- 总时间:O(n log n + q log n),适用于n, q ≤ 100,000。
- 空间复杂度:
- 邻接表O(n),倍增表O(n log n),dep/dis数组O(n)。
- 总空间:O(n log n)。
Go完整代码如下:
package main
import (
"fmt"
"math/bits"
)
func findMedian(n int, edges [][]int, queries [][]int) []int {
type edge struct{ to, wt int }
g := make([][]edge, n)
for _, e := range edges {
x, y, wt := e[0], e[1], e[2]
g[x] = append(g[x], edge{y, wt})
g[y] = append(g[y], edge{x, wt})
}
// 17 可以替换成 bits.Len(uint(n)),但数组内存连续性更好
pa := make([][17]int, n)
dep := make([]int, n)
dis := make([]int, n)
var dfs func(int, int)
dfs = func(x, p int) {
pa[x][0] = p
for _, e := range g[x] {
y := e.to
if y == p {
continue
}
dep[y] = dep[x] + 1
dis[y] = dis[x] + e.wt
dfs(y, x)
}
}
dfs(0, -1)
mx := bits.Len(uint(n))
for i := range mx - 1 {
for x := range pa {
p := pa[x][i]
if p != -1 {
pa[x][i+1] = pa[p][i]
} else {
pa[x][i+1] = -1
}
}
}
uptoDep := func(x, d int) int {
for k := uint(dep[x] - d); k > 0; k &= k - 1 {
x = pa[x][bits.TrailingZeros(k)]
}
return x
}
// 返回 x 和 y 的最近公共祖先(节点编号从 0 开始)
getLCA := func(x, y int) int {
if dep[x] > dep[y] {
x, y = y, x
}
y = uptoDep(y, dep[x]) // 使 y 和 x 在同一深度
if y == x {
return x
}
for i := mx - 1; i >= 0; i-- {
px, py := pa[x][i], pa[y][i]
if px != py {
x, y = px, py // 同时往上跳 2^i 步
}
}
return pa[x][0]
}
// 从 x 往上跳【至多】d 距离,返回最远能到达的节点
uptoDis := func(x, d int) int {
dx := dis[x]
for i := mx - 1; i >= 0; i-- {
p := pa[x][i]
if p != -1 && dx-dis[p] <= d { // 可以跳至多 d
x = p
}
}
return x
}
// 以上是 LCA 模板
ans := make([]int, len(queries))
for i, q := range queries {
x, y := q[0], q[1]
if x == y {
ans[i] = x
continue
}
lca := getLCA(x, y)
disXY := dis[x] + dis[y] - dis[lca]*2
half := (disXY + 1) / 2
if dis[x]-dis[lca] >= half { // 答案在 x-lca 路径中
// 先往上跳至多 half-1,然后再跳一步,就是至少 half
to := uptoDis(x, half-1)
ans[i] = pa[to][0] // 再跳一步
} else { // 答案在 y-lca 路径中
// 从 y 出发至多 disXY-half,就是从 x 出发至少 half
ans[i] = uptoDis(y, disXY-half)
}
}
return ans
}
func main() {
n := 2
edges := [][]int{{0, 1, 7}}
queries := [][]int{{1, 0}, {0, 1}}
result := findMedian(n, edges, queries)
fmt.Println(result)
}

Python完整代码如下:
# -*-coding:utf-8-*-
import math
from typing import List
def findMedian(n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
# 构建图的邻接表
graph = [[] for _ in range(n)]
for e in edges:
x, y, wt = e
graph[x].append((y, wt))
graph[y].append((x, wt))
# 计算倍增数组的深度
mx = n.bit_length()
# 初始化数组
parent = [[-1] * mx for _ in range(n)]
depth = [0] * n
distance = [0] * n
# DFS预处理
def dfs(x: int, p: int):
parent[x][0] = p
for y, wt in graph[x]:
if y == p:
continue
depth[y] = depth[x] + 1
distance[y] = distance[x] + wt
dfs(y, x)
dfs(0, -1)
# 构建倍增数组
for i in range(mx - 1):
for x in range(n):
p = parent[x][i]
if p != -1:
parent[x][i + 1] = parent[p][i]
else:
parent[x][i + 1] = -1
# 将节点x提升到深度d
def upto_depth(x: int, d: int) -> int:
k = depth[x] - d
while k > 0:
step = k & -k # 获取最低位的1
x = parent[x][step.bit_length() - 1]
k -= step
return x
# 获取最近公共祖先
def get_lca(x: int, y: int) -> int:
if depth[x] > depth[y]:
x, y = y, x
y = upto_depth(y, depth[x])
if y == x:
return x
for i in range(mx - 1, -1, -1):
px, py = parent[x][i], parent[y][i]
if px != py:
x, y = px, py
return parent[x][0]
# 从x向上跳至多d距离
def upto_distance(x: int, d: int) -> int:
dx = distance[x]
for i in range(mx - 1, -1, -1):
p = parent[x][i]
if p != -1 and dx - distance[p] <= d:
x = p
return x
# 处理查询
result = []
for q in queries:
x, y = q
if x == y:
result.append(x)
continue
lca = get_lca(x, y)
dis_xy = distance[x] + distance[y] - 2 * distance[lca]
half = (dis_xy + 1) // 2
if distance[x] - distance[lca] >= half:
# 答案在x到lca的路径上
to = upto_distance(x, half - 1)
result.append(parent[to][0])
else:
# 答案在y到lca的路径上
result.append(upto_distance(y, dis_xy - half))
return result
# 测试代码
if __name__ == "__main__":
n = 2
edges = [[0, 1, 7]]
queries = [[1, 0], [0, 1]]
result = findMedian(n, edges, queries)
print(result)

C++完整代码如下:
#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
using namespace std;
struct Edge {
int to, wt;
};
class TreeMedianFinder {
public:
int n, mx;
vector<vector<Edge>> g;
vector<vector<int>> pa; // pa[x][i]:x 的 2^i 级祖先
vector<int> dep, dis;
TreeMedianFinder(int n, const vector<vector<int>>& edges) : n(n) {
g.assign(n, {});
for (auto& e : edges) {
int x = e[0], y = e[1], wt = e[2];
g[x].push_back({y, wt});
g[y].push_back({x, wt});
}
mx = 32 - __builtin_clz(n); // bits.Len(n)
pa.assign(n, vector<int>(mx, -1));
dep.assign(n, 0);
dis.assign(n, 0);
dfs(0, -1);
// 倍增预处理
for (int i = 0; i < mx - 1; i++) {
for (int x = 0; x < n; x++) {
if (pa[x][i] != -1)
pa[x][i + 1] = pa[pa[x][i]][i];
else
pa[x][i + 1] = -1;
}
}
}
void dfs(int x, int p) {
pa[x][0] = p;
for (auto& e : g[x]) {
int y = e.to;
if (y == p) continue;
dep[y] = dep[x] + 1;
dis[y] = dis[x] + e.wt;
dfs(y, x);
}
}
// 跳到指定深度
int uptoDep(int x, int d) {
int diff = dep[x] - d;
while (diff > 0) {
int k = __builtin_ctz(diff); // 低位 1 的位置
x = pa[x][k];
diff &= diff - 1;
}
return x;
}
// 最近公共祖先
int getLCA(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
y = uptoDep(y, dep[x]);
if (x == y) return x;
for (int i = mx - 1; i >= 0; i--) {
if (pa[x][i] != pa[y][i]) {
x = pa[x][i];
y = pa[y][i];
}
}
return pa[x][0];
}
// 从 x 往上跳至多 d 距离
int uptoDis(int x, int d) {
int dx = dis[x];
for (int i = mx - 1; i >= 0; i--) {
int p = pa[x][i];
if (p != -1 && dx - dis[p] <= d) {
x = p;
}
}
return x;
}
vector<int> solveQueries(const vector<vector<int>>& queries) {
vector<int> ans;
ans.reserve(queries.size());
for (auto& q : queries) {
int x = q[0], y = q[1];
if (x == y) {
ans.push_back(x);
continue;
}
int lca = getLCA(x, y);
int disXY = dis[x] + dis[y] - 2 * dis[lca];
int half = (disXY + 1) / 2;
if (dis[x] - dis[lca] >= half) {
// 在 x-lca 路径中
int to = uptoDis(x, half - 1);
ans.push_back(pa[to][0]);
} else {
// 在 y-lca 路径中
ans.push_back(uptoDis(y, disXY - half));
}
}
return ans;
}
};
int main() {
int n = 2;
vector<vector<int>> edges = {{0, 1, 7}};
vector<vector<int>> queries = {{1, 0}, {0, 1}};
TreeMedianFinder solver(n, edges);
vector<int> result = solver.solveQueries(queries);
for (int x : result) {
cout << x << " ";
}
cout << endl;
return 0;
}

- 点赞
- 收藏
- 关注作者
评论(0)