Go标准库中的sort是如何实现的?

2023-07-10 ⏳2.7分钟(1.1千字)

本文诣在解析Go中的标准库sort是如何排序的,做了哪些优化,以及pdqsort1排序算法在Go中的源码实现。

一直自认为Go中默认的排序是快速排序,疑惑如果碰到了极端情况sort标准库会如何进行排序,优化方向有哪些?直到最近看到了一个issue,了解Go中已经使用了pdqsort.才发现自己的思维还停留在教科书上的几种单一排序方式。其实现在大部分业界使用的不稳定排序其实都已经是混合排序算法了。现在借此篇文章通过源码了解Go中的pdqsort是如何实现的。

前置知识

quicksort(快速排序)

经典的快排其实就是分治的思想,找到pivot2,比pivot小的放左边,比pivot大的放右边,然后分别在对左右子数组进行递归。

所以 选择pivot的好坏直接决定此次快排的效率。

常见的选取pivot的方式:

insertsort(插入排序)

遍历数组一个个插入已经排好序的数组里,在数组长度较短的情况下表现良好

heapsort(堆排序)

利用堆结构设计出来的一种排序,特性为在最坏情况下复杂度仍然为O(n* logn)。所以很多混合排序都以它为兜底排序方式。

pdqsort (pattern-defeating quicksort)

一种混合排序算法,在不同情况下切换不同的排序机制。

源码解析

大致思路为:

循环执行上述流程直至全局有序

// pdqsort sorts data[a:b].
// The algorithm based on pattern-defeating quicksort(pdqsort), 
// but without the optimizations from BlockQuicksort.
// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf
// C++ implementation: https://github.com/orlp/pdqsort
// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/
// limit is the number of allowed bad (very unbalanced) pivots 
// before falling back to heapsort.
func pdqsort(data Interface, a, b, limit int) {
  const maxInsertion = 12
  
  var (
    wasBalanced    = true // whether the last partitioning was reasonably balanced
    wasPartitioned = true // whether the slice was already partitioned
  )
  
  for {
    length := b - a
    
    if length <= maxInsertion {
            insertionSort(data, a, b)
            return
    }
    
    // Fall back to heapsort if too many bad choices were made.
    if limit == 0 {
            heapSort(data, a, b)
            return
    }
    
    // If the last partitioning was imbalanced, we need to    breaking patterns.
    if !wasBalanced {
            breakPatterns(data, a, b)
            limit--
    }
    
    pivot, hint := choosePivot(data, a, b)
    if hint == decreasingHint {
            reverseRange(data, a, b)
            // The chosen pivot was pivot-a elements after the start of the array.
            // After reversing it is pivot-a elements before the end of the array.
            // The idea came from Rust's implementation.
            pivot = (b - 1) - (pivot - a)
            hint = increasingHint
    }
    
    // The slice is likely already sorted.
    if wasBalanced && wasPartitioned && hint == increasingHint {
            if partialInsertionSort(data, a, b) {
                    return
            }
    }
    
    // Probably the slice contains many duplicate elements,    
    // partition the slice into
    // elements equal to and elements greater than the pivot.
    if a > 0 && !data.Less(a-1, pivot) {
            mid := partitionEqual(data, a, b, pivot)
            a = mid
            continue
    }
    
    mid, alreadyPartitioned := partition(data, a, b, pivot)
    wasPartitioned = alreadyPartitioned
    
    leftLen, rightLen := mid-a, b-mid
    balanceThreshold := length / 8
    if leftLen < rightLen {
            wasBalanced = leftLen >= balanceThreshold
            pdqsort(data, a, mid, limit)
            a = mid + 1
    } else {
            wasBalanced = rightLen >= balanceThreshold
            pdqsort(data, mid+1, b, limit)
            b = mid
    }
  }
}

insertionSort

maxInsertion默认为12,数组长度小于maxInsertion时,直接使用插入排序。

heapSort

limit为优化次数,取自2^x

func Sort(data Interface) {
        n := data.Len()
        if n <= 1 {
                return
        }
        limit := bits.Len(uint(n))
        pdqsort(data, 0, n, limit)
}

limit为0时,代表多次优化,pivot仍未接近中位数,改用堆排序。

wasBalanced

wasBalanced代表了上次分组是否均衡。结果取决于 分完组后数量少的一组的长度与总长度的1/8的比较值.

balanceThreshold := length / 8
if leftLen < rightLen {
        wasBalanced = leftLen >= balanceThreshold
        ...
} else {
        wasBalanced = rightLen >= balanceThreshold
        ...
}

breakPatterns

breakPatterns函数是当wasBalanced==false(上次分组不均衡)时,会随机swap几个元素来避免极端情况的发生。同时limit还会减1,代表上次使用快排表现不佳(如果limit减为0,则会使用堆排序)。

// breakPatterns scatters some elements around in an attempt to break some patterns
// that might cause imbalanced partitions in quicksort.
func breakPatterns(data Interface, a, b int) {
  length := b - a
  if length >= 8 {
    random := xorshift(length)
    modulus := nextPowerOfTwo(length)  
    for idx := a + (length/4)*2 - 1; idx <= a+(length/4)*2+1; id  ++ {
       other := int(uint(random.Next()) & (modulus - 1))
       if other >= length {
          other -= length
       }
       data.Swap(idx, a+other)
    }
}

choosePivot

choosePivot函数从数组中选择合适的pivot。

同时第三种情况还会判断是否局部有序,如果 a-1 <a < a+1 && b-1 < b < b+1 && c-1 < c < c+1 && a < b < c则判断为局部正序,反之则为局部逆序

func choosePivot(data Interface, a, b int) (pivot int, hint sortedHint) {
  const (
          shortestNinther = 50
          maxSwaps        = 4 * 3
  )  
  l := b - a  
  var (
          swaps int
          i     = a + l/4*1
          j     = a + l/4*2
          k     = a + l/4*3
  )  
  if l >= 8 {
    if l >= shortestNinther {
            // Tukey ninther method, the idea came from Rust's implementation.
            i = medianAdjacent(data, i, &swaps)
            j = medianAdjacent(data, j, &swaps)
            k = medianAdjacent(data, k, &swaps)
    }
    // Find the median among i, j, k and stores it into j.
    j = median(data, i, j, k, &swaps)
  }  
  switch swaps {
  case 0:
          return j, increasingHint
  case maxSwaps:
          return j, decreasingHint
  default:
          return j, unknownHint
  }
}

reverseRange

如果上一步判断为局部逆序,则使用reverseRange反转数组变为局部正序。

wasPartitioned

wasPartitioned表示上次循环没有交换元素,即局部正序的.

partialInsertionSort

如果wasBalanced && wasPartitioned && hint == increasingHint,则很有可能此数组已经有序了。那我们就使用partial insertion sort算法,相比insertion sort多了尝试次数,避免我们误判导致时间复杂度上升。 partialInsertionSort函数就是为了实现上述逻辑,如果maxSteps之内未有序则退出,反之则直接返回有序数组。

// partialInsertionSort partially sorts a slice, returns true if the slice is sorted at the end.
func partialInsertionSort(data Interface, a, b int) bool {
  const (
          maxSteps         = 5  // maximum number of adjacent out-of-order pairs that will get shifted
          shortestShifting = 50 // don't shift any elements on short arrays
  )
  i := a + 1
  for j := 0; j < maxSteps; j++ {
    for i < b && !data.Less(i, i-1) {
            i++
    }  
    if i == b {
            return true
    }  
    if b-a < shortestShifting {
            return false
    }  
    data.Swap(i, i-1)  
    // Shift the smaller one to the left.
    if i-a >= 2 {
            for j := i - 1; j >= 1; j-- {
                    if !data.Less(j, j-1) {
                            break
                    }
                    data.Swap(j, j-1)
            }
    }
    // Shift the greater one to the right.
    if b-i >= 2 {
            for j := i + 1; j < b; j++ {
                    if !data.Less(j, j-1) {
                            break
                    }
                    data.Swap(j, j-1)
            }
    }
  }
  return false
}

partitionEqual

a > 0 && !data.Less(a-1, pivot),意味着上一组左组中的最大值和本次的pivot相等。数组中可能存在很多重复元素。partitionEqual函数就是解决这个问题,把重复元素放在一起,然后重新选pivot,然后开始下一轮循环。

partition

partition实现了快排逻辑,并且会判断本次循环是否交换了元素,设置wasPartitioned.

总结

排序方式千千万,标准库中的方式很大程度上是为了满足各种情况而出现的通用解,在某些垂直领域其实并不会比单一排序更有效。这个在之前的解析map一文中也能发现。所以当我们碰到一些性能问题时,可以适当的摒弃标准库而选择合适业务的方案,如此才能有自己的积累。