Poison

正确的二分搜索实现

最近看 Joshua Bloch 的访谈 More Effective Java With Google’s Joshua Bloch 时,其中提到单元测试不足以确保代码能够正常工作,举了二分搜索的例子,二分搜索的文章详见:Google AI Blog: Extra, Extra - Read All About It: Nearly All Binary Searches and Mergesorts are Broken,本文在此作简要记录。

在 JDK 1.3.1_28 中 java.util.Arrays#binarySearch(int[], int) 的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/**
* Searches the specified array of ints for the specified value using the
* binary search algorithm. The array <strong>must</strong> be sorted (as
* by the <tt>sort</tt> method, above) prior to making this call. If it
* is not sorted, the results are undefined. If the array contains
* multiple elements with the specified value, there is no guarantee which
* one will be found.
*
* @param a the array to be searched.
* @param key the value to be searched for.
* @return index of the search key, if it is contained in the list;
* otherwise, <tt>(-(<i>insertion point</i>) - 1)</tt>. The
* <i>insertion point</i> is defined as the point at which the
* key would be inserted into the list: the index of the first
* element greater than the key, or <tt>list.size()</tt>, if all
* elements in the list are less than the specified key. Note
* that this guarantees that the return value will be &gt;= 0 if
* and only if the key is found.
* @see #sort(int[])
*/
public static int binarySearch(int[] a, int key) {
int low = 0;
int high = a.length - 1;

while (low <= high) {
int mid = (low + high) / 2;
int midVal = a[mid];

if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
else
return mid; // key found
}
return -(low + 1); // key not found.
}

在 JDK 1.4.2_19 及 1.5.0_22 中 java.util.Arrays#binarySearch(int[], int) 的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/**
* Searches the specified array of ints for the specified value using the
* binary search algorithm. The array <strong>must</strong> be sorted (as
* by the <tt>sort</tt> method, above) prior to making this call. If it
* is not sorted, the results are undefined. If the array contains
* multiple elements with the specified value, there is no guarantee which
* one will be found.
*
* @param a the array to be searched.
* @param key the value to be searched for.
* @return index of the search key, if it is contained in the list;
* otherwise, <tt>(-(<i>insertion point</i>) - 1)</tt>. The
* <i>insertion point</i> is defined as the point at which the
* key would be inserted into the list: the index of the first
* element greater than the key, or <tt>list.size()</tt>, if all
* elements in the list are less than the specified key. Note
* that this guarantees that the return value will be &gt;= 0 if
* and only if the key is found.
* @see #sort(int[])
*/
public static int binarySearch(int[] a, int key) {
int low = 0;
int high = a.length - 1;

while (low <= high) {
int mid = (low + high) >> 1;
int midVal = a[mid];

if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
else
return mid; // key found
}
return -(low + 1); // key not found.
}

在 JDK 1.6.0_45 中 java.util.Arrays#binarySearch(int[], int) 的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/**
* Searches the specified array of ints for the specified value using the
* binary search algorithm. The array must be sorted (as
* by the {@link #sort(int[])} method) prior to making this call. If it
* is not sorted, the results are undefined. If the array contains
* multiple elements with the specified value, there is no guarantee which
* one will be found.
*
* @param a the array to be searched
* @param key the value to be searched for
* @return index of the search key, if it is contained in the array;
* otherwise, <tt>(-(<i>insertion point</i>) - 1)</tt>. The
* <i>insertion point</i> is defined as the point at which the
* key would be inserted into the array: the index of the first
* element greater than the key, or <tt>a.length</tt> if all
* elements in the array are less than the specified key. Note
* that this guarantees that the return value will be &gt;= 0 if
* and only if the key is found.
*/
public static int binarySearch(int[] a, int key) {
return binarySearch0(a, 0, a.length, key);
}

// Like public version, but without range checks.
private static int binarySearch0(int[] a, int fromIndex, int toIndex,
int key) {
int low = fromIndex;
int high = toIndex - 1;

while (low <= high) {
int mid = (low + high) >>> 1;
int midVal = a[mid];

if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
else
return mid; // key found
}
return -(low + 1); // key not found.
}

上述几个版本中实现的差异均在对 mid 值的计算,其中在 JDK 1.3.1_28 实现方式 int mid = (low + high) / 2; 中,当 lowhigh 的值很大时会导致和超过 Integer.MAX_VALUE,导致符号位变为 1,从而导致除以 2 的值 mid 为负数,继而触发数据越界异常,比如如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package me.tianshuang;

public class Test {

public static void main(String[] args) {
// 为了便于阅读,此处使用二进制表示,为了减少内存消耗,此处使用字节数组演示
byte[] array = new byte[0b01000000_00000000_00000000_00000001];
binarySearch(array, (byte) 1);
}

public static int binarySearch(byte[] a, byte key) {
int low = 0;
int high = a.length - 1;

while (low <= high) {
int mid = (low + high) / 2;
byte midVal = a[mid];

if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
else
return mid; // key found
}
return -(low + 1); // key not found.
}

}

以上代码申请的数组的 length 为 0b01000000_00000000_00000000_00000001,可达的索引范围为:[0b00000000_00000000_00000000_00000000, 0b01000000_00000000_00000000_00000000],该数组中的值默认全部被初始化为 0,我们需要寻找的目标值为 1,那么在二分搜索过程中会不停向右侧逼近,当 lowhigh 均为 0b01000000_00000000_00000000_00000000 时,此时 low + high 等于 0b10000000_00000000_00000000_00000000,可以看出符号位已经变为 1 了,此时值已经是负数了,我们根据 补码 中提到的转换方法将该二进制表示转换为十进制的值 -1×(231) = -2147483648,再进行除以 2 的操作得到 mid 的值为 -1073741824,随后进行获取该数组下标的值即会触发 java.lang.ArrayIndexOutOfBoundsException,异常栈帧如下:

1
2
3
Exception in thread "main" java.lang.ArrayIndexOutOfBoundsException: -1073741824
at me.tianshuang.Test.binarySearch(Test.java:16)
at me.tianshuang.Test.main(Test.java:7)

在 JDK 1.4.2_19 及 1.5.0_22 中,对 mid 的计算也只是将除以 2 调整为了通过移位操作实现,在以上的测试用例中依然会触发数组索引越界异常,直到 JDK 1.6 对 mid 的计算调整为 int mid = (low + high) >>> 1; 问题才被修复,Bug 可见:JDK-5045582 : (coll) binarySearch() fails for size larger than 1<<30,Joshua Bloch 说该 bug 在 JDK 中存在了九年才被发现。在 《编程珠玑》 中,Bentley 说:“虽然第一个二分搜索是在 1946 年发布的,但第一个对 n 的所有值都正确工作的二分搜索直到 1962 年才出现。”

关于 mid 的计算,正确的实现是 int mid = low + ((high - low) / 2);,更快且正确的实现是 int mid = (low + high) >>> 1;,通过这个案例来看,如果要确定程序是正确的,必须针对所有可能的输入值对其进行测试,但这很少可行。对于并发程序,情况更糟,必须测试所有内部状态,这实际上是不可能的。同时我联想到之前遇到的几个关于单元测试输入值未覆盖指定输入导致的 bug,其中一个是同事编写的获取一个 url 的 Content-Length 的函数,当时同事的代码是调用的 java.net.URLConnection#getContentLength 方法实现,上线后发现业务逻辑没有生效,经过排查,原因为该方法仅会返回值小于等于 Integer.MAX_VALUEContent-Length,其他的情况就直接返回 -1 了,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
* Returns the value of the {@code content-length} header field.
* <P>
* <B>Note</B>: {@link #getContentLengthLong() getContentLengthLong()}
* should be preferred over this method, since it returns a {@code long}
* instead and is therefore more portable.</P>
*
* @return the content length of the resource that this connection's URL
* references, {@code -1} if the content length is not known,
* or if the content length is greater than Integer.MAX_VALUE.
*/
public int getContentLength() {
long l = getContentLengthLong();
if (l > Integer.MAX_VALUE)
return -1;
return (int) l;
}

该问题导致 url 对应的 Content-Length 大于 Integer.MAX_VALUE 时,返回了 -1 导致业务逻辑未生效,而当时同事编写的单元测试也仅找了个普通的 url 进行测试,并未覆盖到 Content-Length 大于 Integer.MAX_VALUE 的情况,导致上线后影响了正常的业务逻辑,这就是因为单元测试覆盖不完善导致的程序 bug,和上面的二分搜索 bug 有类似之处。同样的,我之前在开发基于 GitHub - RoaringBitmap/RoaringBitmap: A better compressed bitset in Java 的分布式去重服务时,发现了 RoaringBitmap 中获取值的错误,也是由于单元测试输入值覆盖不完善导致,issue 如下:The previousValue method sometimes returns incorrect data · Issue #400 · RoaringBitmap/RoaringBitmap · GitHub。所以,正如 Joshua Bloch 所说:即使是最小的代码段也很难正确编写,我们的整个世界都运行在大而复杂的代码段上。

References

Underscores in Numeric Literals