[백준(Python/Java)] 4386_별자리 만들기 - 최소 신장 트리

kindof

·

2021. 9. 18. 14:16

https://www.acmicpc.net/problem/4386

 

4386번: 별자리 만들기

도현이는 우주의 신이다. 이제 도현이는 아무렇게나 널브러져 있는 n개의 별들을 이어서 별자리를 하나 만들 것이다. 별자리의 조건은 다음과 같다. 별자리를 이루는 선은 서로 다른 두 별을 일

www.acmicpc.net

최소 신장 트리(Minimum Spanning Tree)를 이용하는 기본적인 문제로 저는 우선순위 큐를 이용해 해결했습니다.

 

MST는 모든 정점을 연결하는 최소 비용의 간선들의 집합을 구해야 하므로, 임의의 정점 하나를 잡고 시작합니다. 이 임의의 정점도 반드시 언젠가는 포함되어야 하기 때문이죠.

 

그리고 해당 정점에서 연결된 간선들 중에서 최소 비용인 간선을 뽑고, 해당 간선과 이어진 정점이 방문하지 않은 정점이라면 큐에 추가하는 식으로 문제를 해결합니다.

 

그렇게 되면 Greedy한 방식으로 가장 최소 비용을 갖으며, 싸이클이 발생하지 않는 정점을 우선적으로 방문하게 되죠.

 

이렇게 방문하는 정점의 개수가 원래 주어진 정점의 개수 N과 같아지면 탐색을 종료합니다.

 

 

[풀이 - 파이썬]

import heapq

def getDistance(x1, y1, x2, y2):
    return ((x1-x2) ** 2 + (y1-y2) ** 2) ** 0.5

def findMST():
    pq = []
    heapq.heappush(pq, (0, 0))
    count = 0
    answer = 0
    while count < n:
        currDistance, start = heapq.heappop(pq)
        if not visited[start]:
            visited[start] = True
            answer += currDistance
            count += 1
            for i in range(n):
                dist = costMatrix[start][i]
                heapq.heappush(pq, (dist, i))
    return answer

n = int(input())
stars = [list(map(float, input().split())) for _ in range(n)]
costMatrix = [[0] * n for _ in range(n)]
for i in range(n):
    for j in range(n):
        costMatrix[i][j] = getDistance(stars[i][0], stars[i][1], stars[j][0], stars[j][1])

visited = [False] * n
print(round(findMST(),2))

 

[풀이 - 자바]

package PS;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class B_4386 {
    static int n;
    static Point[] stars;
    static double[][] costMatrix;
    static boolean[] visited;
    public static class Point{
        double x,y;
        Point(double x, double y){
            this.x = x;
            this.y = y;
        }
    }

    public static class Edge implements Comparable<Edge>{
        double dist;
        int nodeNum;
        public Edge(double dist, int nodeNum){
            this.dist = dist;
            this.nodeNum = nodeNum;
        }

        @Override
        public int compareTo(Edge other){
            return (int) (this.dist - other.dist);
        }
    }


    public static double getDistance(Point a, Point b){
        return Math.sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
    }

    public static double findMST(){
        PriorityQueue<Edge> q = new PriorityQueue<>();
        q.offer(new Edge(0,0));
        int count = 0;
        double mst = 0;

        while (count < n){
            Edge edge = q.poll();
            double currDistance = edge.dist;
            int nodeNum = edge.nodeNum;

            if(!visited[nodeNum]){
                visited[nodeNum] = true;
                mst += currDistance;
                count += 1;
                for(int i = 0; i < n; i++){
                    double dist = costMatrix[nodeNum][i];
                    q.offer(new Edge(dist, i));
                }
            }
        }
        return Math.round(mst * 100) / 100.0;
    }

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());

        stars = new Point[n];
        for(int i = 0; i < n; i++){
            st = new StringTokenizer(br.readLine(), " ");
            double x = Double.parseDouble(st.nextToken());
            double y = Double.parseDouble(st.nextToken());
            stars[i] = new Point(x,y);
        }

        costMatrix = new double[n][n];
        for(int i =0; i< n; i++){
            for(int j = 0; j < n; j++){
                costMatrix[i][j] = getDistance(stars[i], stars[j]);
            }
        }

        visited = new boolean[n];
        double answer = findMST();
        System.out.println(answer);
    }
}