算法4 归并排序

归并排序是我们学习的第一个使用分治法实现的排序算法,它也是唯一一个时间复杂度为NlogN的常用的稳定排序算法,有比它效率更高的排序算法,但却不是稳定的。

分治法是将一个大的问题分成若干个规模较小的相同问题,从而降低问题的复杂度,它是很多高效算法的基础,例如归并排序、快速排序、快速傅里叶变换等。

归并排序将数组平分成若干个子数组,然后将再将小数组归并回一个完整的数组,而排序发生在归并过程中。

这里的平分数组不是真正的拆分成小数组,而是一个数组当成多个子数组看待。

切分数组

归并排序有两种切分数组的方式,分别是使用递归的自顶向下的方式和使用循环的自底向上的方式。

自顶向下:使用递归的方式将数组平分成两个子数组,它能均匀的切分数组,保证左右两边的子数组长度相同,最终切分成长度为1的N个子数组。在递归收拢的阶段对子数组进行归并操作。

自底向上:这种方式一开始就认为数组是以切分好的,直接把数组看成N个长度为1的子数组,然后直接对子数组数组进行归并操作。

归并数组

首先我们需要先实现归并的方法,给定两个已经排序的子数组,分别是lo~midmid+1 ~ hi,将他们归并成一个有序的数组。

需要借助一个辅助才能完成归并操作:首先将数组拷贝辅助数组,然后比较两个有序的子数组,按大小顺序一个一个的拷贝回原数组,操作如下图所示:

1
2
3
4
5
6
7
8
9
10
11
12
private static void merge(Comparable[] a, Comparable[] aux, int lo, int mid, int hi) {
for (int k = lo; k <= hi; k++) aux[k] = a[k];
int i = lo, j = mid + 1;
for (int k = lo; k <= hi; k++) {
// 左边的子数组已用完,无须比较,直接复制右边的元素
if (i > mid) a[k] = aux[j++];
// 右边的数组已用完,无须比较,直接复制左边的元素
else if (j > hi) a[k] = aux[i++];
else if (less(aux[j], aux[i])) a[k] = aux[j++];
else a[k] = aux[i++];
}
}

需要留意的是辅助数组我们从外部传入,而不是方法内部创建,否则因为新建数组的开销会导致算法整体效率变低。

归并操作是从长度为1的最小子数组开始的,两个长度为1的子数组可以归并成一个长度为2的有序子数组,再将长度为2的有序子数组归并成长度为4的有序子数组,最终能将数组完全排序,这就是分治法的核心思想。

自顶向下的归并排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private static void merge(Comparable[] a, Comparable[] aux, int lo, int mid, int hi) {
// as before
}

private static void sort(Comparable[] a, Comparable[] aux, int lo, int hi) {
if (hi <= lo) return;
int mid = lo + (hi - lo) / 2;
sort(a, aux, lo, mid);
sort(a, aux, mid + 1, hi);
merge(a, aux, lo, mid, hi);
}

private static void sort(Comparable[] a) {
Comparable[] aux = new Comparable[a.length];
sort(a, aux, 0, a.length - 1);
}

merge方法的调用跟踪如下图所示:

性能分析:

递归的切分数组,因为每次切分后数组长度减半,所以一共会切分logN次,即递归层数为lgN层。所以有以下结论:

第k层的子数组个数为,每个子数组的长度为 ,即第k层最多需要做次比较。

一共有logN层,即一共需要NlogN次比较,所以归并排序的时间复杂度是线性对数级

从代码和排序动画可以分析出,归并排序的性能几乎不受数组初始顺序的影响,无论什么初始顺序归并排序都能在NlogN比例的时间内完成。

稳定性分析:

归并排序是稳定的排序算法,它不会打乱数组已经有序的部分。从merge()方法的less(aux[j], aux[i])就可以看出aux[j] < aux[i]时将右半边的子数组拷贝到原数组,而相等或者大于的情况下会优先拷贝左半边子数组的元素,这将不会改变原数组的有序部分。

所以less()方法中的参数顺序是不能变的,会打破归并排序的稳定性,这一点要非常的注意

自底向上的归并排序

1
2
3
4
5
6
7
8
9
public static void sort(Comparable[] a) {
Comparable[] aux = new Comparable[a.length];
for (int sz = 1; sz < a.length; sz += sz) { // sz = 1, 2, 4, 8, 16.....
for (int lo = 0; lo < a.length - sz; lo += sz + sz) {
// lo mid high(注意角标越界)
merge(a, aux, lo, lo + sz - 1, Math.min(lo + sz + sz - 1, a.length - 1));
}
}
}

merge方法的调用跟踪如下图所示:

性能分析:

外循环的sz成倍递增,所以外循环的循环次数为logN,而内循环最大循环次数为N,所以很容易得出自底向上的归并排序的增长数量级为NlgN。

归并排序的改进

上面归并排序的性能已经非常优秀了,但是我们还能继续进行一些小优化,最终能让归并排序的性能再提升20%到30%。

优化以自顶向下的归并排序作为优化对象,也可以应用到自底向上的归并排序,它们的优化方式是一样的(除了“改进3”),所以我不会贴两份基本相同的代码。

改进1:小规模子数组使用插入排序

递归会一直切分数组直到子数组长度为1,使得小规模问题中方法的调用过于频繁,所以使用插入排序处理小规模的子数组会更加高效。例如子数组长度小于7(7~15之间的任意整数都是不错的选择)时不继续切分数组,一般可以将归并排序的运行时间缩短10-15% 。

1
2
3
4
5
6
7
8
9
10
private static void sort(Comparable[] a, Comparable[] aux, int lo, int hi) {
if (hi - lo + 1 <= 7) {
insertionSort(a, lo, hi);
return;
}
int mid = lo + (hi - lo) / 2;
sort(a, aux, lo, mid);
sort(a, aux, mid + 1, hi);
merge(a, aux, lo, mid, hi);
}

改进2:检测数组是否已经有序

可以添加一个判断条件,如果a[mid] 小于等于 a[mid + 1],即表明子数组已经是有序的并跳过merge方法,这个改动不影响排序的递归调用。

1
2
3
4
5
6
7
8
private static void sort(Comparable[] a, Comparable[] aux, int lo, int hi) {
if (hi <= lo) return;
int mid = lo + (hi - lo) / 2;
sort(a, aux, lo, mid);
sort(a, aux, mid + 1, hi);
if (less(a[mid + 1], a[mid]))
merge(a, aux, lo, mid, hi);
}

改进3:不将元素复制到辅助数组

这个想法是让a和aux在每次递归时互换身份,互相作为对方的辅助数组,省去了每次merge都要复制数组的操作,代价则是需要预先将元素组拷贝到辅助数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private static void sort(Comparable[] a) {
Comparable[] aux = a.clone();
sort(a, aux, 0, a.length - 1);
}

private static void sort(Comparable[] a, Comparable[] aux, int lo, int hi) {
if (hi <= lo) return;
int mid = lo + (hi - lo) / 2;
// 交换身份
sort(aux, a, lo, mid);
sort(aux, a, mid + 1, hi);
merge(a, aux, lo, mid, hi);
}

private static void merge(Comparable[] a, Comparable[] aux, int lo, int mid, int hi) {
// for (int k = lo; k <= hi; k++) aux[k] = a[k]; // 节省了复制数组的成本
int i = lo, j = mid + 1;
for (int k = lo; k <= hi; k++) {
if (i > mid) a[k] = aux[j++];
else if (j > hi) a[k] = aux[i++];
else if (less(aux[j], aux[i])) a[k] = aux[j++];
else a[k] = aux[i++];
}
}

对于N个元素的归并排序,递归的调用层数为logN层,我们在每一层递归中切分时交换身份,而在归并时又交换回来。每次归并时,都正好使用了另一个数组中未排定的部分作为参考,归并到另一个数组中。而且因为递归的特性(栈),最先入栈时是以(a, aux)的顺序入栈,那么出栈时必定也是以(a, aux)的顺序出栈,所以无论有多少层递归调用都不影响它的返回顺序。

例如,下面是长度为8的数组的递归调用过程,其中序号表示merge方法的执行顺序,用hash值表示数组。即使长度为16的数组,也只是在“树”的最下方新增一层递归,它们的参数的返回顺序在递归调用时就已经决定了。我们只要能保证第一层调用时的逻辑是对的,那么整个递归过程就是对的。

1
2
3
4
5
6
7
8
a[]:0x7a
b[]:0x8b
0x7a,0x8b
7
0x8b,0x7a 0x8b,0x7a
3 6
0x7a,0x8b 0x7a,0x8b 0x7a,0x8b 0x7a,0x8b
1 2 4 5

同时使用所有改进

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
private static void sort(Comparable[] a) {
Comparable[] aux = a.clone();
sort(a, aux, 0, a.length - 1);
}

private static void sort(Comparable[] a, Comparable[] aux, int lo, int hi) {
if (hi - lo + 1 <= 7) {
insertionSort(a, lo, hi);
return;
}
int mid = lo + (hi - lo) / 2;
sort(aux, a, lo, mid);
sort(aux, a, mid + 1, hi);
if (less(aux[mid + 1], aux[mid])) {
merge(a, aux, lo, mid, hi);
} else {
// 跳过合并时,需要将已经有序的部分拷贝到另一个数组
System.arraycopy(aux, lo, a, lo, hi - lo + 1);
}
}

private static void merge(Comparable[] a, Comparable[] aux, int lo, int mid, int hi) {
int i = lo, j = mid + 1;
for (int k = lo; k <= hi; k++) {
if (i > mid) a[k] = aux[j++];
else if (j > hi) a[k] = aux[i++];
else if (less(aux[j], aux[i])) a[k] = aux[j++];
else a[k] = aux[i++];
}
}

Coursera作业(Collinear points)

下面代码在Coursera作业上的得分为100。

CollinearPoints即共线点,在二维平面上拥有众多的点,位于同一条直线上的若干个点称为共线点,我们需要做的就是利用排序算法快速的找出所有包含至少4个点以上的线段,更具体的描述可到作业地址上查看并下载相关项目。

正确的解决方式是从平面上最低的点开始遍历,这样只需要往一个方向查找即可,所以一开始需要对这些点从低到高进行排序。

若多个点能组成一条线段,则这几个点之间的斜率相同,即(y1 - y0)/(x1 - x0) = (y2 - y0)/(x2 - x0)

假设我们现在遍历到p点,可以将其他点按照与p点间的斜率进行第二次排序,这样所有斜率相同的点都会被排在一起,且它是最长线段。假设我们使用的是稳定排序算法,那么共线点必然也是按顺序排序的,即p < q < r < s < t

需要注意的是,使用这种方式我们很可能找到重复且错误的线段,例如使用q点进行斜率排序后,我们将得到一条这样的线段:q-p-r-s-t

为了避免出现这种情况,我们只需要判断这条线段的第二个点是否大于第一个点即可,很显然q-p-r-s-t并不是。

点和线段的定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public class Point implements Comparable<Point> {
   private final int x;     // x-coordinate of this point
   private final int y;     // y-coordinate of this point
   
   public Point(int x, int y) {
       this.x = x;
       this.y = y;
  }

   public void draw() {
       StdDraw.point(x, y);
  }
   
   public void drawTo(Point that) {
       StdDraw.line(this.x, this.y, that.x, that.y);
  }
   
   public double slopeTo(Point that) {
       if (x == that.x && y == that.y) return Double.NEGATIVE_INFINITY;
       double dy = that.y - y;
       double dx = that.x - x;
       if (dx == 0) return Double.POSITIVE_INFINITY;
       if (dy == 0) return 0;
       return dy / dx;
  }
   
   public int compareTo(Point that) {
       if (y < that.y) return -1;
       else if (y == that.y) return Integer.compare(x, that.x);
       else return 1;
  }
   
   public Comparator<Point> slopeOrder() {
       return (o1, o2) -> {
           double slope1 = slopeTo(o1);
           double slope2 = slopeTo(o2);
           return Double.compare(slope1, slope2);
      };
  }
   
   public String toString() {
       return "(" + x + ", " + y + ")";
  }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class LineSegment {
   private final Point p;   // one endpoint of this line segment
   private final Point q;   // the other endpoint of this line segment
   
   public LineSegment(Point p, Point q) {
       if (p == null || q == null) throw new NullPointerException("argument is null");
       this.p = p;
       this.q = q;
  }
   
   public void draw() {
       p.drawTo(q);
  }
   
   public String toString() {
       return p + " -> " + q;
  }
   
   public int hashCode() {
       throw new UnsupportedOperationException();
  }
}

利用排序算法快速找出所有共线点组成的线段:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
public class FastCollinearPoints {

   private ArrayList<LineSegment> segments = new ArrayList<>();

   public FastCollinearPoints(Point[] data) {
       // check data
       if (data == null) throw new IllegalArgumentException("array is null");
       for (int i = 0; i < data.length; i++) {
           if (data[i] == null) throw new IllegalArgumentException("has null item");
      }
       Point[] sortedPoints = new Point[data.length];
       System.arraycopy(data, 0, sortedPoints, 0, data.length);
       MergeX.sort(sortedPoints);
       for (int i = 0; i < sortedPoints.length - 1; i++) {
           if (sortedPoints[i].compareTo(sortedPoints[i + 1]) == 0) {
               throw new IllegalArgumentException("has same item.");
          }
      }

       Point[] points = new Point[data.length];
       for (int i = 0; i < points.length - 1; i++) {
           // MergeX.sort(points); // 重新拷贝一份有序数组性能要优于重排序
           System.arraycopy(sortedPoints, 0, points, 0, points.length);
           MergeX.sort(points, points[i].slopeOrder());
           Point p = points[0];
           Point q = p;
           int count = 1;
           for (int j = 0; j < points.length - 1; j++) {
               if (points[j].slopeTo(p) == points[j + 1].slopeTo(p)) {
                   count++;
                   if (count == 2) {
                       q = points[j];
                       count++;
                  }
                   else if (count >= 4 && j == points.length - 2 && q.compareTo(p) > 0) {
                       segments.add(new LineSegment(p, points[j + 1]));
                  }
              }
               else if (count >= 4 && q.compareTo(p) > 0) {
                   segments.add(new LineSegment(p, points[j]));
                   count = 1;
              }
               else {
                   count = 1;
              }
          }
      }
  }

   public int numberOfSegments() {
       return segments.size();
  }

   public LineSegment[] segments() {
       return segments.toArray(new LineSegment[0]);
  }
}