9
0
mirror of https://github.com/Winds-Studio/Leaf.git synced 2025-12-20 07:29:24 +00:00

fix unpark race

This commit is contained in:
hayanesuru
2025-07-16 15:24:59 +09:00
parent cdd379f424
commit cecd8d751a

View File

@@ -5,10 +5,12 @@ import org.dreeam.leaf.util.queue.MpmcQueue;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.FutureTask; import java.util.concurrent.FutureTask;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.LockSupport;
public final class FixedThreadExecutor { public final class FixedThreadExecutor {
private final Thread[] workers; private final Thread[] threads;
private final Worker[] workers;
public final MpmcQueue<Runnable> channel; public final MpmcQueue<Runnable> channel;
private static volatile boolean SHUTDOWN = false; private static volatile boolean SHUTDOWN = false;
@@ -16,15 +18,17 @@ public final class FixedThreadExecutor {
if (numThreads <= 0) { if (numThreads <= 0) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
} }
this.workers = new Thread[numThreads]; this.threads = new Thread[numThreads];
this.workers = new Worker[numThreads];
this.channel = new MpmcQueue<>(Runnable.class, queue); this.channel = new MpmcQueue<>(Runnable.class, queue);
for (int i = 0; i < numThreads; i++) { for (int i = 0; i < numThreads; i++) {
workers[i] = Thread.ofPlatform() workers[i] = new Worker(channel, new AtomicBoolean(false));
threads[i] = Thread.ofPlatform()
.uncaughtExceptionHandler(Util::onThreadException) .uncaughtExceptionHandler(Util::onThreadException)
.daemon(false) .daemon(false)
.priority(Thread.NORM_PRIORITY) .priority(Thread.NORM_PRIORITY)
.name(prefix + " - " + i) .name(prefix + " - " + i)
.start(new Worker(channel)); .start(workers[i]);
} }
} }
@@ -41,15 +45,17 @@ public final class FixedThreadExecutor {
} }
public void unpack() { public void unpack() {
int size = Math.min(Math.max(1, channel.length()), workers.length); int size = Math.min(Math.max(1, channel.length()), threads.length);
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
LockSupport.unpark(workers[i]); if (workers[i].parked.get()) {
LockSupport.unpark(threads[i]);
}
} }
} }
public void shutdown() { public void shutdown() {
SHUTDOWN = true; SHUTDOWN = true;
for (final Thread worker : workers) { for (final Thread worker : threads) {
LockSupport.unpark(worker); LockSupport.unpark(worker);
} }
} }
@@ -57,7 +63,7 @@ public final class FixedThreadExecutor {
public void join(long timeoutMillis) throws InterruptedException { public void join(long timeoutMillis) throws InterruptedException {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
for (final Thread worker : workers) { for (final Thread worker : threads) {
final long remaining = timeoutMillis - System.currentTimeMillis() + startTime; final long remaining = timeoutMillis - System.currentTimeMillis() + startTime;
if (remaining <= 0) { if (remaining <= 0) {
return; return;
@@ -69,19 +75,24 @@ public final class FixedThreadExecutor {
} }
} }
private record Worker(MpmcQueue<Runnable> channel) implements Runnable { private record Worker(MpmcQueue<Runnable> channel, AtomicBoolean parked) implements Runnable {
@Override @Override
public void run() { public void run() {
while (true) { while (true) {
final Runnable task = channel.recv(); final Runnable task = channel.recv();
if (task != null) { if (task != null) {
parked.set(false);
task.run(); task.run();
} else if (SHUTDOWN) { } else if (SHUTDOWN) {
break; break;
} else if (channel.isEmpty()) { } else if (channel.isEmpty()) {
Thread.yield(); Thread.yield();
if (channel.isEmpty()) { if (parked.compareAndSet(false, true)) {
LockSupport.park(); if (channel.isEmpty()) {
LockSupport.park();
} else {
parked.set(false);
}
} }
} }
} }