Skip to main content
  1. Posts/
  2. Algorithm/

BOJ 32144 트리를 쓰는 트리 문제

·358 words·2 mins
Jiho Kim
Author
Jiho Kim
달려 또 달려

📝 문제 정보
#

🧐 관찰 및 접근
#

  • 문제에 내 이름이 나와서 기분이 좋다.
  • 서브트리를 하나 잡아서, 그 서브트리의 루트와 그 부모와의 연결을 끊고 서브트리의 임의의 정점과 방금 끊은 부모를 연결하는 연산을 할 수 있다.
    • 직관적으로, 해당 서브트리에서 줄 수 있는 길이의 가중치가 깊이에서 지름으로 바뀜을 알 수 있다.
    • Tree DP를 이용한 트리의 지름 구하기 방법을 이용하면 좋을 것 같다.
    • ![[Drawing 2026-01-25 10.09.49.excalidraw.png]]
    • 위와 같은 경우같은게 발생할 것 같다.
    • 두번째로 작은 트리의 지름과 마찬가지로, 기존 지름의 양 끝 점중 한 점은 유지된다.
      • …인줄알았는데 안된다.
      • ![[Drawing 2026-01-25 11.07.01.excalidraw.png]]
      • 다음과 같이 서브노드 안에 기존 트리의 지름이 다 있는 경우도 있다…
    • 결국 문제에서도 말하는 것처럼 문제를 부분트리와 그 나머지로 보면, 리루팅이 가능하지 않을까?
  • 리루팅을 이용해서, 다음과 같은 정보를 저장하자.
    • 서브트리의 지름
    • 해당 서브트리에서 가장 깊은 깊이 두개 (리루팅용)
    • 해당 노드에서 위로 갔을때, 가장 깊은 길이
    • 이는 dfs를 이용해서 구현 가능하고, 시간복잡도는 $O(N)$이다.
  • 디버깅을 위한 예제 2번 그림
    • ![[Drawing 2026-01-25 10.30.52.excalidraw.png]]

💻 풀이
#

  • 코드 (C++):
int N;
vector<int> links[300005];
vector<pii> max_depth[300005];
int diam[300005];
int updepth[300005], updiam[300005];

vector<pii> dfs(int cur, int par){
    vector<pii> v;
    v.push_back({0, cur});
    for(auto nxt: links[cur]){
        if(nxt == par) continue;
        auto nv = dfs(nxt, cur);
        diam[cur] = max(diam[cur], diam[nxt]);
        v.push_back({nv[0].first + 1, nxt});
        sort(all(v), greater<pii>());
        while(v.size() > 2) v.pop_back();
    }
    if(v.size() >= 2) diam[cur] = max(diam[cur], v[0].first + v[1].first);
    else diam[cur] = max(diam[cur], v[0].first);
    max_depth[cur] = v;

    // cout << "cur: " << cur << " diam: " << diam[cur] << " max_depth: ";
    // for(auto p: v) cout << "(" << p.first << "," << p.second << ") ";
    // cout << '\n';
    return v;
}

void dfs2(int cur, int par){
    for(auto nxt: links[cur]){
        if(nxt == par) continue;

        updepth[nxt] = updepth[cur] + 1;
        if(max_depth[cur][0].second != nxt) updepth[nxt] = max(updepth[nxt], max_depth[cur][0].first + 1);
        else if(max_depth[cur].size() >= 2) updepth[nxt] = max(updepth[nxt], max_depth[cur][1].first + 1);
        dfs2(nxt, cur);
    }
}

void solve(){
    cin >> N;
    rep(i, 2, N+1){
        int p; cin >> p;
        links[p].push_back(i);
        links[i].push_back(p);
    }

    dfs(1, -1);
    dfs2(1, -1);
    rep(i, 2, N+1) cout << max(diam[1], updepth[i] + diam[i]) << "\n";
}
🔒

구현 코드 잠금

아래 쿠팡 링크를 방문하시면 코드가 공개됩니다.
광고 수익이 블로그 운영에 도움이 됩니다 🙏

🛒 쿠팡 방문하고 코드 보기

방문 후 잠금이 자동으로 해제됩니다