mirror of
https://github.com/Winds-Studio/Leaf.git
synced 2025-12-25 18:09:17 +00:00
fix MpmcQueue memory order
This commit is contained in:
@@ -5,11 +5,11 @@ import org.dreeam.leaf.util.queue.MpmcQueue;
|
||||
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.FutureTask;
|
||||
import java.util.concurrent.locks.LockSupport;
|
||||
|
||||
public final class FixedThreadExecutor {
|
||||
private final Thread[] threads;
|
||||
public final MpmcQueue<Runnable> channel;
|
||||
public final Object sync;
|
||||
private static volatile boolean SHUTDOWN = false;
|
||||
|
||||
public FixedThreadExecutor(int numThreads, int queue, String prefix) {
|
||||
@@ -18,14 +18,13 @@ public final class FixedThreadExecutor {
|
||||
}
|
||||
this.threads = new Thread[numThreads];
|
||||
this.channel = new MpmcQueue<>(Runnable.class, queue);
|
||||
this.sync = new Object();
|
||||
for (int i = 0; i < numThreads; i++) {
|
||||
threads[i] = Thread.ofPlatform()
|
||||
.uncaughtExceptionHandler(Util::onThreadException)
|
||||
.daemon(false)
|
||||
.priority(Thread.NORM_PRIORITY)
|
||||
.name(prefix + " - " + i)
|
||||
.start(new Worker(channel, sync));
|
||||
.start(new Worker(channel));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,15 +41,16 @@ public final class FixedThreadExecutor {
|
||||
}
|
||||
|
||||
public void unpack() {
|
||||
synchronized (sync) {
|
||||
sync.notifyAll();
|
||||
final int len = Math.clamp(channel.length(), 1, threads.length);
|
||||
for (int i = 0; i < len; i++) {
|
||||
LockSupport.unpark(threads[i]);
|
||||
}
|
||||
}
|
||||
|
||||
public void shutdown() {
|
||||
SHUTDOWN = true;
|
||||
synchronized (sync) {
|
||||
sync.notifyAll();
|
||||
for (Thread thread : threads) {
|
||||
LockSupport.unpark(thread);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ public final class FixedThreadExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
private record Worker(MpmcQueue<Runnable> channel, Object sync) implements Runnable {
|
||||
private record Worker(MpmcQueue<Runnable> channel) implements Runnable {
|
||||
@Override
|
||||
public void run() {
|
||||
while (true) {
|
||||
@@ -78,12 +78,12 @@ public final class FixedThreadExecutor {
|
||||
task.run();
|
||||
} else if (SHUTDOWN) {
|
||||
break;
|
||||
} else if (channel.isEmpty()) {
|
||||
synchronized (sync) {
|
||||
try {
|
||||
sync.wait();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
} else {
|
||||
Thread.yield();
|
||||
if (channel.isEmpty()) {
|
||||
LockSupport.park();
|
||||
if (Thread.interrupted()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,19 +3,20 @@
|
||||
|
||||
package org.dreeam.leaf.util.queue;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.lang.invoke.VarHandle;
|
||||
|
||||
public final class MpmcQueue<T> {
|
||||
private static final long VERSION_MASK = 0xFFFF_0000_0000_0000L;
|
||||
private static final long INDEX_MASK = 0x0000_FFFF_FFFF_0000L;
|
||||
private static final long PENDING_MASK = 0x0000_0000_0000_00FFL;
|
||||
private static final long DONE_MASK = 0x0000_0000_0000_FF00L;
|
||||
private static final long FAST_PATH_MASK = INDEX_MASK | DONE_MASK;
|
||||
private static final long MAX_IN_PROGRESS = 16L;
|
||||
private static final int MAX_CAPACITY = 1 << 30;
|
||||
private static final int VERSION_SHIFT = 48;
|
||||
private static final long PENDING_MASK = 0x0000_0000_0000_00FFL;
|
||||
private static final long DONE_PENDING_MASK = DONE_MASK | PENDING_MASK;
|
||||
private static final int INDEX_SHIFT = 16;
|
||||
private static final int DONE_SHIFT = 8;
|
||||
private static final long MAX_IN_PROGRESS = 16;
|
||||
private static final int MAX_CAPACITY = 1 << 30;
|
||||
private static final int PARALLELISM = Runtime.getRuntime().availableProcessors();
|
||||
|
||||
private static final VarHandle READ;
|
||||
@@ -23,6 +24,7 @@ public final class MpmcQueue<T> {
|
||||
|
||||
private final long mask;
|
||||
private final long capacity;
|
||||
@Nullable
|
||||
private final T[] buffer;
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
@@ -49,35 +51,16 @@ public final class MpmcQueue<T> {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
this.mask = (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1))) - 1L;
|
||||
this.capacity = Math.max(2, (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1))));
|
||||
this.mask = this.capacity - 1L;
|
||||
//noinspection unchecked
|
||||
this.buffer = (clazz == Object.class)
|
||||
? (T[]) new Object[(int) (mask + 1L)]
|
||||
: (T[]) java.lang.reflect.Array.newInstance(clazz, (int) (mask + 1L));
|
||||
this.capacity = mask + 1L;
|
||||
}
|
||||
|
||||
private static long version(long state) {
|
||||
return (state & VERSION_MASK) >>> VERSION_SHIFT;
|
||||
}
|
||||
|
||||
private static long index(long state) {
|
||||
return (state & INDEX_MASK) >>> INDEX_SHIFT;
|
||||
}
|
||||
|
||||
private static long pending(long state) {
|
||||
return state & PENDING_MASK;
|
||||
}
|
||||
|
||||
private static long done(long state) {
|
||||
return (state & DONE_MASK) >>> 8;
|
||||
}
|
||||
|
||||
private static long createState(long version, long index, long done, long pending) {
|
||||
return version << VERSION_SHIFT | index << INDEX_SHIFT | done << 8 | pending;
|
||||
? (T[]) new Object[(int) this.capacity]
|
||||
: (T[]) java.lang.reflect.Array.newInstance(clazz, (int) this.capacity);
|
||||
}
|
||||
|
||||
private void spinWait(final int attempts) {
|
||||
//noinspection StatementWithEmptyBody
|
||||
if (attempts == 0) {
|
||||
} else if (PARALLELISM != 1 && (attempts & 31) != 31) {
|
||||
Thread.onSpinWait();
|
||||
@@ -86,46 +69,30 @@ public final class MpmcQueue<T> {
|
||||
}
|
||||
}
|
||||
|
||||
public boolean send(final T item) {
|
||||
if (item == null) {
|
||||
throw new IllegalArgumentException("Cannot enqueue null item");
|
||||
}
|
||||
|
||||
public boolean send(@NotNull final T item) {
|
||||
long write = (long) WRITE.getAcquire(this);
|
||||
boolean success;
|
||||
long newWrite = 0L;
|
||||
long index = 0L;
|
||||
int attempts = 0;
|
||||
|
||||
while (true) {
|
||||
spinWait(attempts++);
|
||||
final long writeVersion = version(write);
|
||||
final long writePending = pending(write);
|
||||
final long writeIndex = index(write);
|
||||
final long currentRead = (long) READ.getAcquire(this);
|
||||
final long readIndex = index(currentRead);
|
||||
final long currentItems = readIndex <= writeIndex
|
||||
? writeIndex - readIndex
|
||||
: writeIndex + capacity - readIndex;
|
||||
if (currentItems + writePending >= mask) {
|
||||
final long inProgressCnt = (write & PENDING_MASK);
|
||||
if ((((write >>> INDEX_SHIFT) + 1L) & mask) == ((long) READ.getVolatile(this) >>> INDEX_SHIFT)) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
if (writePending == MAX_IN_PROGRESS) {
|
||||
|
||||
if (inProgressCnt == MAX_IN_PROGRESS) {
|
||||
write = (long) WRITE.getAcquire(this);
|
||||
continue;
|
||||
}
|
||||
index = (writeIndex + writePending) & mask;
|
||||
if (((writeIndex + writePending + 1L) & mask) == readIndex) {
|
||||
index = ((write >>> INDEX_SHIFT) + inProgressCnt) & mask;
|
||||
if (((index + 1L) & mask) == ((long) READ.getVolatile(this) >>> INDEX_SHIFT)) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
newWrite = createState(
|
||||
writeVersion + 1L,
|
||||
writeIndex,
|
||||
done(write),
|
||||
writePending + 1L
|
||||
);
|
||||
newWrite = write + 1L;
|
||||
if (WRITE.weakCompareAndSetAcquire(this, write, newWrite)) {
|
||||
success = true;
|
||||
break;
|
||||
@@ -136,23 +103,13 @@ public final class MpmcQueue<T> {
|
||||
return false;
|
||||
}
|
||||
buffer[(int) index] = item;
|
||||
/*
|
||||
if ((newWrite & FAST_PATH_MASK) == (index << INDEX_SHIFT) && index < mask) {
|
||||
WRITE.getAndAddRelease(this, (1L << INDEX_SHIFT) - 1L);
|
||||
return true;
|
||||
}
|
||||
*/
|
||||
write = newWrite;
|
||||
while (true) {
|
||||
final long p = pending(write);
|
||||
final long d = done(write);
|
||||
final long i = index(write);
|
||||
final long v = version(write);
|
||||
final long n = d + 1L == p
|
||||
? createState(v + 1L, (i + p) & mask, 0L, 0L)
|
||||
: i == index
|
||||
? createState(v, (i + 1L) & mask, d, p - 1L)
|
||||
: createState(v, i, d + 1L, p);
|
||||
final long n = ((write & DONE_MASK) >>> DONE_SHIFT) + 1L == (write & PENDING_MASK)
|
||||
? ((write >>> INDEX_SHIFT) + (write & PENDING_MASK) & mask) << INDEX_SHIFT
|
||||
: write >>> INDEX_SHIFT == index
|
||||
? write + (1L << INDEX_SHIFT) - 1L & (mask << INDEX_SHIFT | DONE_PENDING_MASK)
|
||||
: write + (1L << DONE_SHIFT);
|
||||
if (WRITE.weakCompareAndSetRelease(this, write, n)) {
|
||||
break;
|
||||
}
|
||||
@@ -162,42 +119,29 @@ public final class MpmcQueue<T> {
|
||||
return true;
|
||||
}
|
||||
|
||||
public T recv() {
|
||||
public @Nullable T recv() {
|
||||
long read = (long) READ.getAcquire(this);
|
||||
boolean success;
|
||||
long index = 0L;
|
||||
long index = 0;
|
||||
long newRead = 0L;
|
||||
int attempts = 0;
|
||||
while (true) {
|
||||
spinWait(attempts++);
|
||||
final long readVersion = version(read);
|
||||
final long readPending = pending(read);
|
||||
final long writeIndex = index((long) WRITE.getAcquire(this));
|
||||
final long readIndex = index(read);
|
||||
final long currentItems = readIndex <= writeIndex
|
||||
? writeIndex - readIndex
|
||||
: writeIndex + capacity - readIndex;
|
||||
if (currentItems == 0L) {
|
||||
final long inProgressCnt = (read & PENDING_MASK);
|
||||
if ((read >>> INDEX_SHIFT) == ((long) WRITE.getVolatile(this) >>> INDEX_SHIFT)) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (readPending == MAX_IN_PROGRESS) {
|
||||
if (inProgressCnt == MAX_IN_PROGRESS) {
|
||||
read = (long) READ.getAcquire(this);
|
||||
continue;
|
||||
}
|
||||
|
||||
index = (readIndex + readPending) & mask;
|
||||
if (index == writeIndex) {
|
||||
index = ((read >>> INDEX_SHIFT) + inProgressCnt) & mask;
|
||||
if ((index & mask) == ((long) WRITE.getVolatile(this) >>> INDEX_SHIFT)) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
newRead = createState(
|
||||
readVersion + 1L,
|
||||
readIndex,
|
||||
done(read),
|
||||
readPending + 1L
|
||||
);
|
||||
newRead = read + 1L;
|
||||
if (READ.weakCompareAndSetAcquire(this, read, newRead)) {
|
||||
success = true;
|
||||
break;
|
||||
@@ -209,23 +153,13 @@ public final class MpmcQueue<T> {
|
||||
}
|
||||
final T result = buffer[(int) index];
|
||||
buffer[(int) index] = null;
|
||||
/*
|
||||
if ((newRead & FAST_PATH_MASK) == (index << INDEX_SHIFT) && index < mask) {
|
||||
READ.getAndAddRelease(this, (1L << INDEX_SHIFT) - 1L);
|
||||
return result;
|
||||
}
|
||||
*/
|
||||
read = newRead;
|
||||
while (true) {
|
||||
final long p = pending(read);
|
||||
final long d = done(read);
|
||||
final long i = index(read);
|
||||
final long v = version(read);
|
||||
final long n = d + 1L == p
|
||||
? createState(v + 1L, (i + p) & mask, 0L, 0L)
|
||||
: i == index
|
||||
? createState(v, (i + 1L) & mask, d, p - 1L)
|
||||
: createState(v, i, d + 1L, p);
|
||||
final long n = ((read & DONE_MASK) >>> DONE_SHIFT) + 1L == (read & PENDING_MASK)
|
||||
? ((read >>> INDEX_SHIFT) + (read & PENDING_MASK) & mask) << INDEX_SHIFT
|
||||
: read >>> INDEX_SHIFT == index
|
||||
? read + (1L << INDEX_SHIFT) - 1L & (mask << INDEX_SHIFT | DONE_PENDING_MASK)
|
||||
: read + (1L << DONE_SHIFT);
|
||||
if (READ.weakCompareAndSetRelease(this, read, n)) {
|
||||
break;
|
||||
}
|
||||
@@ -236,36 +170,27 @@ public final class MpmcQueue<T> {
|
||||
}
|
||||
|
||||
public int length() {
|
||||
final long readCounters = (long) READ.getVolatile(this);
|
||||
final long writeCounters = (long) WRITE.getVolatile(this);
|
||||
final long readIndex = index(readCounters);
|
||||
final long writeIndex = index(writeCounters);
|
||||
return (int) ((readIndex <= writeIndex
|
||||
? writeIndex - readIndex
|
||||
: writeIndex + capacity - readIndex) - pending(readCounters));
|
||||
final long reads = (long) READ.getVolatile(this);
|
||||
final long writes = (long) WRITE.getVolatile(this);
|
||||
final long readIndex = (reads >>> INDEX_SHIFT);
|
||||
final long writeIndex = (writes >>> INDEX_SHIFT);
|
||||
return (int) (readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex);
|
||||
// (readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex) - (reads & PENDING_MASK)
|
||||
}
|
||||
|
||||
public boolean isEmpty() {
|
||||
final long readCounters = (long) READ.getVolatile(this);
|
||||
final long writeCounters = (long) WRITE.getVolatile(this);
|
||||
final long readIndex = index(readCounters);
|
||||
final long writeIndex = index(writeCounters);
|
||||
final long writePending = pending(writeCounters);
|
||||
final long currentItems = readIndex <= writeIndex
|
||||
? writeIndex - readIndex
|
||||
: writeIndex + capacity - readIndex;
|
||||
return currentItems == 0L && writePending == 0L;
|
||||
return length() == 0;
|
||||
}
|
||||
|
||||
public int remaining() {
|
||||
final long readCounters = (long) READ.getVolatile(this);
|
||||
final long writeCounters = (long) WRITE.getVolatile(this);
|
||||
final long readIndex = index(readCounters);
|
||||
final long writeIndex = index(writeCounters);
|
||||
final long reads = (long) READ.getVolatile(this);
|
||||
final long writes = (long) WRITE.getVolatile(this);
|
||||
final long readIndex = (reads >>> INDEX_SHIFT);
|
||||
final long writeIndex = (writes >>> INDEX_SHIFT);
|
||||
final long len = readIndex <= writeIndex ?
|
||||
writeIndex - readIndex :
|
||||
writeIndex + capacity - readIndex;
|
||||
return (int) (mask - len - pending(writeCounters));
|
||||
return (int) (mask - len - (writes & PENDING_MASK));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
|
||||
Reference in New Issue
Block a user