[BOJ] 백준 15681 트리와 쿼리 - Python/Java

kindof

·

2021. 11. 28. 14:02

문제

https://www.acmicpc.net/problem/15681

 

15681번: 트리와 쿼리

트리의 정점의 수 N과 루트의 번호 R, 쿼리의 수 Q가 주어진다. (2 ≤ N ≤ 105, 1 ≤ R ≤ N, 1 ≤ Q ≤ 105) 이어 N-1줄에 걸쳐, U V의 형태로 트리에 속한 간선의 정보가 주어진다. (1 ≤ U, V ≤ N, U ≠ V)

www.acmicpc.net

해설

해당 문제의 해설은 문제에서 주어지는 [힌트] 부분에 상세하게 제시되어 있어서 이 내용을 따라서 구현하기만 하면 정답을 구할 수 있습니다.

 

그래도 간단하게나마 제 해설을 적어보겠습니다.

 

1. 이 문제에서는 노드들을 그래프 형태로 입력받은 뒤, 트리의 루트 노드가 주어지기 때문에 루트를 기준으로 DFS를 수행하면 트리 구조를 만들어 낼 수 있습니다.

 

즉, 문제의 [힌트] 부분에서 설명하는 것처럼, 트리에서는 어떤 정점의 부모는 하나이거나 없습니다. 따라서, 어떤 정점에 대해 연결된 모든 정점은 최대 한 개의 정점을 제외하면 모두 해당 정점의 자식들이 됩니다. 이에 따라, 부모 정점의 정보를 가져가면서, 부모 정점이 아니면서 자신과 연결되어 있는 모든 정점을 자신의 자식으로, 자신의 자식이 될 정점들의 부모 정점을 자신으로 연결한 뒤 재귀적으로 자식 정점들에게 트리 구성을 요청하면 그래프로 입력받은 노드들을 트리 형태로 변환할 수 있게 됩니다.

 

2. 자신을 루트 노드로 하는 서브트리의 노드 개수를 구하기 위해서는 DP를 이용합니다.

 

맨 아래 자식부터 올라오면서 작은 서브트리의 노드 개수를 구하게 되면, 현재 노드의 관점에서 봤을 때 자식들을 루트로 하는 서브트리의 노드 개수가 구해지게 됩니다.

 

1, 2번에서 설명한 내용이 각각 makeTree(), countSubtreeNodes() 함수로 구현되어 있습니다. 

 

풀이 - 파이썬

import sys
from collections import defaultdict
sys.setrecursionlimit(100000000)
input = sys.stdin.readline

def makeTree(currentNode, parent):
    for node in graph[currentNode]:
        if node != parent:
            childInfo[currentNode].append(node)
            parentInfo[node] = currentNode
            makeTree(node, currentNode)

def countSubtreeNodes(currentNode):
    size[currentNode] = 1 # 자신도 자신을 루트로 하는 서브트리에 포함되므로 0이 아닌 1에서 시작
    for node in childInfo[currentNode]:
        countSubtreeNodes(node)
        size[currentNode] += size[node]

n, r, q = map(int, input().split()) # 정점의 수, 루트 번호, 쿼리의 수
childInfo = defaultdict(list) # 자신의 자식들을 저장
parentInfo = defaultdict() # 자신의 부모를 저장
size = [0] * (n+1)
graph = [[] for _ in range(n+1)] # 그래프 형태로 노드를 입력받는다.
for _ in range(n-1):
    a, b = map(int, input().split())
    graph[a].append(b)
    graph[b].append(a)

makeTree(r, -1) # 주어진 루트를 중심으로 그래프를 트리 형태로 변환한다.
countSubtreeNodes(r) # 각 노드에 대해 서브트리의 정점 개수를 구한다.
for _ in range(q): # 결과 출력
    print(size[int(input())])

 

풀이 - 자바

package PS;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

public class B_15681 {
    static int n, r, q;
    static int[] size;
    static ArrayList<Node>[] graph;

    public static void makeTree(Node currentNode, int parent){
        for(Node node : graph[currentNode.data]){
            if(node.data != parent){
                currentNode.addChild(node);
                node.setParent(currentNode.data);
                makeTree(node, currentNode.data);
            }
        }
    }

    public static void countSubtreeNodes(Node currentNode){
        size[currentNode.data] = 1;
        for(Node node : currentNode.child){
            countSubtreeNodes(node);
            size[currentNode.data] += size[node.data];
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        r = Integer.parseInt(st.nextToken());
        q = Integer.parseInt(st.nextToken());

        size = new int[n+1];
        graph = new ArrayList[n+1];
        for (int i = 1; i < n + 1; i++) graph[i] = new ArrayList<>();
        for(int i = 1; i < n; i++){ // 그래프 형태로 노드 입력받기
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            graph[a].add(new Node(b));
            graph[b].add(new Node(a));
        }

        Node root = new Node(r);
        makeTree(root, -1); // 주어진 루트를 중심으로 그래프를 트리 형태로 변환한다.
        countSubtreeNodes(root); // 각 노드에 대해 서브트리의 정점 개수를 구한다.
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < q; i++){
            int count = size[Integer.parseInt(br.readLine())];
            sb.append(count).append('\n');
        }
        System.out.print(sb);
    }


    static class Node{
        int data;
        int parent;
        ArrayList<Node> child;

        public Node(int data){
            this.data = data;
            this.child = new ArrayList<>();
        }

        public void addChild(Node child){
            this.child.add(child);
        }
        public void setParent(int parent){
            this.parent = parent;
        }
    }
}