diff --git a/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java b/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java index 11af5a50..fef96034 100644 --- a/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java +++ b/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java @@ -5,11 +5,11 @@ import org.dreeam.leaf.util.queue.MpmcQueue; import java.util.concurrent.Callable; import java.util.concurrent.FutureTask; +import java.util.concurrent.locks.LockSupport; public final class FixedThreadExecutor { private final Thread[] threads; public final MpmcQueue channel; - public final Object sync; private static volatile boolean SHUTDOWN = false; public FixedThreadExecutor(int numThreads, int queue, String prefix) { @@ -18,14 +18,13 @@ public final class FixedThreadExecutor { } this.threads = new Thread[numThreads]; this.channel = new MpmcQueue<>(Runnable.class, queue); - this.sync = new Object(); for (int i = 0; i < numThreads; i++) { threads[i] = Thread.ofPlatform() .uncaughtExceptionHandler(Util::onThreadException) .daemon(false) .priority(Thread.NORM_PRIORITY) .name(prefix + " - " + i) - .start(new Worker(channel, sync)); + .start(new Worker(channel)); } } @@ -42,15 +41,16 @@ public final class FixedThreadExecutor { } public void unpack() { - synchronized (sync) { - sync.notifyAll(); + final int len = Math.clamp(channel.length(), 1, threads.length); + for (int i = 0; i < len; i++) { + LockSupport.unpark(threads[i]); } } public void shutdown() { SHUTDOWN = true; - synchronized (sync) { - sync.notifyAll(); + for (Thread thread : threads) { + LockSupport.unpark(thread); } } @@ -69,7 +69,7 @@ public final class FixedThreadExecutor { } } - private record Worker(MpmcQueue channel, Object sync) implements Runnable { + private record Worker(MpmcQueue channel) implements Runnable { @Override public void run() { while (true) { @@ -78,12 +78,12 @@ public final class FixedThreadExecutor { task.run(); } else if (SHUTDOWN) { break; - } else if (channel.isEmpty()) { - synchronized (sync) { - try { - sync.wait(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); + } else { + Thread.yield(); + if (channel.isEmpty()) { + LockSupport.park(); + if (Thread.interrupted()) { + return; } } } diff --git a/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java b/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java index 23710eb1..f35c33f1 100644 --- a/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java +++ b/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java @@ -3,19 +3,20 @@ package org.dreeam.leaf.util.queue; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; public final class MpmcQueue { - private static final long VERSION_MASK = 0xFFFF_0000_0000_0000L; - private static final long INDEX_MASK = 0x0000_FFFF_FFFF_0000L; - private static final long PENDING_MASK = 0x0000_0000_0000_00FFL; private static final long DONE_MASK = 0x0000_0000_0000_FF00L; - private static final long FAST_PATH_MASK = INDEX_MASK | DONE_MASK; - private static final long MAX_IN_PROGRESS = 16L; - private static final int MAX_CAPACITY = 1 << 30; - private static final int VERSION_SHIFT = 48; + private static final long PENDING_MASK = 0x0000_0000_0000_00FFL; + private static final long DONE_PENDING_MASK = DONE_MASK | PENDING_MASK; private static final int INDEX_SHIFT = 16; + private static final int DONE_SHIFT = 8; + private static final long MAX_IN_PROGRESS = 16; + private static final int MAX_CAPACITY = 1 << 30; private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); private static final VarHandle READ; @@ -23,6 +24,7 @@ public final class MpmcQueue { private final long mask; private final long capacity; + @Nullable private final T[] buffer; @SuppressWarnings("unused") @@ -49,35 +51,16 @@ public final class MpmcQueue { throw new IllegalArgumentException(); } - this.mask = (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1))) - 1L; + this.capacity = Math.max(2, (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1)))); + this.mask = this.capacity - 1L; //noinspection unchecked this.buffer = (clazz == Object.class) - ? (T[]) new Object[(int) (mask + 1L)] - : (T[]) java.lang.reflect.Array.newInstance(clazz, (int) (mask + 1L)); - this.capacity = mask + 1L; - } - - private static long version(long state) { - return (state & VERSION_MASK) >>> VERSION_SHIFT; - } - - private static long index(long state) { - return (state & INDEX_MASK) >>> INDEX_SHIFT; - } - - private static long pending(long state) { - return state & PENDING_MASK; - } - - private static long done(long state) { - return (state & DONE_MASK) >>> 8; - } - - private static long createState(long version, long index, long done, long pending) { - return version << VERSION_SHIFT | index << INDEX_SHIFT | done << 8 | pending; + ? (T[]) new Object[(int) this.capacity] + : (T[]) java.lang.reflect.Array.newInstance(clazz, (int) this.capacity); } private void spinWait(final int attempts) { + //noinspection StatementWithEmptyBody if (attempts == 0) { } else if (PARALLELISM != 1 && (attempts & 31) != 31) { Thread.onSpinWait(); @@ -86,46 +69,30 @@ public final class MpmcQueue { } } - public boolean send(final T item) { - if (item == null) { - throw new IllegalArgumentException("Cannot enqueue null item"); - } - + public boolean send(@NotNull final T item) { long write = (long) WRITE.getAcquire(this); boolean success; long newWrite = 0L; long index = 0L; int attempts = 0; - while (true) { spinWait(attempts++); - final long writeVersion = version(write); - final long writePending = pending(write); - final long writeIndex = index(write); - final long currentRead = (long) READ.getAcquire(this); - final long readIndex = index(currentRead); - final long currentItems = readIndex <= writeIndex - ? writeIndex - readIndex - : writeIndex + capacity - readIndex; - if (currentItems + writePending >= mask) { + final long inProgressCnt = (write & PENDING_MASK); + if ((((write >>> INDEX_SHIFT) + 1L) & mask) == ((long) READ.getVolatile(this) >>> INDEX_SHIFT)) { success = false; break; } - if (writePending == MAX_IN_PROGRESS) { + + if (inProgressCnt == MAX_IN_PROGRESS) { write = (long) WRITE.getAcquire(this); continue; } - index = (writeIndex + writePending) & mask; - if (((writeIndex + writePending + 1L) & mask) == readIndex) { + index = ((write >>> INDEX_SHIFT) + inProgressCnt) & mask; + if (((index + 1L) & mask) == ((long) READ.getVolatile(this) >>> INDEX_SHIFT)) { success = false; break; } - newWrite = createState( - writeVersion + 1L, - writeIndex, - done(write), - writePending + 1L - ); + newWrite = write + 1L; if (WRITE.weakCompareAndSetAcquire(this, write, newWrite)) { success = true; break; @@ -136,23 +103,13 @@ public final class MpmcQueue { return false; } buffer[(int) index] = item; - /* - if ((newWrite & FAST_PATH_MASK) == (index << INDEX_SHIFT) && index < mask) { - WRITE.getAndAddRelease(this, (1L << INDEX_SHIFT) - 1L); - return true; - } - */ write = newWrite; while (true) { - final long p = pending(write); - final long d = done(write); - final long i = index(write); - final long v = version(write); - final long n = d + 1L == p - ? createState(v + 1L, (i + p) & mask, 0L, 0L) - : i == index - ? createState(v, (i + 1L) & mask, d, p - 1L) - : createState(v, i, d + 1L, p); + final long n = ((write & DONE_MASK) >>> DONE_SHIFT) + 1L == (write & PENDING_MASK) + ? ((write >>> INDEX_SHIFT) + (write & PENDING_MASK) & mask) << INDEX_SHIFT + : write >>> INDEX_SHIFT == index + ? write + (1L << INDEX_SHIFT) - 1L & (mask << INDEX_SHIFT | DONE_PENDING_MASK) + : write + (1L << DONE_SHIFT); if (WRITE.weakCompareAndSetRelease(this, write, n)) { break; } @@ -162,42 +119,29 @@ public final class MpmcQueue { return true; } - public T recv() { + public @Nullable T recv() { long read = (long) READ.getAcquire(this); boolean success; - long index = 0L; + long index = 0; long newRead = 0L; int attempts = 0; while (true) { spinWait(attempts++); - final long readVersion = version(read); - final long readPending = pending(read); - final long writeIndex = index((long) WRITE.getAcquire(this)); - final long readIndex = index(read); - final long currentItems = readIndex <= writeIndex - ? writeIndex - readIndex - : writeIndex + capacity - readIndex; - if (currentItems == 0L) { + final long inProgressCnt = (read & PENDING_MASK); + if ((read >>> INDEX_SHIFT) == ((long) WRITE.getVolatile(this) >>> INDEX_SHIFT)) { success = false; break; } - - if (readPending == MAX_IN_PROGRESS) { + if (inProgressCnt == MAX_IN_PROGRESS) { read = (long) READ.getAcquire(this); continue; } - - index = (readIndex + readPending) & mask; - if (index == writeIndex) { + index = ((read >>> INDEX_SHIFT) + inProgressCnt) & mask; + if ((index & mask) == ((long) WRITE.getVolatile(this) >>> INDEX_SHIFT)) { success = false; break; } - newRead = createState( - readVersion + 1L, - readIndex, - done(read), - readPending + 1L - ); + newRead = read + 1L; if (READ.weakCompareAndSetAcquire(this, read, newRead)) { success = true; break; @@ -209,23 +153,13 @@ public final class MpmcQueue { } final T result = buffer[(int) index]; buffer[(int) index] = null; - /* - if ((newRead & FAST_PATH_MASK) == (index << INDEX_SHIFT) && index < mask) { - READ.getAndAddRelease(this, (1L << INDEX_SHIFT) - 1L); - return result; - } - */ read = newRead; while (true) { - final long p = pending(read); - final long d = done(read); - final long i = index(read); - final long v = version(read); - final long n = d + 1L == p - ? createState(v + 1L, (i + p) & mask, 0L, 0L) - : i == index - ? createState(v, (i + 1L) & mask, d, p - 1L) - : createState(v, i, d + 1L, p); + final long n = ((read & DONE_MASK) >>> DONE_SHIFT) + 1L == (read & PENDING_MASK) + ? ((read >>> INDEX_SHIFT) + (read & PENDING_MASK) & mask) << INDEX_SHIFT + : read >>> INDEX_SHIFT == index + ? read + (1L << INDEX_SHIFT) - 1L & (mask << INDEX_SHIFT | DONE_PENDING_MASK) + : read + (1L << DONE_SHIFT); if (READ.weakCompareAndSetRelease(this, read, n)) { break; } @@ -236,36 +170,27 @@ public final class MpmcQueue { } public int length() { - final long readCounters = (long) READ.getVolatile(this); - final long writeCounters = (long) WRITE.getVolatile(this); - final long readIndex = index(readCounters); - final long writeIndex = index(writeCounters); - return (int) ((readIndex <= writeIndex - ? writeIndex - readIndex - : writeIndex + capacity - readIndex) - pending(readCounters)); + final long reads = (long) READ.getVolatile(this); + final long writes = (long) WRITE.getVolatile(this); + final long readIndex = (reads >>> INDEX_SHIFT); + final long writeIndex = (writes >>> INDEX_SHIFT); + return (int) (readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex); + // (readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex) - (reads & PENDING_MASK) } public boolean isEmpty() { - final long readCounters = (long) READ.getVolatile(this); - final long writeCounters = (long) WRITE.getVolatile(this); - final long readIndex = index(readCounters); - final long writeIndex = index(writeCounters); - final long writePending = pending(writeCounters); - final long currentItems = readIndex <= writeIndex - ? writeIndex - readIndex - : writeIndex + capacity - readIndex; - return currentItems == 0L && writePending == 0L; + return length() == 0; } public int remaining() { - final long readCounters = (long) READ.getVolatile(this); - final long writeCounters = (long) WRITE.getVolatile(this); - final long readIndex = index(readCounters); - final long writeIndex = index(writeCounters); + final long reads = (long) READ.getVolatile(this); + final long writes = (long) WRITE.getVolatile(this); + final long readIndex = (reads >>> INDEX_SHIFT); + final long writeIndex = (writes >>> INDEX_SHIFT); final long len = readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex; - return (int) (mask - len - pending(writeCounters)); + return (int) (mask - len - (writes & PENDING_MASK)); } @SuppressWarnings("unused")