9
0
mirror of https://github.com/Xiao-MoMi/craft-engine.git synced 2025-12-31 21:06:31 +00:00

更新数字类型

This commit is contained in:
XiaoMoMi
2025-12-30 22:51:35 +08:00
parent 39f7df2fdc
commit 9e3ea149e5
7 changed files with 752 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
package net.momirealms.craftengine.core.plugin.context.number;
import net.momirealms.craftengine.core.plugin.context.Context;
import net.momirealms.craftengine.core.util.ResourceConfigUtils;
import org.jetbrains.annotations.NotNull;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
/**
* 贝塔分布提供器
* 极其灵活的分布,通过 alpha 和 beta 参数控制区间 [min, max] 内的形状
*/
public record BetaNumberProvider(
double min,
double max,
double alpha, // 形状参数 α
double beta // 形状参数 β
) implements NumberProvider {
public static final NumberProviderFactory<BetaNumberProvider> FACTORY = new Factory();
public BetaNumberProvider {
if (min >= max) throw new IllegalArgumentException("min < max required");
if (alpha <= 0 || beta <= 0) throw new IllegalArgumentException("alpha, beta > 0 required");
}
@Override
public double getDouble(Context context) {
// 使用针对不同参数范围优化的生成算法
double x = generateStandardBeta(this.alpha, this.beta);
// 将 [0, 1] 映射到 [min, max]
return this.min + x * (this.max - this.min);
}
/**
* 生成标准 Beta(α, β) 分布 (范围 [0, 1])
* 采用受阻采样法 (Rejection Sampling)
*/
private double generateStandardBeta(double a, double b) {
ThreadLocalRandom random = ThreadLocalRandom.current();
// 特例优化:如果 α=1, β=1退化为均匀分布
if (Math.abs(a - 1.0) < 1e-6 && Math.abs(b - 1.0) < 1e-6) {
return random.nextDouble();
}
// 简化的受阻采样实现 (针对 a, b >= 1 的常见场景)
// 生产环境下如果 a, b < 1通常建议使用 Gamma 分布转换法
while (true) {
double u1 = random.nextDouble();
double u2 = random.nextDouble();
double x = Math.pow(u1, 1.0 / a);
double y = Math.pow(u2, 1.0 / b);
if (x + y <= 1.0) {
return x / (x + y);
}
}
}
@Override
public int getInt(Context context) {
return (int) Math.round(getDouble(context));
}
@Override
public float getFloat(Context context) {
return (float) getDouble(context);
}
private static class Factory implements NumberProviderFactory<BetaNumberProvider> {
@Override
public BetaNumberProvider create(Map<String, Object> arguments) {
double min = ResourceConfigUtils.getAsDouble(arguments.getOrDefault("min", 0.0), "min");
double max = ResourceConfigUtils.getAsDouble(arguments.getOrDefault("max", 1.0), "max");
// α 和 β 的默认值通常设为 2.0 (形成一个平滑的中间高两头低的弧线)
double alpha = ResourceConfigUtils.getAsDouble(arguments.getOrDefault("alpha", 2.0), "alpha");
double beta = ResourceConfigUtils.getAsDouble(arguments.getOrDefault("beta", 2.0), "beta");
return new BetaNumberProvider(min, max, alpha, beta);
}
}
}

View File

@@ -0,0 +1,101 @@
package net.momirealms.craftengine.core.plugin.context.number;
import net.momirealms.craftengine.core.plugin.context.Context;
import net.momirealms.craftengine.core.util.MiscUtils;
import net.momirealms.craftengine.core.util.ResourceConfigUtils;
import org.jetbrains.annotations.NotNull;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
/**
* 指数分布提供器
* 用于描述独立随机事件发生的时间间隔
* 参数 lambda (λ) 是单位时间内事件发生的平均次数 (率参数)
*/
public record ExponentialNumberProvider(
double min,
double max,
double lambda,
int maxAttempts
) implements NumberProvider {
public static final NumberProviderFactory<ExponentialNumberProvider> FACTORY = new Factory();
public ExponentialNumberProvider {
if (min >= max) {
throw new IllegalArgumentException("min must be less than max");
}
if (lambda <= 0) {
throw new IllegalArgumentException("lambda must be greater than 0");
}
if (maxAttempts <= 0) {
throw new IllegalArgumentException("max-attempts must be greater than 0");
}
}
@Override
public int getInt(Context context) {
return (int) Math.round(getDouble(context));
}
@Override
public float getFloat(Context context) {
return (float) getDouble(context);
}
@Override
public double getDouble(Context context) {
for (int i = 0; i < this.maxAttempts; i++) {
// 逆变换采样法 (Inverse Transform Sampling)
// 公式: X = -ln(1 - U) / λ 或者简单的 -ln(U) / λ
// 其中 U 是 [0, 1) 之间的均匀分布随机数
double u = ThreadLocalRandom.current().nextDouble();
// 防止 u 为 0 导致 ln(0) 出现负无穷
if (u < 1e-10) continue;
double value = -Math.log(u) / this.lambda;
if (value >= this.min && value <= this.max) {
return value;
}
}
// 失败回退:返回 1/lambda (分布的期望均值)
return MiscUtils.clamp(1.0 / this.lambda, this.min, this.max);
}
private static class Factory implements NumberProviderFactory<ExponentialNumberProvider> {
@Override
public ExponentialNumberProvider create(Map<String, Object> arguments) {
double min = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("min", 0.0), "min");
double max = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("max", Double.MAX_VALUE), "max");
// 如果用户没填 lambda尝试从 mean (均值) 转换
// 指数分布中: mean = 1/lambda
double lambda;
if (arguments.containsKey("mean")) {
double mean = ResourceConfigUtils.getAsDouble(arguments.get("mean"), "mean");
lambda = 1.0 / mean;
} else {
lambda = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("lambda"),
"warning.config.number.exponential.missing_lambda"), "lambda");
}
int maxAttempts = ResourceConfigUtils.getAsInt(
arguments.getOrDefault("max-attempts", 64), "max-attempts");
return new ExponentialNumberProvider(min, max, lambda, maxAttempts);
}
}
@Override
public @NotNull String toString() {
return String.format("ExponentialNumberProvider{min=%.2f, max=%.2f, lambda=%.4f, mean=%.2f}",
this.min, this.max, this.lambda, 1.0 / this.lambda);
}
}

View File

@@ -0,0 +1,190 @@
package net.momirealms.craftengine.core.plugin.context.number;
import net.momirealms.craftengine.core.plugin.context.Context;
import net.momirealms.craftengine.core.util.MiscUtils;
import net.momirealms.craftengine.core.util.ResourceConfigUtils;
import org.jetbrains.annotations.NotNull;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
/**
* 对数正态分布提供器 (Log-Normal Distribution)
* <p>
* 适用于描述诸如伤害值、金币掉落、经验获取等右偏分布的数据(大多数值较小,但偶尔有极高值)。
* <p>
* 参数说明:
* - location (μ): 对数变量的均值
* - scale (σ): 对数变量的标准差
* - 或者在配置中直接提供 mean (真实均值) 和 std-dev (真实标准差),工厂类会自动转换。
*/
public record LogNormalNumberProvider(
double min,
double max,
double location, // μ
double scale, // σ
int maxAttempts
) implements NumberProvider {
public static final NumberProviderFactory<LogNormalNumberProvider> FACTORY = new Factory();
private static final double EPSILON = 1e-6; // 防止 log(0) 的极小值
public LogNormalNumberProvider {
validateParameters(min, max, scale, maxAttempts);
}
private static void validateParameters(double min, double max, double scale, int maxAttempts) {
if (min >= max) {
throw new IllegalArgumentException("min must be less than max");
}
if (scale <= 0) {
throw new IllegalArgumentException("scale must be greater than 0");
}
if (maxAttempts <= 0) {
throw new IllegalArgumentException("max-attempts must be greater than 0");
}
// 对数正态分布定义域为 (0, +∞)min 必须大于 0
if (min <= 0) {
throw new IllegalArgumentException("min must be greater than 0 for log-normal distribution. If you need 0, consider shifting or clamping.");
}
}
@Override
public int getInt(Context context) {
return (int) Math.round(getDouble(context));
}
@Override
public float getFloat(Context context) {
return (float) getDouble(context);
}
@Override
public double getDouble(Context context) {
Random random = ThreadLocalRandom.current();
// 快速路径:如果范围极小,直接返回均值
if (max - min < EPSILON) {
return min;
}
for (int attempts = 0; attempts < this.maxAttempts; attempts++) {
// 核心算法X = exp(μ + σZ), 其中 Z ~ N(0, 1)
double normalValue = random.nextGaussian() * this.scale + this.location;
// 性能优化:在进行昂贵的 exp 运算前,先检查指数范围防止 Infinity
if (normalValue > 700) { // Math.exp(710) > Double.MAX_VALUE
continue;
}
double value = Math.exp(normalValue);
if (value >= this.min && value <= this.max) {
return value;
}
}
// 失败回退:返回限制在范围内的真实中位数
return MiscUtils.clamp(getRealMedian(), this.min, this.max);
}
/**
* 获取真实分布的均值 (Real Mean)
* E[X] = exp(μ + σ²/2)
*/
public double getRealMean() {
return Math.exp(this.location + (this.scale * this.scale) / 2.0);
}
/**
* 获取真实分布的中位数 (Real Median)
* Median[X] = exp(μ)
*/
public double getRealMedian() {
return Math.exp(this.location);
}
/**
* 获取真实分布的众数 (Real Mode)
* Mode[X] = exp(μ - σ²)
*/
public double getRealMode() {
return Math.exp(this.location - this.scale * this.scale);
}
/**
* 获取真实分布的标准差 (Real StdDev)
* SD[X] = sqrt( (exp(σ²)-1) * exp(2μ+σ²) )
*/
public double getRealStdDev() {
double var = (Math.exp(scale * scale) - 1) * Math.exp(2 * location + scale * scale);
return Math.sqrt(var);
}
private static class Factory implements NumberProviderFactory<LogNormalNumberProvider> {
@Override
public LogNormalNumberProvider create(Map<String, Object> arguments) {
double rawMin = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("min"),
"warning.config.number.log-normal.missing_min"), "min");
double max = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("max"),
"warning.config.number.log-normal.missing_max"), "max");
// 自动修正 min <= 0 的情况,防止 Log(0) 崩溃
// 如果用户配置 min=0我们将其修正为一个极小的正数
double min = Math.max(rawMin, EPSILON);
double location;
double scale;
// 优先检查用户是否直接配置了 mean (真实均值) 和 std-dev (真实标准差)
// 这对用户来说比配置 location/scale 直观得多
if (arguments.containsKey("mean") && arguments.containsKey("std-dev")) {
double realMean = ResourceConfigUtils.getAsDouble(arguments.get("mean"), "mean");
double realStdDev = ResourceConfigUtils.getAsDouble(arguments.get("std-dev"), "std-dev");
// 将真实均值/方差转换为对数正态分布参数 μ 和 σ
// μ = ln(mean^2 / sqrt(mean^2 + var))
// σ = sqrt(ln(1 + var/mean^2))
double meanSq = realMean * realMean;
double var = realStdDev * realStdDev;
scale = Math.sqrt(Math.log(1 + (var / meanSq)));
location = Math.log(meanSq / Math.sqrt(meanSq + var));
} else {
// 回退到使用 location/scale 或根据 min/max 估算
// 默认策略:假设 min 和 max 覆盖了大约 +/- 3个标准差的范围 (对数域)
// log(min) ≈ μ - 3σ
// log(max) ≈ μ + 3σ
double logMin = Math.log(min);
double logMax = Math.log(max);
double defaultLocation = (logMin + logMax) / 2.0;
double defaultScale = (logMax - logMin) / 6.0;
location = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("location", defaultLocation), "location");
scale = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("scale", defaultScale), "scale");
}
int maxAttempts = ResourceConfigUtils.getAsInt(
arguments.getOrDefault("max-attempts", 128), "max-attempts");
return new LogNormalNumberProvider(min, max, location, scale, maxAttempts);
}
}
@Override
public @NotNull String toString() {
return String.format(
"LogNormalNumberProvider{range=[%.2f, %.2f], location(μ)=%.2f, scale(σ)=%.2f, realMean≈%.2f, realStdDev≈%.2f}",
this.min, this.max, this.location, this.scale, getRealMean(), getRealStdDev()
);
}
}

View File

@@ -15,8 +15,15 @@ public final class NumberProviders {
public static final NumberProviderType<FixedNumberProvider> CONSTANT = register(Key.ce("constant"), FixedNumberProvider.FACTORY);
public static final NumberProviderType<UniformNumberProvider> UNIFORM = register(Key.ce("uniform"), UniformNumberProvider.FACTORY);
public static final NumberProviderType<ExpressionNumberProvider> EXPRESSION = register(Key.ce("expression"), ExpressionNumberProvider.FACTORY);
public static final NumberProviderType<GaussianNumberProvider> NORMAL = register(Key.ce("normal"), GaussianNumberProvider.FACTORY);
public static final NumberProviderType<GaussianNumberProvider> GAUSSIAN = register(Key.ce("gaussian"), GaussianNumberProvider.FACTORY);
public static final NumberProviderType<LogNormalNumberProvider> LOG_NORMAL = register(Key.ce("log_normal"), LogNormalNumberProvider.FACTORY);
public static final NumberProviderType<SkewNormalNumberProvider> SKEW_NORMAL = register(Key.ce("skew_normal"), SkewNormalNumberProvider.FACTORY);
public static final NumberProviderType<BinomialNumberProvider> BINOMIAL = register(Key.ce("binomial"), BinomialNumberProvider.FACTORY);
public static final NumberProviderType<WeightedNumberProvider> WEIGHTED = register(Key.ce("weighted"), WeightedNumberProvider.FACTORY);
public static final NumberProviderType<TriangleNumberProvider> TRIANGLE = register(Key.ce("triangle"), TriangleNumberProvider.FACTORY);
public static final NumberProviderType<ExponentialNumberProvider> EXPONENTIAL = register(Key.ce("exponential"), ExponentialNumberProvider.FACTORY);
public static final NumberProviderType<BetaNumberProvider> BETA = register(Key.ce("beta"), BetaNumberProvider.FACTORY);
private NumberProviders() {}

View File

@@ -0,0 +1,198 @@
package net.momirealms.craftengine.core.plugin.context.number;
import net.momirealms.craftengine.core.plugin.context.Context;
import net.momirealms.craftengine.core.util.MiscUtils;
import net.momirealms.craftengine.core.util.ResourceConfigUtils;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
/**
* 通用偏态分布提供器
* <p>
* 基于偏态正态分布Skew Normal Distribution实现。
* 通过 Azallini 的方法生成Z = δ|X| + sqrt(1-δ^2)Y
*/
public final class SkewNormalNumberProvider implements NumberProvider {
public static final NumberProviderFactory<SkewNormalNumberProvider> FACTORY = new Factory();
// 理论最大偏度 (approx 0.99527)
private static final double MAX_SKEWNESS = 0.995;
private final double min;
private final double max;
private final double targetMean;
private final double targetStdDev;
private final double skewness;
private final int maxAttempts;
// 预计算的分布参数
private final double delta; // 相关系数 δ
private final double sqrtOneMinusDeltaSq; // √(1 - δ²) 用于生成公式优化
private final double omega; // 尺度参数 ω (Scale)
private final double xi; // 位置参数 ξ (Location)
public SkewNormalNumberProvider(double min, double max, double mean, double stdDev, double skewness, int maxAttempts) {
this.min = min;
this.max = max;
this.targetMean = mean;
this.targetStdDev = stdDev;
this.skewness = skewness;
this.maxAttempts = maxAttempts;
validateParameters();
// 1. 根据偏度计算形状相关参数 δ
this.delta = calculateDelta(this.skewness);
// 2. 预计算生成公式中需要的常数,避免热点代码重复计算
this.sqrtOneMinusDeltaSq = Math.sqrt(1 - this.delta * this.delta);
// 3. 计算尺度参数 ω
// Var(X) = ω² * (1 - 2δ²/π) => ω = stdDev / sqrt(1 - 2δ²/π)
this.omega = stdDev / Math.sqrt(1 - (2 * this.delta * this.delta) / Math.PI);
// 4. 计算位置参数 ξ
// E[X] = ξ + ω * δ * sqrt(2/π) => ξ = mean - ω * δ * sqrt(2/π)
this.xi = mean - this.omega * this.delta * Math.sqrt(2.0 / Math.PI);
}
private void validateParameters() {
if (this.min >= this.max) {
throw new IllegalArgumentException("min must be less than max");
}
if (this.targetStdDev <= 0) {
throw new IllegalArgumentException("std-dev must be greater than 0");
}
if (this.maxAttempts <= 0) {
throw new IllegalArgumentException("max-attempts must be greater than 0");
}
// 严格限制偏度,防止数学计算错误
if (Math.abs(this.skewness) > MAX_SKEWNESS) {
throw new IllegalArgumentException("skewness absolute value must be <= " + MAX_SKEWNESS);
}
}
/**
* 根据目标偏度反推相关系数 δ
* 公式推导基于: |γ1| = (4-π)/2 * (δ*sqrt(2/π))^3 / (1 - 2δ²/π)^(3/2)
*/
private double calculateDelta(double skewness) {
if (Math.abs(skewness) < 1e-6) {
return 0.0;
}
double absGamma = Math.abs(skewness);
// 为了数值稳定性,再次钳制范围
absGamma = Math.min(absGamma, MAX_SKEWNESS);
double sign = skewness < 0 ? -1.0 : 1.0;
// 使用精确反函数解
double term1 = Math.pow(absGamma, 2.0 / 3.0);
double term2 = Math.pow((4.0 - Math.PI) / 2.0, 2.0 / 3.0);
double deltaAbs = Math.sqrt((Math.PI / 2.0) * term1 / (term1 + term2));
return sign * deltaAbs;
}
@Override
public int getInt(Context context) {
// 四舍五入取整
return (int) Math.round(getDouble(context));
}
@Override
public float getFloat(Context context) {
return (float) getDouble(context);
}
@Override
public double getDouble(Context context) {
// 如果没有偏度,直接使用更快的标准高斯生成
if (Math.abs(this.skewness) < 1e-6) {
return generateNormalBounded();
}
return generateSkewNormalBounded();
}
/**
* 生成有界偏态分布随机数
*/
private double generateSkewNormalBounded() {
Random random = ThreadLocalRandom.current();
for (int i = 0; i < this.maxAttempts; i++) {
// 生成标准正态变量
double u0 = random.nextGaussian();
double u1 = random.nextGaussian();
// 核心生成公式: Z = δ|U0| + √(1-δ²)U1
// 此时 Z 服从标准偏态正态分布 (Location=0, Scale=1, Shape=α)
double standardSkewed = this.delta * Math.abs(u0) + this.sqrtOneMinusDeltaSq * u1;
// 转换到目标均值和方差: X = ξ + ωZ
double value = this.xi + this.omega * standardSkewed;
if (value >= this.min && value <= this.max) {
return value;
}
}
// 失败回退:返回区间内受限的均值
return MiscUtils.clamp(this.targetMean, this.min, this.max);
}
/**
* 特例优化当偏度为0时正态分布使用更简单的逻辑
*/
private double generateNormalBounded() {
Random random = ThreadLocalRandom.current();
for (int i = 0; i < this.maxAttempts; i++) {
double value = this.targetMean + random.nextGaussian() * this.targetStdDev;
if (value >= this.min && value <= this.max) {
return value;
}
}
return MiscUtils.clamp(this.targetMean, this.min, this.max);
}
@Override
public String toString() {
return "SkewNormalNumberProvider{" +
"range=[" + this.min + ", " + this.max + "]" +
", mean=" + this.targetMean +
", stdDev=" + this.targetStdDev +
", skewness=" + this.skewness +
'}';
}
private static class Factory implements NumberProviderFactory<SkewNormalNumberProvider> {
@Override
public SkewNormalNumberProvider create(Map<String, Object> arguments) {
double min = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("min"),
"warning.config.number.skewed.missing_min"), "min");
double max = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("max"),
"warning.config.number.skewed.missing_max"), "max");
double defaultMean = (min + max) / 2.0;
double mean = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("mean", defaultMean), "mean");
// 默认标准差设为范围的 1/6 (类似 3-sigma 法则覆盖大部分范围)
double defaultStdDev = (max - min) / 6.0;
double stdDev = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("std-dev", defaultStdDev), "std-dev");
double skewness = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("skewness", 0.0), "skewness");
int maxAttempts = ResourceConfigUtils.getAsInt(
arguments.getOrDefault("max-attempts", 50), "max-attempts"); // 默认次数稍微降低通常128有点多
return new SkewNormalNumberProvider(min, max, mean, stdDev, skewness, maxAttempts);
}
}
}

View File

@@ -0,0 +1,83 @@
package net.momirealms.craftengine.core.plugin.context.number;
import net.momirealms.craftengine.core.plugin.context.Context;
import net.momirealms.craftengine.core.util.ResourceConfigUtils;
import org.jetbrains.annotations.NotNull;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
/**
* 三角形分布提供器
* 一种连续概率分布,其概率密度函数图像呈三角形
* 相比正态分布,它计算开销极低且天生有界
*/
public record TriangleNumberProvider(
double min,
double max,
double mode
) implements NumberProvider {
public static final NumberProviderFactory<TriangleNumberProvider> FACTORY = new Factory();
public TriangleNumberProvider {
if (min >= max) {
throw new IllegalArgumentException("min must be less than max");
}
if (mode < min || mode > max) {
throw new IllegalArgumentException("mode must be between min and max");
}
}
@Override
public int getInt(Context context) {
return (int) Math.round(getDouble(context));
}
@Override
public float getFloat(Context context) {
return (float) getDouble(context);
}
@Override
public double getDouble(Context context) {
double u = ThreadLocalRandom.current().nextDouble();
// 逆变换采样法 (Inverse Transform Sampling)
// 概率转折点F(mode) = (mode - min) / (max - min)
double fc = (this.mode - this.min) / (this.max - this.min);
if (u < fc) {
// 左半部分三角形
return this.min + Math.sqrt(u * (this.max - this.min) * (this.mode - this.min));
} else {
// 右半部分三角形
return this.max - Math.sqrt((1 - u) * (this.max - this.min) * (this.max - this.mode));
}
}
private static class Factory implements NumberProviderFactory<TriangleNumberProvider> {
@Override
public TriangleNumberProvider create(Map<String, Object> arguments) {
double min = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("min"),
"warning.config.number.triangle.missing_min"), "min");
double max = ResourceConfigUtils.getAsDouble(
ResourceConfigUtils.requireNonNullOrThrow(arguments.get("max"),
"warning.config.number.triangle.missing_max"), "max");
// 默认众数在正中间(等腰三角形)
double defaultMode = (min + max) / 2.0;
double mode = ResourceConfigUtils.getAsDouble(
arguments.getOrDefault("mode", defaultMode), "mode");
return new TriangleNumberProvider(min, max, mode);
}
}
@Override
public @NotNull String toString() {
return String.format("TriangleNumberProvider{min=%.2f, max=%.2f, mode=%.2f}", this.min, this.max, this.mode);
}
}

View File

@@ -0,0 +1,89 @@
package net.momirealms.craftengine.core.plugin.context.number;
import net.momirealms.craftengine.core.plugin.context.Context;
import net.momirealms.craftengine.core.util.ResourceConfigUtils;
import org.jetbrains.annotations.NotNull;
import java.util.HashMap;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.concurrent.ThreadLocalRandom;
/**
* 权重随机提供器
* 根据配置的权重比例随机选择一个数值
*/
public final class WeightedNumberProvider implements NumberProvider {
public static final NumberProviderFactory<WeightedNumberProvider> FACTORY = new Factory();
// 使用 TreeMap 存储前缀和,便于使用 higherEntry 进行二分查找
private final NavigableMap<Double, Double> weightMap = new TreeMap<>();
private final double totalWeight;
public WeightedNumberProvider(Map<Double, Double> inputWeights) {
double sum = 0;
for (Map.Entry<Double, Double> entry : inputWeights.entrySet()) {
double value = entry.getKey();
double weight = entry.getValue();
if (weight > 0) {
sum += weight;
// 存储累计权重 -> 目标值
this.weightMap.put(sum, value);
}
}
this.totalWeight = sum;
if (this.weightMap.isEmpty()) {
throw new IllegalArgumentException("Weighted provider must have at least one positive weight entry");
}
}
@Override
public int getInt(Context context) {
return (int) Math.round(getDouble(context));
}
@Override
public float getFloat(Context context) {
return (float) getDouble(context);
}
@Override
public double getDouble(Context context) {
// 生成 [0, totalWeight) 之间的随机数
double randomValue = ThreadLocalRandom.current().nextDouble() * totalWeight;
// 查找第一个累计权重值大于 randomValue 的条目 (二分查找O(log N))
Map.Entry<Double, Double> entry = weightMap.higherEntry(randomValue);
if (entry == null) {
return weightMap.lastEntry().getValue();
}
return entry.getValue();
}
private static class Factory implements NumberProviderFactory<WeightedNumberProvider> {
@Override
public WeightedNumberProvider create(Map<String, Object> arguments) {
// 期望配置格式:
// weights:
// "1.0": 50
// "2.0": 30
// "5.0": 20
Map<String, Object> weightsObj = ResourceConfigUtils.getAsMap(arguments.get("weights"), "weights");
Map<Double, Double> processedWeights = new HashMap<>();
for (Map.Entry<String, Object> entry : weightsObj.entrySet()) {
double value = Double.parseDouble(entry.getKey());
double weight = Double.parseDouble(String.valueOf(entry.getValue()));
processedWeights.put(value, weight);
}
return new WeightedNumberProvider(processedWeights);
}
}
@Override
public @NotNull String toString() {
return "WeightedNumberProvider{entries=" + this.weightMap.size() + ", totalWeight=" + this.totalWeight + "}";
}
}