Poison

ConcurrentHashMap

本文主要记录 JDK 12 中 ConcurrentHashMap 的扩容实现。首先需要关注的为 sizeCtl 变量,其定义如下:

1
2
3
4
5
6
7
8
9
/**
* Table initialization and resizing control. When negative, the
* table is being initialized or resized: -1 for initialization,
* else -(1 + the number of active resizing threads). Otherwise,
* when table is null, holds the initial table size to use upon
* creation, or 0 for default. After initialization, holds the
* next element count value upon which to resize the table.
*/
private transient volatile int sizeCtl;

根据注释我们知道,sizeCtl 主要用于底层 table 数组的初始化和扩容控制。当该值为负数时,底层 table 数组正在初始化或扩容,否则,当 table 数组为空时,持有需要初始化 table 数组的大小,使用默认构造函数时 sizeCtl 为 0,在 table 数组初始化后,sizeCtl 存储的为触发扩容的元素个数阈值。我们根据源码来理解上述注释,首先看默认构造函数:

1
2
3
4
5
/**
* Creates a new, empty map with the default initial table size (16).
*/
public ConcurrentHashMap() {
}

此时实例化后 sizeCtl 被初始化为默认值 0,我们再看非默认构造函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/**
* Creates a new, empty map with an initial table size
* accommodating the specified number of elements without the need
* to dynamically resize.
*
* @param initialCapacity The implementation performs internal
* sizing to accommodate this many elements.
* @throws IllegalArgumentException if the initial capacity of
* elements is negative
*/
public ConcurrentHashMap(int initialCapacity) {
this(initialCapacity, LOAD_FACTOR, 1);
}

/**
* Creates a new, empty map with an initial table size based on
* the given number of elements ({@code initialCapacity}), initial
* table density ({@code loadFactor}), and number of concurrently
* updating threads ({@code concurrencyLevel}).
*
* @param initialCapacity the initial capacity. The implementation
* performs internal sizing to accommodate this many elements,
* given the specified load factor.
* @param loadFactor the load factor (table density) for
* establishing the initial table size
* @param concurrencyLevel the estimated number of concurrently
* updating threads. The implementation may use this value as
* a sizing hint.
* @throws IllegalArgumentException if the initial capacity is
* negative or the load factor or concurrencyLevel are
* nonpositive
*/
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (initialCapacity < concurrencyLevel) // Use at least as many bins
initialCapacity = concurrencyLevel; // as estimated threads
long size = (long)(1.0 + (long)initialCapacity / loadFactor);
int cap = (size >= (long)MAXIMUM_CAPACITY) ?
MAXIMUM_CAPACITY : tableSizeFor((int)size);
this.sizeCtl = cap;
}

可以看出,会根据传入的参数将 sizeCtl 对齐至 2 的 n 次幂,比如我们使用 new ConcurrentHashMap(17) 进行实例化,此时 sizeCtl 会被赋值为 32。当第一个线程调用 put 方法时,会触发 initTable() 方法尝试初始化底层 table 数组:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/**
* Initializes table, using the size recorded in sizeCtl.
*/
private final Node<K,V>[] initTable() {
Node<K,V>[] tab; int sc;
while ((tab = table) == null || tab.length == 0) {
if ((sc = sizeCtl) < 0)
Thread.yield(); // lost initialization race; just spin
else if (U.compareAndSetInt(this, SIZECTL, sc, -1)) {
try {
if ((tab = table) == null || tab.length == 0) {
int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
table = tab = nt;
sc = n - (n >>> 2);
}
} finally {
sizeCtl = sc;
}
break;
}
}
return tab;
}

此时会尝试 CAS 设置 sizeCtl 字段为 -1,即注释中提到的正在初始化的状态,CAS 成功的线程会进入 else if 代码块内部,对 table 数组是否为空进行双重检查,确认未被初始化后申请数组并赋值给 table 变量,然后通过 sc = n - (n >>> 2) 计算扩容阈值并在 finally 代码块中赋值给 sizeCtl,此时 sizeCtl 存储的为触发扩容的阈值。当 ConcurrentHashMap 由默认构造函数创建时,初始化的 table 数组大小为 DEFAULT_CAPACITY,该静态变量定义如下:

1
2
3
4
5
/**
* The default initial table capacity. Must be a power of 2
* (i.e., at least 1) and at most MAXIMUM_CAPACITY.
*/
private static final int DEFAULT_CAPACITY = 16;

可知扩容阈值 sc = n - (n >>> 2) = 16 - (16 >>> 2) = 16 - 4 = 12,即对应着负载因子 0.75。扩容操作是在添加元素后的 addCount 方法中触发的,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/**
* Adds to count, and if table is too small and not already
* resizing, initiates transfer. If already resizing, helps
* perform transfer if work is available. Rechecks occupancy
* after a transfer to see if another resize is already needed
* because resizings are lagging additions.
*
* @param x the count to add
* @param check if <0, don't check resize, if <= 1 only check if uncontended
*/
private final void addCount(long x, int check) {
CounterCell[] cs; long b, s;
if ((cs = counterCells) != null ||
!U.compareAndSetLong(this, BASECOUNT, b = baseCount, s = b + x)) {
CounterCell c; long v; int m;
boolean uncontended = true;
if (cs == null || (m = cs.length - 1) < 0 ||
(c = cs[ThreadLocalRandom.getProbe() & m]) == null ||
!(uncontended =
U.compareAndSetLong(c, CELLVALUE, v = c.value, v + x))) {
fullAddCount(x, uncontended);
return;
}
if (check <= 1)
return;
s = sumCount();
}
if (check >= 0) {
Node<K,V>[] tab, nt; int n, sc;
while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
(n = tab.length) < MAXIMUM_CAPACITY) {
int rs = resizeStamp(n) << RESIZE_STAMP_SHIFT;
if (sc < 0) {
if (sc == rs + MAX_RESIZERS || sc == rs + 1 ||
(nt = nextTable) == null || transferIndex <= 0)
// 以下任意条件满足则跳出循环:
// 当并发扩容线程数达到规定的上限时
// 当前线程为最后一个扩容线程时,即低 16 位的数值由 2 减为 1 时
// 完成扩容设置 nextTable 为 null 后,或创建新数组异常时(内存不足等情况)
// bucket 范围分配完毕时
break;
if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1))
transfer(tab, nt);
}
else if (U.compareAndSetInt(this, SIZECTL, sc, rs + 2))
transfer(tab, null);
s = sumCount();
}
}
}

该方法中的第一个 if 代码块用于计数,我们主要分析第二个 if 代码块的扩容实现,即首先比较当前哈希表实际元素数量 s 是否大于等于扩容阈值 sizeCtl,当哈希表实际元素数量达到 12 时,即会进入 while 循环内部,首先执行的是 resizeStamp(n) << RESIZE_STAMP_SHIFT 这段逻辑,我们先看 resizeStamp 方法实现:

1
2
3
4
5
6
7
/**
* Returns the stamp bits for resizing a table of size n.
* Must be negative when shifted left by RESIZE_STAMP_SHIFT.
*/
static final int resizeStamp(int n) {
return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}

以上代码是什么意思呢?我们结合二进制表示来看,在 ConcurrentHashMap 使用默认构造函数的情况下,创建的 table 数组长度为 DEFAULT_CAPACITY 即 16,即 resizeStamp 方法的参数 n 为 16,此时 Integer.numberOfLeadingZeros(n) 计算 n 的二进制表示中前导零的数量,这么做的目的是什么呢?根据我的理解,此处主要是压缩表示 n 值所需的比特数,当 nint 进行表示时,需要 32 个比特,因为 n 一定为 2 的 x 次幂这一特性,我们知道 n 的二进制表示中一定只有一位为 1,比如 n 为 16 时的二进制表示为 00000000 00000000 00000000 00010000,且易知 Integer.numberOfLeadingZeros(n) 计算出的 n 的二进制表示中前导零的数量最大为 30(table 数组容量最小为 2),当 n 为 16 时,Integer.numberOfLeadingZeros(n) 的返回值为 27,即 00000000 00000000 00000000 00010000 中前导零的数量,此时我们再使用 27 的二进制表示:00000000 00000000 00000000 00011011,即我们用前导零数量的二进制表示的方式把 n 这个数值压缩为使用 5 个比特即可表示,为何要进行压缩呢?因为 ConcurrentHashMap 提供了并发扩容最大线程数的限制,sizeCtl 还会存储当前参与扩容的线程数。我们回到 resizeStamp 方法,或运算右侧的表达式为 1 << (RESIZE_STAMP_BITS - 1),其中 RESIZE_STAMP_BITS 定义如下:

1
2
3
4
5
/**
* The number of bits used for generation stamp in sizeCtl.
* Must be at least 6 for 32bit arrays.
*/
private static int RESIZE_STAMP_BITS = 16;

容易得出 1 << (RESIZE_STAMP_BITS - 1) = 1 << (16 - 1) = 1 << 15,对应的二进制表示为 00000000 00000000 10000000 00000000。然后我们将两值做或运算:

1
2
3
4
5
6
7
8
Integer.numberOfLeadingZeros(16):
00000000 00000000 00000000 00011011

1 << (RESIZE_STAMP_BITS - 1):
00000000 00000000 10000000 00000000

resizeStamp(16) = Integer.numberOfLeadingZeros(16) | (1 << (RESIZE_STAMP_BITS - 1)):
00000000 00000000 10000000 00011011

可以看出 resizeStamp 方法主要做了两个操作,第一个操作为将当前 table 数组的大小用前导零数量的二进制表示进行存储,第二个操作为将第十六位比特设置为 1,这样做的目的是什么呢?我们接着往下分析,根据源码我们知道 resizeStamp(n) 方法返回后会左移 RESIZE_STAMP_SHIFT,执行的代码为:

1
2
3
4
5
6
/**
* The bit shift for recording size stamp in sizeCtl.
*/
private static final int RESIZE_STAMP_SHIFT = 32 - RESIZE_STAMP_BITS;

int rs = resizeStamp(n) << RESIZE_STAMP_SHIFT;

即左移 16 位,二进制表示如下:

1
2
3
4
5
resizeStamp(n):
00000000 00000000 10000000 00011011

int rs = resizeStamp(n) << 16:
10000000 00011011 00000000 00000000

rs 变量的二进制表示为 10000000 00011011 00000000 00000000。可以看出,符号位现在为 1,即 rs 为一个负数,且 table 数组的长度使用前导零数量的二进制表示存储在了高 16 位中,现在低 16 位全部为 0,我们回顾一下上面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
if (check >= 0) {
Node<K,V>[] tab, nt; int n, sc;
while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
(n = tab.length) < MAXIMUM_CAPACITY) {
int rs = resizeStamp(n) << RESIZE_STAMP_SHIFT;
if (sc < 0) {
if (sc == rs + MAX_RESIZERS || sc == rs + 1 ||
(nt = nextTable) == null || transferIndex <= 0)
break;
if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1))
transfer(tab, nt);
}
else if (U.compareAndSetInt(this, SIZECTL, sc, rs + 2))
transfer(tab, null);
s = sumCount();
}
}

此时 sc = sizeCtl 即值为 12,会执行 else if 中的代码尝试使用 CAS 将 sizeCtl 更新为 rs + 2,在 sizeCtl 的注释中提到:When negative, the table is being initialized or resized: -1 for initialization, else -(1 + the number of active resizing threads). 从代码来看,个人认为这段注释描述不是很准确,如果我们不看高 16 位中存储的前导零数值的二进制部分,那么这段描述是正确的。因为高 16 位中存储了前导零数值的二进制部分,那么 -(1 + the number of active resizing threads) 这段描述就有误,比如首个 CAS 成功的线程将设置 sizeCtl 的值为 10000000 00011011 00000000 00000010,此时代表有一个线程正在进行扩容操作,sizeCtl 转换为负数时为 -2145714174 而并不为注释中提到的 -(1 + 1)。所以,个人认为这段注释想要表达的其实是当 sizeCtl 为负数时,即符号位为 1 时,此时哈希表正在初始化或者扩容,当正在初始化时,低 16 位对应的数值为 1,当正在扩容时,低 16 位对应的数值表示当前并发扩容的线程数量加上一,为什么要加上一呢?因为低 16 位为 00000000 00000001 这种情况被刚提到的正在初始化状态占用了,这就是首次 CAS 需要将 sizeCtl 更新为 rs + 2 的原因。当首个 CAS 成功的线程将 sizeCtl 更新为负数后,后续的线程在进行 CAS 操作时都是调用的 U.compareAndSetInt(this, SIZECTL, sc, sc + 1)sc 加一,即增加一个并发扩容的线程数。CAS 成功后将进入扩容的方法 transfer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/**
* Moves and/or copies the nodes in each bin to new table. See
* above for explanation.
*/
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
int n = tab.length, stride; // stride 为步长,即单个线程需要迁移的 bucket 数量
if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
stride = MIN_TRANSFER_STRIDE; // subdivide range, 根据 CPU 核心数计算步长,保证最低不小于 16 以避免过多的内存争用
if (nextTab == null) { // initiating
try {
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1]; // 新数组大小为当前数组大小的两倍
nextTab = nt;
} catch (Throwable ex) { // try to cope with OOME
sizeCtl = Integer.MAX_VALUE;
return;
}
nextTable = nextTab;
transferIndex = n; // 扩容处理 bucket 是从右向左处理,处理完索引为 0 的 bucket 就结束整个哈希表的扩容,整个索引范围为 [0, transferIndex), 注意不含 transferIndex
}
int nextn = nextTab.length; // 新数组的大小
ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab); // 创建转发节点,指向新创建的 table 数组
boolean advance = true; // 是否继续向左推进的标志位
boolean finishing = false; // to ensure sweep before committing nextTab, 是否完成整个 table 数组扩容的标志位
for (int i = 0, bound = 0;;) {
Node<K,V> f; int fh;
while (advance) {
int nextIndex, nextBound;
if (--i >= bound || finishing) // 注意 --i >= bound 时会跳出当前 while 循环,且对 i 进行了减 1 操作,即准备迁移左侧的 bucket
advance = false;
else if ((nextIndex = transferIndex) <= 0) { // 当 transferIndex 小于等于 0 时跳出 while 循环
i = -1;
advance = false;
}
else if (U.compareAndSwapInt
(this, TRANSFERINDEX, nextIndex,
nextBound = (nextIndex > stride ?
nextIndex - stride : 0))) { // CAS 更新 transferIndex, 更新成功则表示当前线程负责迁移索引为 [Math.max(nextIndex - stride, 0), nextIndex) 这部分 bucket
bound = nextBound; // 将当前线程迁移的左边界赋值给 bound 变量供后续判断当前线程完成 [bound, nextIndex) 这段 bucket 的迁移
i = nextIndex - 1; // i 是当前迁移的 bucket 的索引,因为 nextIndex 右边界默认为 table.length, 即不包含, 所以迁移时从索引为 nextIndex - 1 的桶开始
advance = false; // 设置 advance 为 false 并跳出 while 循环
}
}
if (i < 0 || i >= n || i + n >= nextn) {
int sc;
if (finishing) { // table 数组迁移完成
nextTable = null;
table = nextTab;
sizeCtl = (n << 1) - (n >>> 1);
return;
}
if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) { // 当前线程完成扩容,并发扩容线程数减 1
if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT) // 当前线程为最后一个扩容的线程则返回
return;
finishing = advance = true;
i = n; // recheck before commit
}
}
else if ((f = tabAt(tab, i)) == null) // 如果索引 i 的 bucket 为空,就尝试将该 bucket 更新为转发节点
advance = casTabAt(tab, i, null, fwd); // 更新成功当前线程就继续向左推进
else if ((fh = f.hash) == MOVED)
advance = true; // already processed
else {
synchronized (f) { // 当 bucket 不为空且未被迁移过时,锁住该 bucket 开始执行迁移 bucket 操作
if (tabAt(tab, i) == f) { // 双重检查
Node<K,V> ln, hn;
if (fh >= 0) { // 如果当前 bucket 存储的链表则拆分链表,Dong Lea 采用的 两次遍历 + 头插法 实现
int runBit = fh & n;
Node<K,V> lastRun = f;
for (Node<K,V> p = f.next; p != null; p = p.next) {
int b = p.hash & n;
if (b != runBit) {
runBit = b;
lastRun = p; // 记录下最后一段高位相同节点的起始节点
}
}
if (runBit == 0) {
ln = lastRun;
hn = null;
}
else {
hn = lastRun;
ln = null;
}
for (Node<K,V> p = f; p != lastRun; p = p.next) {
int ph = p.hash; K pk = p.key; V pv = p.val;
if ((ph & n) == 0)
ln = new Node<K,V>(ph, pk, pv, ln); // 头插法
else
hn = new Node<K,V>(ph, pk, pv, hn); // 头插法
}
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;
}
else if (f instanceof TreeBin) {
TreeBin<K,V> t = (TreeBin<K,V>)f;
TreeNode<K,V> lo = null, loTail = null;
TreeNode<K,V> hi = null, hiTail = null;
int lc = 0, hc = 0;
for (Node<K,V> e = t.first; e != null; e = e.next) {
int h = e.hash;
TreeNode<K,V> p = new TreeNode<K,V>
(h, e.key, e.val, null, null);
if ((h & n) == 0) {
if ((p.prev = loTail) == null)
lo = p;
else
loTail.next = p; // 尾插法
loTail = p;
++lc;
}
else {
if ((p.prev = hiTail) == null)
hi = p;
else
hiTail.next = p; // 尾插法
hiTail = p;
++hc;
}
}
ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
(hc != 0) ? new TreeBin<K,V>(lo) : t;
hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
(lc != 0) ? new TreeBin<K,V>(hi) : t;
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;
}
}
}
}
}
}

在源码上已经加上了注释,DEBUG 跟踪一遍代码对运行机制理解更加深入,理解了 transfer 方法再回头去看 addCount 方法中触发 break 的判断逻辑就很容易理解了。为什么用 JDK 12 的源码来记录呢?因为 JDK 8 中的代码扩容时控制并发线程数的逻辑存在 bug,这个 bug 在 JDK 12 中已经被修复了,而一直没有人将这个 bug 反向移植至 JDK 8,所以至今 JDK 8 中的扩容代码都是存在 bug 的,个人认为这容易对学习源码的人造成障碍,特别是在不知情的情况下尝试理解代码的实现时,你会发现无法理解这段代码,为什么无法理解呢?因为这段代码本来就是错的。

2022-08-20

JDK-8214427 已反向移植至 JDK8,自 OpenJDK 8u352 起包含该修复,自 Oracle JDK 8u361(预计 2023-01-17 发布) 起包含该修复。

Reference

Why does ConcurrentHashMap prevent null keys and values? - Stack Overflow
Java Concurrent Hashmap initTable() Why the try/finally block? - Stack Overflow
8214427: probable bug in logic of ConcurrentHashMap.addCount() by tianshuang · Pull Request #18 · openjdk/jdk8u-dev · GitHub
How to contribute a fix - OpenJDK Wiki
JDK Releases
PARITY REPORT FOR JDK: 8