2025-11-30:树中找到带权中位节点。用go语言,给出一个含 n 个节点(编号 0 到 n-1)的带权无向树,树的根定为节

举报
福大大架构师每日一题 发表于 2025/11/30 09:28:38 2025/11/30
【摘要】 2025-11-30:树中找到带权中位节点。用go语言,给出一个含 n 个节点(编号 0 到 n-1)的带权无向树,树的根定为节点 0。树用长度为 n-1 的数组 edges 描述,每个 edges[i] = [ui, vi, wi] 表示 ui 与 vi 之间有一条权值为 wi 的边。在两节点间的路径上,把从起点累积经过的边权和视作距离。所谓“带权中位节点”是指沿从起点 ui 到终点 vi...

2025-11-30:树中找到带权中位节点。用go语言,给出一个含 n 个节点(编号 0 到 n-1)的带权无向树,树的根定为节点 0。树用长度为 n-1 的数组 edges 描述,每个 edges[i] = [ui, vi, wi] 表示 ui 与 vi 之间有一条权值为 wi 的边。

在两节点间的路径上,把从起点累积经过的边权和视作距离。所谓“带权中位节点”是指沿从起点 ui 到终点 vi 的路径,从 ui 出发第一个使得累计边权达到(或超过)整条路径总权重一半的节点 x。

现在给出若干查询 queries,每个 queries[j] = [uj, vj] 要求找出 uj 到 vj 路径上的带权中位节点。输出一个数组 ans,其中 ans[j] 是对应查询的带权中位节点编号。

2 <= n <= 100000。

edges.length == n - 1。

edges[i] == [ui, vi, wi]。

0 <= ui, vi < n。

1 <= wi <= 1000000000。

1 <= queries.length <= 100000。

queries[j] == [uj, vj]。

0 <= uj, vj < n。

输入保证 edges 表示一棵合法的树。

输入: n = 2, edges = [[0,1,7]], queries = [[1,0],[0,1]]。

输出: [0,1]。

解释:

在这里插入图片描述

查询 路径 边权 总路径权值和 一半 解释 答案
[1, 0] 1 → 0 [7] 7 3.5 从 1 → 0 的权重和为 7 ≥ 3.5,中位节点是 0。 0
[0, 1] 0 → 1 [7] 7 3.5 从 0 → 1 的权重和为 7 ≥ 3.5,中位节点是 1。 1

题目来自力扣3585。

步骤概述

  1. 图的构建:将边列表转换为邻接表表示的树结构。
  2. LCA预处理:通过DFS计算节点深度和距离,并构建倍增表以支持快速祖先查询。
  3. 查询处理:对每个查询,计算路径总权值、确定中位点位置,并利用倍增跳跃定位节点。
  4. 时间复杂度:预处理阶段O(n log n),查询阶段O(q log n),总复杂度O((n + q) log n)。
  5. 空间复杂度:主要开销来自存储树结构和倍增表,为O(n log n)。

详细分步过程

步骤1: 构建树结构(邻接表)

  • 输入:边列表edges,每条边包含两个节点和边权值。
  • 过程
    • 初始化一个大小为n的邻接表g,每个节点对应一个列表,存储相邻节点及边权。
    • 遍历所有边,由于树是无向的,每条边在邻接表中双向添加(例如,边(u, v, w)会同时添加到g[u]和g[v]的列表中)。
  • 目的:为后续DFS遍历提供高效的邻接关系查询。

步骤2: LCA预处理(DFS和倍增表构建)

  • DFS遍历(计算深度和距离)
    • 从根节点0开始递归遍历树。
    • 维护三个数组:
      • dep[]:记录每个节点到根节点的深度(根节点深度为0)。
      • dis[]:记录每个节点到根节点的路径权值累加和(根节点距离为0)。
      • pa[][]:倍增表,pa[x][i]表示节点x的2^i级祖先节点。
    • 对于当前节点x,遍历其所有邻居节点y(跳过父节点避免循环)。更新y的深度dep[y] = dep[x] + 1,距离dis[y] = dis[x] + 边权。同时记录y的直接父节点pa[y][0] = x。
  • 构建倍增表
    • 计算最大跳跃层级mx = ceil(log₂(n))(例如n=100,000时,mx≈17)。
    • 通过动态规划填充pa数组:对于每个层级i(从1到mx-1),遍历所有节点x,若pa[x][i-1]存在,则pa[x][i] = pa[pa[x][i-1]][i-1](即x的2^i祖先等于x的2^{i-1}祖先的2^{i-1}祖先)。
  • 目的:将任意两点路径查询转化为O(log n)时间的跳跃操作。

步骤3: 处理查询(定位带权中位节点)

对每个查询queries[j] = [u, v],执行以下子步骤:

  1. 特判相同节点:若u == v,直接返回u作为中位节点(路径权值为0,节点自身即中点)。
  2. 计算LCA和路径总权值
    • 调用getLCA(u, v)找到最近公共祖先lca(算法:先将u和v调整到同一深度,然后同步向上跳跃直至相遇)。
    • 路径总权值dist = dis[u] + dis[v] - 2*dis[lca](利用到根节点距离的差值计算)。
    • 计算半权值阈值half = (dist + 1) / 2(向上取整,确保累计权值≥一半)。
  3. 确定中位节点位置
    • 判断u到lca的子路径权值是否足够覆盖half:
      • 若dis[u] - dis[lca] ≥ half:
        • 中位节点位于u到lca的路径上。
        • 从u向上回溯至多half-1权值(通过uptoDis函数):沿倍增表从高位到低位尝试跳跃,确保跳跃后累计距离不超过half-1。
        • 此时到达节点to,中位节点是to的父节点pa[to][0](再跳一步即超过half)。
      • 否则中位节点位于v到lca的路径上:
        • 从v向上回溯权值dist - half(即从v出发走剩余路径达到half)。
        • 直接调用uptoDis(v, dist - half)定位节点,该节点即为中位节点。
  4. 输出结果:将每个查询的结果存入答案数组ans。

示例验证(针对输入n=2, edges=[[0,1,7]], queries=[[1,0],[0,1]])

  • 查询[1,0]
    • LCA为0,路径总权值=7,half=4。
    • dis[1]-dis[0]=7≥4,中位在1→0路径。从1回溯min(4-1,7)=3权值(实际回溯0权值,因半路已超),跳至父节点0,输出0。
  • 查询[0,1]
    • 路径相同,half=4。dis[0]-dis[0]=0<4,中位在1→0路径。从1回溯7-4=3权值(实际回溯至1本身),输出1。

时间复杂度和空间复杂度

  • 时间复杂度
    • 预处理:DFS遍历O(n),倍增表构建O(n log n)。
    • 每个查询:LCA计算O(log n),路径权值计算O(1),跳跃操作O(log n)。
    • 总时间:O(n log n + q log n),适用于n, q ≤ 100,000。
  • 空间复杂度
    • 邻接表O(n),倍增表O(n log n),dep/dis数组O(n)。
    • 总空间:O(n log n)。

Go完整代码如下:

package main

import (
	"fmt"
	"math/bits"
)

func findMedian(n int, edges [][]int, queries [][]int) []int {
	type edge struct{ to, wt int }
	g := make([][]edge, n)
	for _, e := range edges {
		x, y, wt := e[0], e[1], e[2]
		g[x] = append(g[x], edge{y, wt})
		g[y] = append(g[y], edge{x, wt})
	}

	// 17 可以替换成 bits.Len(uint(n)),但数组内存连续性更好
	pa := make([][17]int, n)
	dep := make([]int, n)
	dis := make([]int, n)

	var dfs func(int, int)
	dfs = func(x, p int) {
		pa[x][0] = p
		for _, e := range g[x] {
			y := e.to
			if y == p {
				continue
			}
			dep[y] = dep[x] + 1
			dis[y] = dis[x] + e.wt
			dfs(y, x)
		}
	}
	dfs(0, -1)

	mx := bits.Len(uint(n))
	for i := range mx - 1 {
		for x := range pa {
			p := pa[x][i]
			if p != -1 {
				pa[x][i+1] = pa[p][i]
			} else {
				pa[x][i+1] = -1
			}
		}
	}

	uptoDep := func(x, d int) int {
		for k := uint(dep[x] - d); k > 0; k &= k - 1 {
			x = pa[x][bits.TrailingZeros(k)]
		}
		return x
	}

	// 返回 x 和 y 的最近公共祖先(节点编号从 0 开始)
	getLCA := func(x, y int) int {
		if dep[x] > dep[y] {
			x, y = y, x
		}
		y = uptoDep(y, dep[x]) // 使 y 和 x 在同一深度
		if y == x {
			return x
		}
		for i := mx - 1; i >= 0; i-- {
			px, py := pa[x][i], pa[y][i]
			if px != py {
				x, y = px, py // 同时往上跳 2^i 步
			}
		}
		return pa[x][0]
	}

	// 从 x 往上跳【至多】d 距离,返回最远能到达的节点
	uptoDis := func(x, d int) int {
		dx := dis[x]
		for i := mx - 1; i >= 0; i-- {
			p := pa[x][i]
			if p != -1 && dx-dis[p] <= d { // 可以跳至多 d
				x = p
			}
		}
		return x
	}

	// 以上是 LCA 模板

	ans := make([]int, len(queries))
	for i, q := range queries {
		x, y := q[0], q[1]
		if x == y {
			ans[i] = x
			continue
		}
		lca := getLCA(x, y)
		disXY := dis[x] + dis[y] - dis[lca]*2
		half := (disXY + 1) / 2
		if dis[x]-dis[lca] >= half { // 答案在 x-lca 路径中
			// 先往上跳至多 half-1,然后再跳一步,就是至少 half
			to := uptoDis(x, half-1)
			ans[i] = pa[to][0] // 再跳一步
		} else { // 答案在 y-lca 路径中
			// 从 y 出发至多 disXY-half,就是从 x 出发至少 half
			ans[i] = uptoDis(y, disXY-half)
		}
	}
	return ans
}

func main() {
	n := 2
	edges := [][]int{{0, 1, 7}}
	queries := [][]int{{1, 0}, {0, 1}}
	result := findMedian(n, edges, queries)
	fmt.Println(result)
}

在这里插入图片描述

Python完整代码如下:

# -*-coding:utf-8-*-

import math
from typing import List

def findMedian(n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
    # 构建图的邻接表
    graph = [[] for _ in range(n)]
    for e in edges:
        x, y, wt = e
        graph[x].append((y, wt))
        graph[y].append((x, wt))
    
    # 计算倍增数组的深度
    mx = n.bit_length()
    
    # 初始化数组
    parent = [[-1] * mx for _ in range(n)]
    depth = [0] * n
    distance = [0] * n
    
    # DFS预处理
    def dfs(x: int, p: int):
        parent[x][0] = p
        for y, wt in graph[x]:
            if y == p:
                continue
            depth[y] = depth[x] + 1
            distance[y] = distance[x] + wt
            dfs(y, x)
    
    dfs(0, -1)
    
    # 构建倍增数组
    for i in range(mx - 1):
        for x in range(n):
            p = parent[x][i]
            if p != -1:
                parent[x][i + 1] = parent[p][i]
            else:
                parent[x][i + 1] = -1
    
    # 将节点x提升到深度d
    def upto_depth(x: int, d: int) -> int:
        k = depth[x] - d
        while k > 0:
            step = k & -k  # 获取最低位的1
            x = parent[x][step.bit_length() - 1]
            k -= step
        return x
    
    # 获取最近公共祖先
    def get_lca(x: int, y: int) -> int:
        if depth[x] > depth[y]:
            x, y = y, x
        y = upto_depth(y, depth[x])
        if y == x:
            return x
        
        for i in range(mx - 1, -1, -1):
            px, py = parent[x][i], parent[y][i]
            if px != py:
                x, y = px, py
        return parent[x][0]
    
    # 从x向上跳至多d距离
    def upto_distance(x: int, d: int) -> int:
        dx = distance[x]
        for i in range(mx - 1, -1, -1):
            p = parent[x][i]
            if p != -1 and dx - distance[p] <= d:
                x = p
        return x
    
    # 处理查询
    result = []
    for q in queries:
        x, y = q
        if x == y:
            result.append(x)
            continue
            
        lca = get_lca(x, y)
        dis_xy = distance[x] + distance[y] - 2 * distance[lca]
        half = (dis_xy + 1) // 2
        
        if distance[x] - distance[lca] >= half:
            # 答案在x到lca的路径上
            to = upto_distance(x, half - 1)
            result.append(parent[to][0])
        else:
            # 答案在y到lca的路径上
            result.append(upto_distance(y, dis_xy - half))
    
    return result

# 测试代码
if __name__ == "__main__":
    n = 2
    edges = [[0, 1, 7]]
    queries = [[1, 0], [0, 1]]
    result = findMedian(n, edges, queries)
    print(result)

在这里插入图片描述

C++完整代码如下:

#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
using namespace std;

struct Edge {
    int to, wt;
};

class TreeMedianFinder {
public:
    int n, mx;
    vector<vector<Edge>> g;
    vector<vector<int>> pa; // pa[x][i]:x 的 2^i 级祖先
    vector<int> dep, dis;

    TreeMedianFinder(int n, const vector<vector<int>>& edges) : n(n) {
        g.assign(n, {});
        for (auto& e : edges) {
            int x = e[0], y = e[1], wt = e[2];
            g[x].push_back({y, wt});
            g[y].push_back({x, wt});
        }
        mx = 32 - __builtin_clz(n); // bits.Len(n)
        pa.assign(n, vector<int>(mx, -1));
        dep.assign(n, 0);
        dis.assign(n, 0);

        dfs(0, -1);

        // 倍增预处理
        for (int i = 0; i < mx - 1; i++) {
            for (int x = 0; x < n; x++) {
                if (pa[x][i] != -1)
                    pa[x][i + 1] = pa[pa[x][i]][i];
                else
                    pa[x][i + 1] = -1;
            }
        }
    }

    void dfs(int x, int p) {
        pa[x][0] = p;
        for (auto& e : g[x]) {
            int y = e.to;
            if (y == p) continue;
            dep[y] = dep[x] + 1;
            dis[y] = dis[x] + e.wt;
            dfs(y, x);
        }
    }

    // 跳到指定深度
    int uptoDep(int x, int d) {
        int diff = dep[x] - d;
        while (diff > 0) {
            int k = __builtin_ctz(diff); // 低位 1 的位置
            x = pa[x][k];
            diff &= diff - 1;
        }
        return x;
    }

    // 最近公共祖先
    int getLCA(int x, int y) {
        if (dep[x] > dep[y]) swap(x, y);
        y = uptoDep(y, dep[x]);
        if (x == y) return x;
        for (int i = mx - 1; i >= 0; i--) {
            if (pa[x][i] != pa[y][i]) {
                x = pa[x][i];
                y = pa[y][i];
            }
        }
        return pa[x][0];
    }

    // 从 x 往上跳至多 d 距离
    int uptoDis(int x, int d) {
        int dx = dis[x];
        for (int i = mx - 1; i >= 0; i--) {
            int p = pa[x][i];
            if (p != -1 && dx - dis[p] <= d) {
                x = p;
            }
        }
        return x;
    }

    vector<int> solveQueries(const vector<vector<int>>& queries) {
        vector<int> ans;
        ans.reserve(queries.size());
        for (auto& q : queries) {
            int x = q[0], y = q[1];
            if (x == y) {
                ans.push_back(x);
                continue;
            }
            int lca = getLCA(x, y);
            int disXY = dis[x] + dis[y] - 2 * dis[lca];
            int half = (disXY + 1) / 2;
            if (dis[x] - dis[lca] >= half) {
                // 在 x-lca 路径中
                int to = uptoDis(x, half - 1);
                ans.push_back(pa[to][0]);
            } else {
                // 在 y-lca 路径中
                ans.push_back(uptoDis(y, disXY - half));
            }
        }
        return ans;
    }
};

int main() {
    int n = 2;
    vector<vector<int>> edges = {{0, 1, 7}};
    vector<vector<int>> queries = {{1, 0}, {0, 1}};

    TreeMedianFinder solver(n, edges);
    vector<int> result = solver.solveQueries(queries);

    for (int x : result) {
        cout << x << " ";
    }
    cout << endl;
    return 0;
}

在这里插入图片描述

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。