알고리즘

Leaf_node 세는 방법

코딩 악귀 2024. 2. 18. 17:26

문제

 

1068번: 트리

첫째 줄에 트리의 노드의 개수 N이 주어진다. N은 50보다 작거나 같은 자연수이다. 둘째 줄에는 0번 노드부터 N-1번 노드까지, 각 노드의 부모가 주어진다. 만약 부모가 없다면 (루트) -1이 주어진다

www.acmicpc.net

 


 

재정의

트리가 구성되어 있을 때, leaf_node를 세리는 문제이다.

특정 노드 한 개를 삭제하였을 때, 남는 leaf_node의 수를 구하면 된다.

leaf_node를 어떻게 다룰지가 핵심인 문제이다.

 

재귀적으로 leaf_node 조합하기

트리는 재귀적 구조로 구성되어 응용하기 좋다.

특정 노드에서, 자식들의 가진 leaf_node를 더하면 총 leaf를 구할 수 있다.

 

이렇게 구현하면 한 가지 문제가 생긴다.

 

A라는 노드에서 자식 B 하나를 가진다.

B 노드는 leaf_node이다.

B노드를 제거한다.

 

이 상황에서 B노드가 제거되어 leaf_node는 한 개가 감소하게 된다.

하지만 A 노드가 새로운 leaf_node가 되어 한 개가 수복 된다.

 

즉, leaf_node를 전부 다 세고 제거하는 방식에는

사이드 이펙트를 고려해야 한다.

 

leaf_node 계수를 미루기

leaf_node의 정의에 대해서 다시 한 번 생각 해 보자.

"누구의 부모도 아닌 노드"이다.

 

근데 문제를 잘 읽어보면, 일반적인 트리 구조로 주어지지 않는다.

index는 노드이고, value는 부모이다.

 

즉, leaf_node는 주어진 배열에서 부모로 지목 받은 적이 없는 것이다.

단순하게 구할 수 있다.

 

그래서 DFS로 삭제 명령을 간선으로 전파 하고,

삭제 된 적도 없으며, 부모로 지목된적 없는 값을 찾으면 된다.

 


 

풀이

 

재귀적 조합 풀이

from collections import Counter


def get_leaf_counter():
    leaf_cnt = 0
    leaf_mapper = {}

    def get_leaf(i):
        nonlocal leaf_cnt

        if len(connection[i]) == 0:
            leaf_mapper[i] = 1
            leaf_cnt += 1
            return 1
        
        if i in leaf_mapper:
            return leaf_mapper[i]

        ret = 0
        for next in connection[i]:
            ret += get_leaf(next)
        
        leaf_mapper[i] = ret
        return ret
    
    get_leaf(-1)
    return leaf_cnt, leaf_mapper


n = int(input())
nodes = list(map(int, input().split()))
child_counter = Counter(nodes)
root = None
connection = {i : [] for i in range(-1, n+1)}

for idx, parent in enumerate(nodes):
    if parent == -1:
        root = idx
    connection[parent].append(idx)

remove = int(input())
leaf_cnt, leaf_mapper = get_leaf_counter()

if remove == root:
    print(0)
else:
    ret = leaf_cnt - leaf_mapper[remove]
    if child_counter[nodes[remove]] == 1:
        ret += 1

    print(ret)

 

leaf_node 지연 계수 풀이

import sys; input = sys.stdin.readline

DEL = -2

n = int(input())
nodes = list(map(int, input().split()))

connetion = {i:[] for i in range(-1, n+1)}
for idx, parent in enumerate(nodes):
    connetion[parent].append(idx)

remove_idx = int(input())

stack = [remove_idx]

while stack:
    node_idx = stack.pop()
    nodes[node_idx] = DEL

    for child in connetion[node_idx]:
        if nodes[child] == DEL\
                or child in stack:
            continue

        stack.append(child)

leaf_node = 0
for node in range(n):
    if node in nodes\
            or nodes[node] == DEL:
        continue
    leaf_node += 1

print(leaf_node)

 


 

피드백

 

당연히 1번 인덱스를 기반으로 한 *2, *2+1 방식의 트리가 주어진 줄 알고 풀다가 시간을 낭비 했다.

하지만 부모를 가르키는 구조인걸 나중에 알고 급하게 dict로 우회 했다.

 

leaf_node를 지연 연산하는 테크닉은 유용한 것 같으니

기억 해 둘 것.