https://www.acmicpc.net/problem/1389
해당 문제를 먼저 dfs 로 풀었는데 계속 메모리 초과가 발생하였다.
import sys
input = sys.stdin.readline
n , m = map(int,input().split())
nodes = [[] for _ in range(n+1)]
stack1 = list()
stack2 = list()
result = [0 for _ in range(n+1)]
for _ in range(m):
start , end = map(int,input().split())
nodes[start].append(end)
nodes[end].append(start)
for i in range(1,n+1):
count = 0
for j in range(1,n+1):
if i == j:
continue
for a in nodes[i]:
stack1.append(a)
while stack1:
count += 1
if j in stack1:
stack1.clear()
break
stack2 = stack1[:]
stack1.clear()
while stack2:
for t in nodes[stack2.pop()]:
stack1.append(t)
result[i] = count
print(result.index(min(result[1:])))
물어보니 문제의 원인은 dfs 로 푸는 것이 아닌 플로이드 워셜 알고리즘으로 해야 한다는 것이었다.
플로이드-워셜(Floyd-Warshall) 알고리즘은 모든 쌍 최단 경로 알고리즘으로, 그래프에서 모든 정점 쌍 간의 최단 경로를 찾는 데 사용됩니다. 이 알고리즘은 동적 계획법을 사용하여 그래프의 모든 경로를 반복적으로 개선하여 최단 경로를 찾습니다.특징
알고리즘 설명
|
# 플로이드 워셜 알고리즘
import sys
def floyd_warshall(n, graph):
dist = [[sys.maxsize] * n for _ in range(n)]
# 초기화
for i in range(n):
for j in range(n):
if i == j:
dist[i][j] = 0
elif graph[i][j] != 0:
dist[i][j] = graph[i][j]
# 최단 경로 갱신
for k in range(n):
for i in range(n):
for j in range(n):
if dist[i][j] > dist[i][k] + dist[k][j]:
dist[i][j] = dist[i][k] + dist[k][j]
return dist
# 예제 입력
n = 4
graph = [
[0, 3, sys.maxsize, 5],
[2, 0, sys.maxsize, 4],
[sys.maxsize, 1, 0, sys.maxsize],
[sys.maxsize, sys.maxsize, 2, 0]
]
# 알고리즘 실행
distances = floyd_warshall(n, graph)
# 결과 출력
for row in distances:
print(row)
아래는 플로이드 워셜 알고리즘으로 해결한 문제이다.
import sys
input = sys.stdin.readline
inf = sys.maxsize
n , m = map(int,input().split())
nodes = [[] for _ in range(n+1)]
dist = [[inf for _ in range(n+1)] for _ in range(n+1)]
for _ in range(m):
start , end = map(int,input().split())
nodes[start].append(end)
nodes[end].append(start)
dist[start][end] = 1
dist[end][start] = 1
for i in range(1,n+1):
for j in range(1,n+1):
for k in range(1,n+1):
if j == k:
dist[j][k] = 0
else:
if dist[j][k] > dist[j][i] + dist[i][k]:
dist[j][k] = dist[j][i] + dist[i][k]
result = sys.maxsize
for r in range(1,n+1):
temp = sum(dist[r][1:])
if result > temp:
result = temp
minnum = r
print(minnum)
'코딩테스트' 카테고리의 다른 글
[python] 백준 9461 번 : 파도반 수열 (실버 3) (2) | 2024.07.23 |
---|---|
[python] 백준 9375 : 패션왕 신해빈 (실버 3) (0) | 2024.07.22 |
[python] 백준 2579 : 계단 오르기 (실버 3) (0) | 2024.07.18 |
[python] 백준 7569 : 토마토 (골드 5) (0) | 2024.07.17 |
[python] 백준 17219 : 비밀번호 찾기 (실버 4) (0) | 2024.07.17 |