diff --git a/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java b/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java index 06d82fc9..9fa71e94 100644 --- a/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java +++ b/leaf-server/src/main/java/org/dreeam/leaf/async/FixedThreadExecutor.java @@ -5,10 +5,12 @@ import org.dreeam.leaf.util.queue.MpmcQueue; import java.util.concurrent.Callable; import java.util.concurrent.FutureTask; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.LockSupport; public final class FixedThreadExecutor { - private final Thread[] workers; + private final Thread[] threads; + private final Worker[] workers; public final MpmcQueue channel; private static volatile boolean SHUTDOWN = false; @@ -16,15 +18,17 @@ public final class FixedThreadExecutor { if (numThreads <= 0) { throw new IllegalArgumentException(); } - this.workers = new Thread[numThreads]; + this.threads = new Thread[numThreads]; + this.workers = new Worker[numThreads]; this.channel = new MpmcQueue<>(Runnable.class, queue); for (int i = 0; i < numThreads; i++) { - workers[i] = Thread.ofPlatform() + workers[i] = new Worker(channel, new AtomicBoolean(false)); + threads[i] = Thread.ofPlatform() .uncaughtExceptionHandler(Util::onThreadException) .daemon(false) .priority(Thread.NORM_PRIORITY) .name(prefix + " - " + i) - .start(new Worker(channel)); + .start(workers[i]); } } @@ -41,15 +45,17 @@ public final class FixedThreadExecutor { } public void unpack() { - int size = Math.min(Math.max(1, channel.length()), workers.length); + int size = Math.min(Math.max(1, channel.length()), threads.length); for (int i = 0; i < size; i++) { - LockSupport.unpark(workers[i]); + if (workers[i].parked.get()) { + LockSupport.unpark(threads[i]); + } } } public void shutdown() { SHUTDOWN = true; - for (final Thread worker : workers) { + for (final Thread worker : threads) { LockSupport.unpark(worker); } } @@ -57,7 +63,7 @@ public final class FixedThreadExecutor { public void join(long timeoutMillis) throws InterruptedException { final long startTime = System.currentTimeMillis(); - for (final Thread worker : workers) { + for (final Thread worker : threads) { final long remaining = timeoutMillis - System.currentTimeMillis() + startTime; if (remaining <= 0) { return; @@ -69,19 +75,24 @@ public final class FixedThreadExecutor { } } - private record Worker(MpmcQueue channel) implements Runnable { + private record Worker(MpmcQueue channel, AtomicBoolean parked) implements Runnable { @Override public void run() { while (true) { final Runnable task = channel.recv(); if (task != null) { + parked.set(false); task.run(); } else if (SHUTDOWN) { break; } else if (channel.isEmpty()) { Thread.yield(); - if (channel.isEmpty()) { - LockSupport.park(); + if (parked.compareAndSet(false, true)) { + if (channel.isEmpty()) { + LockSupport.park(); + } else { + parked.set(false); + } } } }