harry-core/src/harry/generators/Surjections.java (125 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package harry.generators;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.function.LongFunction;
import java.util.function.Supplier;
import harry.core.VisibleForTesting;
public class Surjections
{
public static <T> Surjection<T> constant(Supplier<T> constant)
{
return (current) -> constant.get();
}
public static <T> Surjection<T> constant(T constant)
{
return (current) -> constant;
}
public static <T> Surjection<T> pick(List<T> ts)
{
return new Surjection<T>()
{
public T inflate(long current)
{
return ts.get(RngUtils.asInt(current, 0, ts.size() - 1));
}
@Override
public String toString()
{
return String.format("Surjection#pick{from=%s}", ts);
}
};
}
public static long[] weights(int... weights)
{
long[] res = new long[weights.length];
for (int i = 0; i < weights.length; i++)
{
long w = weights[i];
res[i] = w << 32 | i;
}
return res;
}
public static <T> Surjection<T> weighted(int[] weights, T... items)
{
return weighted(weights(weights), items);
}
public static <T> Surjection<T> weighted(long[] weights, T... items)
{
assert weights.length == items.length;
Arrays.sort(weights);
TreeMap<Integer, T> weightMap = new TreeMap<Integer, T>();
int prev = 0;
for (int i = 0; i < weights.length; i++)
{
long orig = weights[i];
int weight = (int) (orig >> 32);
int idx = (int) orig;
weightMap.put(prev, items[idx]);
prev += weight;
}
return (i) -> {
int weight = RngUtils.asInt(i, 0, 100);
return weightMap.floorEntry(weight).getValue();
};
}
public static <T> Surjection<T> weighted(Map<T, Integer> weights)
{
TreeMap<Integer, T> weightMap = new TreeMap<Integer, T>();
int sum = 0;
for (Map.Entry<T, Integer> entry : weights.entrySet())
{
sum += entry.getValue();
weightMap.put(sum, entry.getKey());
}
int max = sum;
return (i) -> {
int weight = RngUtils.asInt(i, 0, max);
return weightMap.ceilingEntry(weight).getValue();
};
}
public static <T> Surjection<T> pick(T... ts)
{
return pick(Arrays.asList(ts));
}
public static <T extends Enum<T>> Surjection<T> enumValues(Class<T> e)
{
return pick(Arrays.asList(e.getEnumConstants()));
}
public interface Surjection<T>
{
T inflate(long descriptor);
default <T1> Surjection<T1> map(Function<T, T1> map)
{
return (current) -> map.apply(inflate(current));
}
default LongFunction<T> toFn()
{
return new LongFunction<T>()
{
public T apply(long value)
{
return inflate(value);
}
};
}
default Generator<T> toGenerator()
{
return new Generator<T>()
{
public T generate(RandomGenerator rng)
{
return inflate(rng.next());
}
};
}
@VisibleForTesting
default Supplier<T> toSupplier()
{
RandomGenerator rng = new PcgRSUFast(0, 0);
return () -> inflate(rng.next());
}
}
}