2024. 3. 8. 20:23ㆍ알고리즘/정렬
소개
- 병합 정렬은 리스트를 재귀적으로 2개씩 분할하고 개수가 1개가 될 때까지 반복한다. 이 후 분할된 부분에 대해 병합 및 정렬을 진행한다.
- 분할하는데 O(log n) 연산이 들고 병합 및 정렬에 O(n)의 연산이 필요하므로 각 연산을 곱한 O(n log n)의 시간 복잡도를 가진다.
- 주어진 데이터의 정렬 상태와 상관없이 항상 일정한 O(n log n)을 가진다.
- 기존 데이터의 정렬 순서를 보장하는 안정 정렬이다.
전체 코드
def merge_two_area(arr, left, mid, right):
l_idx = left
r_idx = mid + 1
# 정렬값 저장을 위한 임시 리스트 생성
sort_arr = [0] * len(arr)
s_idx = left
while l_idx <= mid and r_idx <= right:
if arr[l_idx] <= arr[r_idx]:
# 왼쪽 우선 merge
sort_arr[s_idx] = arr[l_idx]
l_idx += 1
else:
# 오른쪽 우선 merge
sort_arr[s_idx] = arr[r_idx]
r_idx += 1
s_idx += 1
# 왼쪽이 먼저 완료 -> 오른쪽 merge
if l_idx > mid:
for i in range(r_idx, right + 1):
sort_arr[s_idx] = arr[i]
s_idx += 1
# 오른쪽 먼저 완료 -> 왼쪽 merge
else:
for i in range(l_idx, mid + 1):
sort_arr[s_idx] = arr[i]
s_idx += 1
# left ~ right 까지 merge sort된 부분 대입
for i in range(left, right + 1):
arr[i] = sort_arr[i]
def merge_sort(arr, left, right):
if left < right:
mid = (right + left) // 2
# 좌, 우 분할 정복
merge_sort(arr, left, mid)
merge_sort(arr, mid + 1, right)
# 분할 된 리스트 병합
merge_two_area(arr, left, mid, right)
def sort(nums: list):
merge_sort(nums, 0, len(nums) - 1)
return nums
# 테스트용 리스트
print(sort([8, 2, 3, 7, 1, 5, 4, 6]))
print(sort([70, 89, 95, 34, 74, 56, 96, 70, 84, 33]))
설명
병합 정렬은 크게 리스트를 분할 과정과 병합 과정으로 나뉜다. 분할 과정은 merge_sort 함수를 좌, 우로 나눈 후 재귀적으로 호출하여 구현한다. 이 후 더 분할할 수 없을 경우 백트래킹이 발생하며 merge_two_area함수를 호출한다.
병합은 merge_two_area 함수를 통해 구현되었는데 다음은 설명을 위한 예시다.
위에서 [1, 5]과 [4, 6]을 병합하는 과정을 살펴본다.
arr 배열은 병합하기 전에 다음과 같이 되어 있다. 병합을 위한 포인터 l_idx = left, r_idx = mid+1로 초기화 한다. 또한 병합된 값을 임시로 저장하는 sorted_arr를 초기화하고 값을 저장할 포인터 s_idx = left로 초기화한다.
l_dx는 left~mid, r_idx는 mid+1~right를 순회한다. 각 포인터 중 더 작은 값을 sorted_arr에 추가하고 각각의 포인터를 증가시킨다. 값을 추가하면 s_idx 포인터를 증가한다.
진행하면 l_idx 또는 r_idx가 초과되는데 이 경우 r_idx가 아직 sorted_arr로 이동하지 않았다. 이처럼 병합 시 좌, 우의 포인터 중 하나가 저장을 다 못하고 종료된다.
이를 해결하기 위해 복사가 종료된 반대편 포인터를 순회하며 sorted_arr에 값을 저장한다.
병합이 완료되었으니 sorted_arr-> arr로 값을 복사한다.
이 과정을 반복하면 마지막에 병합 정렬이 완료된다.