Recursive Way to do merge sort

def merge_sort(arr):
    if len(arr) > 1:
        mid = len(arr) // 2
        left_half = arr[:mid]
        right_half = arr[mid:]

        merge_sort(left_half)
        merge_sort(right_half)

        # merge two halves

        i = j = k = 0

        # Both halves have at least one element
        while i < len(left_half) and j < len(right_half):
            if left_half[i] < right_half[j]:
                arr[k] = left_half[i]
                i += 1
            else:
                arr[k] = right_half[j]
                j += 1
            k += 1

        # left_half has more elements 
        while i < len(left_half):
            arr[k] = left_half[i]
            i += 1
            k += 1

        # right_half has more elements
        while j < len(right_half):
            arr[k] = right_half[j]
            j += 1
            k += 1

if __name__ == "__main__":
    aList = [4, 1, 7, 6, 5, 3, 8, 2]
    merge_sort(aList)
    print(aList)
Write merge step in the seperated function