9
0
mirror of https://github.com/Winds-Studio/Leaf.git synced 2025-12-25 09:59:15 +00:00

better locks, perf, readability on ConcurrentLongHashSet

This commit is contained in:
Taiyou06
2025-03-24 23:25:00 +01:00
parent 278b1d635b
commit f1a398f9b5

View File

@@ -1,265 +1,697 @@
package org.dreeam.leaf.util.map;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.longs.*;
import org.jetbrains.annotations.NotNull;
import java.util.Collection;
import java.util.Collections;
import java.util.NoSuchElementException;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
/**
* Optimized thread-safe implementation of {@link LongSet} that uses striped locking
* and primitive long arrays to minimize boxing/unboxing overhead.
*/
@SuppressWarnings({"unused", "deprecation"})
public final class ConcurrentLongHashSet extends LongOpenHashSet implements LongSet {
private static final int DEFAULT_SEGMENTS = 16; // Should be power-of-two
// Number of lock stripes - higher number means more concurrency but more memory
private static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// Load factor - when to resize the hash table
private static final float DEFAULT_LOAD_FACTOR = 0.75f;
// Initial capacity per stripe
private static final int DEFAULT_INITIAL_CAPACITY = 16;
// Array of segments (stripes)
private final Segment[] segments;
private final int segmentMask;
// Total size, cached for faster size() operation
private final AtomicInteger size;
/**
* Creates a new empty concurrent long set with default parameters.
*/
public ConcurrentLongHashSet() {
this(DEFAULT_SEGMENTS);
this(DEFAULT_CONCURRENCY_LEVEL * DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
@Override
public boolean removeAll(@NotNull Collection<?> c) {
Objects.requireNonNull(c, "Collection cannot be null");
boolean modified = false;
for (Object obj : c) {
if (obj instanceof Long) {
modified |= remove(obj);
}
/**
* Creates a new concurrent long set with the specified parameters.
*
* @param initialCapacity the initial capacity
* @param loadFactor the load factor
* @param concurrencyLevel the concurrency level
*/
public ConcurrentLongHashSet(int initialCapacity, float loadFactor, int concurrencyLevel) {
// Need to call super() even though we don't use its state
super();
// Validate parameters
if (initialCapacity < 0) {
throw new IllegalArgumentException("Initial capacity must be positive");
}
return modified;
}
@Override
public boolean retainAll(@NotNull Collection<?> c) {
Objects.requireNonNull(c, "Collection cannot be null");
boolean modified = false;
LongIterator iterator = iterator();
while (iterator.hasNext()) {
long key = iterator.nextLong();
if (!c.contains(key)) {
modified |= remove(key);
}
if (loadFactor <= 0 || Float.isNaN(loadFactor)) {
throw new IllegalArgumentException("Load factor must be positive");
}
return modified;
}
public ConcurrentLongHashSet(int concurrencyLevel) {
int numSegments = Integer.highestOneBit(concurrencyLevel) << 1;
this.segmentMask = numSegments - 1;
this.segments = new Segment[numSegments];
for (int i = 0; i < numSegments; i++) {
segments[i] = new Segment();
if (concurrencyLevel <= 0) {
throw new IllegalArgumentException("Concurrency level must be positive");
}
}
// ------------------- Core Methods -------------------
@Override
public boolean add(long key) {
Segment segment = getSegment(key);
segment.lock();
try {
return segment.set.add(key);
} finally {
segment.unlock();
// Calculate segment count (power of 2)
int segmentCount = 1;
while (segmentCount < concurrencyLevel) {
segmentCount <<= 1;
}
}
@Override
public boolean contains(long key) {
Segment segment = getSegment(key);
segment.lock();
try {
return segment.set.contains(key);
} finally {
segment.unlock();
// Calculate capacity per segment
int segmentCapacity = Math.max(initialCapacity / segmentCount, DEFAULT_INITIAL_CAPACITY);
// Create segments
this.segments = new Segment[segmentCount];
for (int i = 0; i < segmentCount; i++) {
this.segments[i] = new Segment(segmentCapacity, loadFactor);
}
this.size = new AtomicInteger(0);
}
@Override
public boolean remove(long key) {
Segment segment = getSegment(key);
segment.lock();
try {
return segment.set.remove(key);
} finally {
segment.unlock();
}
}
// ------------------- Bulk Operations -------------------
@Override
public boolean containsAll(@NotNull Collection<?> c) {
Objects.requireNonNull(c, "Collection cannot be null");
for (Object obj : c) {
if (obj == null || !(obj instanceof Long)) return false;
if (!contains(obj)) return false;
}
return true;
}
@Override
public boolean addAll(@NotNull Collection<? extends Long> c) {
Objects.requireNonNull(c, "Collection cannot be null");
boolean modified = false;
for (Long value : c) {
modified |= add(value);
}
return modified;
}
// ------------------- Locking Helpers -------------------
private Segment getSegment(long key) {
int hash = spreadHash(Long.hashCode(key));
return segments[hash & segmentMask];
}
private static int spreadHash(int h) {
return (h ^ (h >>> 16)) & 0x7fffffff; // Avoid negative indices
}
// ------------------- Size Stuff -------------------
@Override
public int size() {
int count = 0;
for (Segment segment : segments) {
segment.lock();
count += segment.set.size();
segment.unlock();
}
return count;
return size.get();
}
@Override
public boolean isEmpty() {
for (Segment segment : segments) {
segment.lock();
boolean empty = segment.set.isEmpty();
segment.unlock();
if (!empty) return false;
}
return true;
return size.get() == 0;
}
@Override
public boolean add(long key) {
Segment segment = segmentFor(key);
int delta = segment.add(key) ? 1 : 0;
if (delta > 0) {
size.addAndGet(delta);
}
return delta > 0;
}
@Override
public boolean contains(long key) {
return segmentFor(key).contains(key);
}
@Override
public boolean remove(long key) {
Segment segment = segmentFor(key);
int delta = segment.remove(key) ? -1 : 0;
if (delta < 0) {
size.addAndGet(delta);
}
return delta < 0;
}
// ------------------- Cleanup -------------------
@Override
public void clear() {
for (Segment segment : segments) {
segment.lock();
segment.set.clear();
segment.unlock();
segment.clear();
}
size.set(0);
}
// ------------------- Iteration -------------------
@Override
public LongIterator iterator() {
return new CompositeLongIterator();
public @NotNull LongIterator iterator() {
return new ConcurrentLongIterator();
}
private class CompositeLongIterator implements LongIterator {
private int currentSegment = 0;
private LongIterator currentIterator;
CompositeLongIterator() {
advanceSegment();
}
private void advanceSegment() {
while (currentSegment < segments.length) {
segments[currentSegment].lock();
currentIterator = segments[currentSegment].set.iterator();
if (currentIterator.hasNext()) break;
segments[currentSegment].unlock();
currentSegment++;
}
}
@Override
public boolean hasNext() {
if (currentIterator == null) return false;
if (currentIterator.hasNext()) return true;
segments[currentSegment].unlock();
currentSegment++;
advanceSegment();
return currentIterator != null && currentIterator.hasNext();
}
@Override
public long nextLong() {
if (!hasNext()) throw new NoSuchElementException();
return currentIterator.nextLong();
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
// ------------------- Segment (these nuts) -------------------
private static class Segment {
final LongOpenHashSet set = new LongOpenHashSet();
final ReentrantLock lock = new ReentrantLock();
void lock() {
lock.lock();
}
void unlock() {
lock.unlock();
}
}
// ignore
@Override
public long[] toLongArray() {
long[] result = new long[size()];
int i = 0;
LongIterator it = iterator();
while (it.hasNext()) {
result[i++] = it.nextLong();
int index = 0;
for (Segment segment : segments) {
index = segment.toLongArray(result, index);
}
return result;
}
@Override
public long[] toArray(long[] a) {
public long[] toArray(long[] array) {
Objects.requireNonNull(array, "Array cannot be null");
long[] result = toLongArray();
if (a.length < result.length) return result;
System.arraycopy(result, 0, a, 0, result.length);
return a;
if (array.length < result.length) {
return result;
}
System.arraycopy(result, 0, array, 0, result.length);
if (array.length > result.length) {
array[result.length] = 0;
}
return array;
}
@NotNull
@Override
public Object @NotNull [] toArray() {
Long[] result = new Long[size()];
int index = 0;
for (Segment segment : segments) {
index = segment.toObjectArray(result, index);
}
return result;
}
@NotNull
@Override
public <T> T @NotNull [] toArray(@NotNull T @NotNull [] array) {
Objects.requireNonNull(array, "Array cannot be null");
Long[] result = new Long[size()];
int index = 0;
for (Segment segment : segments) {
index = segment.toObjectArray(result, index);
}
if (array.length < result.length) {
return (T[]) result;
}
System.arraycopy(result, 0, array, 0, result.length);
if (array.length > result.length) {
array[result.length] = null;
}
return array;
}
@Override
public boolean containsAll(@NotNull Collection<?> collection) {
Objects.requireNonNull(collection, "Collection cannot be null");
for (Object o : collection) {
if (o instanceof Long) {
if (!contains(((Long) o).longValue())) {
return false;
}
} else {
return false;
}
}
return true;
}
@Override
public boolean addAll(@NotNull Collection<? extends Long> collection) {
Objects.requireNonNull(collection, "Collection cannot be null");
boolean modified = false;
for (Long value : collection) {
modified |= add(value);
}
return modified;
}
@Override
public boolean removeAll(@NotNull Collection<?> collection) {
Objects.requireNonNull(collection, "Collection cannot be null");
boolean modified = false;
for (Object o : collection) {
if (o instanceof Long) {
modified |= remove(((Long) o).longValue());
}
}
return modified;
}
@Override
public boolean retainAll(@NotNull Collection<?> collection) {
Objects.requireNonNull(collection, "Collection cannot be null");
// Convert collection to a set of longs for faster lookups
LongOpenHashSet toRetain = new LongOpenHashSet();
for (Object o : collection) {
if (o instanceof Long) {
toRetain.add(((Long) o).longValue());
}
}
boolean modified = false;
for (Segment segment : segments) {
modified |= segment.retainAll(toRetain);
}
if (modified) {
// Recalculate size
int newSize = 0;
for (Segment segment : segments) {
newSize += segment.size();
}
size.set(newSize);
}
return modified;
}
@Override
public boolean addAll(LongCollection c) {
Objects.requireNonNull(c, "Collection cannot be null");
boolean modified = false;
LongIterator iterator = c.iterator();
while (iterator.hasNext()) {
modified |= add(iterator.nextLong());
}
return modified;
}
@Override
public boolean containsAll(LongCollection c) {
Objects.requireNonNull(c, "Collection cannot be null");
LongIterator iterator = c.iterator();
while (iterator.hasNext()) {
if (!contains(iterator.nextLong())) {
return false;
}
}
return true;
}
@Override
public boolean removeAll(LongCollection c) {
Objects.requireNonNull(c, "Collection cannot be null");
boolean modified = false;
LongIterator iterator = c.iterator();
while (iterator.hasNext()) {
modified |= remove(iterator.nextLong());
}
return modified;
}
@Override
public boolean retainAll(LongCollection c) {
Objects.requireNonNull(c, "Collection cannot be null");
// For LongCollection we can directly use it
boolean modified = false;
for (Segment segment : segments) {
modified |= segment.retainAll(c);
}
if (modified) {
// Recalculate size
int newSize = 0;
for (Segment segment : segments) {
newSize += segment.size();
}
size.set(newSize);
}
return modified;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof LongSet that)) return false;
return size() == that.size() && containsAll(that);
if (size() != that.size()) return false;
return containsAll(that);
}
@Override
public int hashCode() {
int hash = 0;
LongIterator it = iterator();
while (it.hasNext()) {
hash += Long.hashCode(it.nextLong());
for (Segment segment : segments) {
hash += segment.hashCode();
}
return hash;
}
@Override
@NotNull
public Object[] toArray() {
return Collections.unmodifiableSet(this).toArray();
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append('[');
LongIterator it = iterator();
boolean hasNext = it.hasNext();
while (hasNext) {
sb.append(it.nextLong());
hasNext = it.hasNext();
if (hasNext) {
sb.append(", ");
}
}
sb.append(']');
return sb.toString();
}
@Override
@NotNull
public <T> T[] toArray(@NotNull T[] a) {
return Collections.unmodifiableSet(this).toArray(a);
/**
* Find the segment for a given key.
*/
private Segment segmentFor(long key) {
// Use high bits of hash to determine segment
// This helps spread keys more evenly across segments
return segments[(int) ((spread(key) >>> segmentShift()) & segmentMask())];
}
/**
* Spread bits to reduce clustering for keys with similar hash codes.
*/
private static long spread(long key) {
long h = key;
h ^= h >>> 32;
h ^= h >>> 16;
h ^= h >>> 8;
return h;
}
private int segmentShift() {
return Integer.numberOfLeadingZeros(segments.length);
}
private int segmentMask() {
return segments.length - 1;
}
/**
* A segment is a striped portion of the hash set with its own lock.
*/
private static class Segment {
private final ReentrantLock lock = new ReentrantLock();
private long[] keys;
private boolean[] used;
private int size;
private int threshold;
private final float loadFactor;
Segment(int initialCapacity, float loadFactor) {
int capacity = MathUtil.nextPowerOfTwo(initialCapacity);
this.keys = new long[capacity];
this.used = new boolean[capacity];
this.size = 0;
this.loadFactor = loadFactor;
this.threshold = (int) (capacity * loadFactor);
}
int size() {
lock.lock();
try {
return size;
} finally {
lock.unlock();
}
}
boolean contains(long key) {
lock.lock();
try {
int index = indexOf(key);
return used[index] && keys[index] == key;
} finally {
lock.unlock();
}
}
boolean add(long key) {
lock.lock();
try {
int index = indexOf(key);
// Key already exists
if (used[index] && keys[index] == key) {
return false;
}
// Insert key
keys[index] = key;
if (!used[index]) {
used[index] = true;
size++;
// Check if rehash is needed
if (size > threshold) {
rehash();
}
}
return true;
} finally {
lock.unlock();
}
}
boolean remove(long key) {
lock.lock();
try {
int index = indexOf(key);
// Key not found
if (!used[index] || keys[index] != key) {
return false;
}
// Mark slot as unused
used[index] = false;
size--;
// If the next slot is also used, we need to handle the removal properly
// to maintain the open addressing property
// This rehashing serves as a "cleanup" after removal
if (size > 0) {
rehashFromIndex(index);
}
return true;
} finally {
lock.unlock();
}
}
void clear() {
lock.lock();
try {
for (int i = 0; i < used.length; i++) {
used[i] = false;
}
size = 0;
} finally {
lock.unlock();
}
}
int toLongArray(long[] array, int offset) {
lock.lock();
try {
for (int i = 0; i < keys.length; i++) {
if (used[i]) {
array[offset++] = keys[i];
}
}
return offset;
} finally {
lock.unlock();
}
}
int toObjectArray(Long[] array, int offset) {
lock.lock();
try {
for (int i = 0; i < keys.length; i++) {
if (used[i]) {
array[offset++] = keys[i];
}
}
return offset;
} finally {
lock.unlock();
}
}
boolean retainAll(LongCollection toRetain) {
lock.lock();
try {
boolean modified = false;
for (int i = 0; i < keys.length; i++) {
if (used[i] && !toRetain.contains(keys[i])) {
used[i] = false;
size--;
modified = true;
}
}
// Rehash to clean up if needed
if (modified && size > 0) {
rehash();
}
return modified;
} finally {
lock.unlock();
}
}
/**
* Find the index where a key should be stored.
* Uses linear probing for collision resolution.
*/
private int indexOf(long key) {
int mask = keys.length - 1;
int index = (int) (spread(key) & mask);
while (used[index] && keys[index] != key) {
index = (index + 1) & mask;
}
return index;
}
/**
* Rehash the segment with a larger capacity.
*/
private void rehash() {
int oldCapacity = keys.length;
int newCapacity = oldCapacity * 2;
long[] oldKeys = keys;
boolean[] oldUsed = used;
keys = new long[newCapacity];
used = new boolean[newCapacity];
size = 0;
threshold = (int) (newCapacity * loadFactor);
// Re-add all keys
for (int i = 0; i < oldCapacity; i++) {
if (oldUsed[i]) {
add(oldKeys[i]);
}
}
}
/**
* Rehash from a specific index after removal to maintain proper open addressing.
*/
private void rehashFromIndex(int startIndex) {
int mask = keys.length - 1;
int currentIndex = startIndex;
int nextIndex = (currentIndex + 1) & mask;
// For each cluster of used slots following the removal point
while (used[nextIndex]) {
long key = keys[nextIndex];
int targetIndex = (int) (spread(key) & mask);
// If the key's ideal position is between the removal point and the current position,
// move it to the removal point
if ((targetIndex <= currentIndex && currentIndex < nextIndex) ||
(nextIndex < targetIndex && targetIndex <= currentIndex) ||
(currentIndex < nextIndex && nextIndex < targetIndex)) {
keys[currentIndex] = keys[nextIndex];
used[currentIndex] = true;
used[nextIndex] = false;
currentIndex = nextIndex;
}
nextIndex = (nextIndex + 1) & mask;
}
}
@Override
public int hashCode() {
lock.lock();
try {
int hash = 0;
for (int i = 0; i < keys.length; i++) {
if (used[i]) {
hash += Long.hashCode(keys[i]);
}
}
return hash;
} finally {
lock.unlock();
}
}
}
/**
* Concurrent iterator for the set.
*/
private class ConcurrentLongIterator implements LongIterator {
private int segmentIndex;
private int keyIndex;
private long lastReturned;
private boolean lastReturnedValid;
ConcurrentLongIterator() {
segmentIndex = 0;
keyIndex = 0;
lastReturnedValid = false;
advance();
}
@Override
public boolean hasNext() {
return segmentIndex < segments.length;
}
@Override
public long nextLong() {
if (!hasNext()) {
throw new java.util.NoSuchElementException();
}
lastReturned = segments[segmentIndex].keys[keyIndex];
lastReturnedValid = true;
advance();
return lastReturned;
}
@Override
public Long next() {
return nextLong();
}
@Override
public void remove() {
if (!lastReturnedValid) {
throw new IllegalStateException();
}
ConcurrentLongHashSet.this.remove(lastReturned);
lastReturnedValid = false;
}
private void advance() {
while (segmentIndex < segments.length) {
Segment segment = segments[segmentIndex];
// Lock the segment to get a consistent view
segment.lock.lock();
try {
while (keyIndex < segment.keys.length) {
if (segment.used[keyIndex]) {
// Found next element
return;
}
keyIndex++;
}
} finally {
segment.lock.unlock();
}
// Move to next segment
segmentIndex++;
keyIndex = 0;
}
}
}
/**
* Utility class for math operations.
*/
private static class MathUtil {
/**
* Returns the next power of two greater than or equal to the given value.
*/
static int nextPowerOfTwo(int value) {
int highestBit = Integer.highestOneBit(value);
return value > highestBit ? highestBit << 1 : value;
}
}
}