9
0
mirror of https://github.com/Winds-Studio/Leaf.git synced 2025-12-26 18:39:23 +00:00

update MpmcQueue

This commit is contained in:
hayanesuru
2025-07-17 04:15:47 +09:00
parent 1bf5f251ce
commit 390afdff39

View File

@@ -7,18 +7,24 @@ import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
public final class MpmcQueue<T> {
private static final int MAX_IN_PROGRESS = 16;
private static final long DONE_MASK = 0x0000_0000_0000_FF00L;
private static final long VERSION_MASK = 0xFFFF_0000_0000_0000L;
private static final long INDEX_MASK = 0x0000_FFFF_FFFF_0000L;
private static final long PENDING_MASK = 0x0000_0000_0000_00FFL;
private static final long FAST_PATH_MASK = 0x00FF_FFFF_FFFF_FF00L;
private static final long DONE_MASK = 0x0000_0000_0000_FF00L;
private static final long FAST_PATH_MASK = INDEX_MASK | DONE_MASK;
private static final long MAX_IN_PROGRESS = 16L;
private static final int MAX_CAPACITY = 1 << 30;
private static final int VERSION_SHIFT = 48;
private static final int INDEX_SHIFT = 16;
private static final int PARALLELISM = Runtime.getRuntime().availableProcessors();
private static final VarHandle READ;
private static final VarHandle WRITE;
private final int mask;
private final long mask;
private final long capacity;
private final T[] buffer;
@SuppressWarnings("unused")
private final Padded padded1 = new Padded();
@SuppressWarnings("FieldMayBeFinal")
@@ -31,10 +37,8 @@ public final class MpmcQueue<T> {
static {
try {
MethodHandles.Lookup l = MethodHandles.lookup();
READ = l.findVarHandle(MpmcQueue.class, "reads",
long.class);
WRITE = l.findVarHandle(MpmcQueue.class, "writes",
long.class);
READ = l.findVarHandle(MpmcQueue.class, "reads", long.class);
WRITE = l.findVarHandle(MpmcQueue.class, "writes", long.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
@@ -45,11 +49,32 @@ public final class MpmcQueue<T> {
throw new IllegalArgumentException();
}
this.mask = (1 << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1))) - 1;
this.mask = (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1))) - 1L;
//noinspection unchecked
this.buffer = (clazz == Object.class)
? (T[]) new Object[mask + 1]
: (T[]) java.lang.reflect.Array.newInstance(clazz, mask + 1);
? (T[]) new Object[(int) (mask + 1L)]
: (T[]) java.lang.reflect.Array.newInstance(clazz, (int) (mask + 1L));
this.capacity = mask + 1L;
}
private static long version(long state) {
return (state & VERSION_MASK) >>> VERSION_SHIFT;
}
private static long index(long state) {
return (state & INDEX_MASK) >>> INDEX_SHIFT;
}
private static long pending(long state) {
return state & PENDING_MASK;
}
private static long done(long state) {
return (state & DONE_MASK) >>> 8;
}
private static long createState(long version, long index, long done, long pending) {
return version << VERSION_SHIFT | index << INDEX_SHIFT | done << 8 | pending;
}
private void spinWait(final int attempts) {
@@ -62,32 +87,45 @@ public final class MpmcQueue<T> {
}
public boolean send(final T item) {
if (item == null) {
throw new IllegalArgumentException("Cannot enqueue null item");
}
long write = (long) WRITE.getAcquire(this);
boolean success;
long newWrite = 0L;
int index = 0;
long index = 0L;
int attempts = 0;
while (true) {
spinWait(attempts++);
final int inProgressCnt = (int) (write & PENDING_MASK);
if ((((int) (write >>> 16) + 1) & mask) == (int) ((long) READ.getAcquire(this) >>> 16)) {
final long writeVersion = version(write);
final long writePending = pending(write);
final long writeIndex = index(write);
final long currentRead = (long) READ.getAcquire(this);
final long readIndex = index(currentRead);
final long currentItems = readIndex <= writeIndex
? writeIndex - readIndex
: writeIndex + capacity - readIndex;
if (currentItems + writePending >= mask) {
success = false;
break;
}
if (inProgressCnt == MAX_IN_PROGRESS) {
if (writePending == MAX_IN_PROGRESS) {
write = (long) WRITE.getAcquire(this);
continue;
}
index = ((int) (write >>> 16) + inProgressCnt) & mask;
if (((index + 1) & mask) == (int) ((long) READ.getAcquire(this) >>> 16)) {
index = (writeIndex + writePending) & mask;
if (((writeIndex + writePending + 1L) & mask) == readIndex) {
success = false;
break;
}
newWrite = write + 1;
newWrite = createState(
writeVersion + 1L,
writeIndex,
done(write),
writePending + 1L
);
if (WRITE.weakCompareAndSetAcquire(this, write, newWrite)) {
success = true;
break;
@@ -97,60 +135,69 @@ public final class MpmcQueue<T> {
if (!success) {
return false;
}
buffer[index] = item;
if ((newWrite & FAST_PATH_MASK) == ((long) index << 16) && index < mask) {
WRITE.getAndAddRelease(this, (1L << 16) - 1);
} else {
write = newWrite;
while (true) {
final int inProcessCnt = (int) (write & PENDING_MASK);
final long n;
if (((int) ((write & DONE_MASK) >>> 8) + 1) == inProcessCnt) {
n = ((long) (((int) (write >>> 16) + inProcessCnt) & mask)) << 16;
} else if ((int) (write >>> 16) == index) {
n = (write + (1L << 16) - 1) & (((long) mask << 16) | 0xFFFFL);
} else {
n = write + (1L << 8);
}
if (WRITE.weakCompareAndSetRelease(this, write, n)) {
break;
}
write = (long) WRITE.getVolatile(this);
spinWait(attempts++);
}
buffer[(int) index] = item;
/*
if ((newWrite & FAST_PATH_MASK) == (index << INDEX_SHIFT) && index < mask) {
WRITE.getAndAddRelease(this, (1L << INDEX_SHIFT) - 1L);
return true;
}
*/
write = newWrite;
while (true) {
final long p = pending(write);
final long d = done(write);
final long i = index(write);
final long v = version(write);
final long n = d + 1L == p
? createState(v + 1L, (i + p) & mask, 0L, 0L)
: i == index
? createState(v, (i + 1L) & mask, d, p - 1L)
: createState(v, i, d + 1L, p);
if (WRITE.weakCompareAndSetRelease(this, write, n)) {
break;
}
write = (long) WRITE.getVolatile(this);
spinWait(attempts++);
}
return true;
}
public T recv() {
long read = (long) READ.getAcquire(this);
boolean success;
int index = 0;
long index = 0L;
long newRead = 0L;
int attempts = 0;
while (true) {
spinWait(attempts++);
final int inProgressCnt = (int) (read & PENDING_MASK);
if ((int) (read >>> 16) == (int) ((long) WRITE.getAcquire(this) >>> 16)) {
final long readVersion = version(read);
final long readPending = pending(read);
final long writeIndex = index((long) WRITE.getAcquire(this));
final long readIndex = index(read);
final long currentItems = readIndex <= writeIndex
? writeIndex - readIndex
: writeIndex + capacity - readIndex;
if (currentItems == 0L) {
success = false;
break;
}
if (inProgressCnt == MAX_IN_PROGRESS) {
if (readPending == MAX_IN_PROGRESS) {
read = (long) READ.getAcquire(this);
continue;
}
index = ((int) (read >>> 16) + inProgressCnt) & mask;
if (index == (int) ((long) WRITE.getAcquire(this) >>> 16)) {
index = (readIndex + readPending) & mask;
if (index == writeIndex) {
success = false;
break;
}
newRead = read + 1;
newRead = createState(
readVersion + 1L,
readIndex,
done(read),
readPending + 1L
);
if (READ.weakCompareAndSetAcquire(this, read, newRead)) {
success = true;
break;
@@ -160,29 +207,30 @@ public final class MpmcQueue<T> {
if (!success) {
return null;
}
final T result = buffer[index];
buffer[index] = null;
if ((newRead & FAST_PATH_MASK) == ((long) index << 16) && index < mask) {
READ.getAndAddRelease(this, (1L << 16) - 1);
} else {
read = newRead;
while (true) {
final int inProcessCnt = (int) (read & PENDING_MASK);
final long n;
if (((int) ((read & DONE_MASK) >>> 8) + 1) == inProcessCnt) {
n = ((long) (((int) (read >>> 16) + inProcessCnt) & mask)) << 16;
} else if ((int) (read >>> 16) == index) {
n = (read + (1L << 16) - 1) & (((long) mask << 16) | 0xFFFFL);
} else {
n = read + (1L << 8);
}
if (READ.weakCompareAndSetRelease(this, read, n)) {
break;
}
read = (long) READ.getVolatile(this);
spinWait(attempts++);
final T result = buffer[(int) index];
buffer[(int) index] = null;
/*
if ((newRead & FAST_PATH_MASK) == (index << INDEX_SHIFT) && index < mask) {
READ.getAndAddRelease(this, (1L << INDEX_SHIFT) - 1L);
return result;
}
*/
read = newRead;
while (true) {
final long p = pending(read);
final long d = done(read);
final long i = index(read);
final long v = version(read);
final long n = d + 1L == p
? createState(v + 1L, (i + p) & mask, 0L, 0L)
: i == index
? createState(v, (i + 1L) & mask, d, p - 1L)
: createState(v, i, d + 1L, p);
if (READ.weakCompareAndSetRelease(this, read, n)) {
break;
}
read = (long) READ.getVolatile(this);
spinWait(attempts++);
}
return result;
}
@@ -190,31 +238,34 @@ public final class MpmcQueue<T> {
public int length() {
final long readCounters = (long) READ.getVolatile(this);
final long writeCounters = (long) WRITE.getVolatile(this);
final int readIndex = (int) (readCounters >>> 16);
final int writeIndex = (int) (writeCounters >>> 16);
return (readIndex <= writeIndex ?
writeIndex - readIndex :
writeIndex + capacity() - readIndex) - (int) (readCounters & PENDING_MASK);
final long readIndex = index(readCounters);
final long writeIndex = index(writeCounters);
return (int) ((readIndex <= writeIndex
? writeIndex - readIndex
: writeIndex + capacity - readIndex) - pending(readCounters));
}
public boolean isEmpty() {
return length() == 0;
}
public int capacity() {
return buffer.length;
final long readCounters = (long) READ.getVolatile(this);
final long writeCounters = (long) WRITE.getVolatile(this);
final long readIndex = index(readCounters);
final long writeIndex = index(writeCounters);
final long writePending = pending(writeCounters);
final long currentItems = readIndex <= writeIndex
? writeIndex - readIndex
: writeIndex + capacity - readIndex;
return currentItems == 0L && writePending == 0L;
}
public int remaining() {
final long readCounters = (long) READ.getVolatile(this);
final long writeCounters = (long) WRITE.getVolatile(this);
final int cap = capacity();
final int readIndex = (int) (readCounters >>> 16);
final int writeIndex = (int) (writeCounters >>> 16);
final int len = readIndex <= writeIndex ?
final long readIndex = index(readCounters);
final long writeIndex = index(writeCounters);
final long len = readIndex <= writeIndex ?
writeIndex - readIndex :
writeIndex + cap - readIndex;
return cap - 1 - len - (int) (writeCounters & PENDING_MASK);
writeIndex + capacity - readIndex;
return (int) (mask - len - pending(writeCounters));
}
@SuppressWarnings("unused")