distribution morphing support

This commit is contained in:
Jonathan Shook 2025-02-18 03:02:39 -06:00
parent fde7b556ec
commit 7952b7867e
9 changed files with 625 additions and 0 deletions

View File

@ -0,0 +1,124 @@
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.core.stathelpers;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.function.DoubleToIntFunction;
import java.util.function.DoubleToLongFunction;
import java.util.stream.Collectors;
/**
* Uses the alias sampling method to encode and sample from discrete probabilities,
* even over larger sets of data. This form requires a unit interval sample value
* between 0.0 and 1.0. Assuming the maximal amount of memory is used for distinct
* outcomes N, a memory buffer of N*16 bytes is required for this implementation,
* requiring 32MB of memory for 1M entries.
*
* This sampler should be shared between threads, and will be by default, in order
* to avoid many instances of a 32MB buffer on heap.
*/
public class AliasSamplerDoubleLong implements DoubleToLongFunction {
private final ByteBuffer stats; // tuples of double,int,int (unfair coin, direct pointers to referents)
private final double slotCount; // The number of fair die-roll slotCount that contain unfair coin probabilities
private static final int _r0=0;
private static final int _r1=_r0+Double.BYTES; // unfair coin
private static final int _r2=_r1+Long.BYTES; // + referent 1
public static int RECORD_LEN = _r2 + Long.BYTES; // + referent 2 = Record size for the above.
// for testing
AliasSamplerDoubleLong(ByteBuffer stats) {
this.stats = stats;
if ((stats.capacity()% RECORD_LEN)!=0) {
throw new RuntimeException("Misaligned ByteBuffer size, must be a multiple of " + RECORD_LEN);
}
slotCount = (stats.capacity()/ RECORD_LEN);
}
public AliasSamplerDoubleLong(List<EvProbLongDouble> events) {
int size = events.size();
int[] alias = new int[events.size()];
double[] prob = new double[events.size()];
LinkedList<EvProbLongDouble> small = new LinkedList<>();
LinkedList<EvProbLongDouble> large = new LinkedList<>();
List<Slot> slots = new ArrayList<>();
// array-size normalization
double sumProbability = events.stream().mapToDouble(EvProbLongDouble::prob).sum();
events = events.stream().map(e -> new EvProbLongDouble(e.id(),
(e.prob()/sumProbability)*size)).collect(Collectors.toList());
// presort
for (EvProbLongDouble event : events) {
(event.prob()<1.0D ? small : large).addLast(event);
}
while (small.peekFirst()!=null && large.peekFirst()!=null) {
EvProbLongDouble l = small.removeFirst();
EvProbLongDouble g = large.removeFirst();
slots.add(new Slot(g.id(), l.id(), l.prob()));
EvProbLongDouble remainder = new EvProbLongDouble(g.id(),(g.prob()+l.prob())-1);
(remainder.prob()<1.0D ? small : large).addLast(remainder);
}
while (large.peekFirst()!=null) {
EvProbLongDouble g = large.removeFirst();
slots.add(new Slot(g.id(),g.id(),1.0));
}
while (small.peekFirst()!=null) {
EvProbLongDouble l = small.removeFirst();
slots.add(new Slot(l.id(),l.id(),1.0));
}
if (slots.size()!=size) {
throw new RuntimeException("basis for average probability is incorrect, because only " + slots.size() + " slotCount of " + size + " were created.");
}
// align to indexes
for (int i = 0; i < slots.size(); i++) {
slots.set(i,slots.get(i).rescale(i, i+1));
}
this.stats = ByteBuffer.allocate(slots.size()* RECORD_LEN);
for (Slot slot : slots) {
stats.putDouble(slot.botProb);
stats.putLong(slot.botId());
stats.putLong(slot.topId());
}
stats.flip();
this.slotCount = (stats.capacity()/ RECORD_LEN);
}
@Override
public long applyAsLong(double value) {
double fractionlPoint = value * slotCount;
int offsetPoint = (int) fractionlPoint * RECORD_LEN;
double divider = stats.getDouble(offsetPoint);
int selector = offsetPoint+ (fractionlPoint>divider?_r2:_r1);
long referentId = stats.getLong(selector);
return referentId;
}
private record Slot(long topId, long botId, double botProb){
public Slot rescale(int min, int max) {
return new Slot(topId, botId, (min + (botProb*(max-min))));
}
};
}

View File

@ -0,0 +1,22 @@
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.core.stathelpers;
import java.util.Comparator;
public record EvProbLongDouble(long id, double prob) {
}

View File

@ -0,0 +1,76 @@
package io.nosqlbench.virtdata.library.basics.shared.from_long.to_double;
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import io.nosqlbench.nb.api.errors.BasicError;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.Example;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.AliasSamplerDoubleLong;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.EvProbLongDouble;
import java.util.ArrayList;
import java.util.List;
/// Empirical Histribution is a portmanteau name to capture the
/// concept of an empirical distribution based on a discrete histogram.
/// This is in contrast to the other similar method [EmpiricalDistribution],
/// which uses a continuous density estimation. Both excel in specific ways.
///
/// Use this distribution when you have a set of label frequencies which you
/// want to represent accurately.
@ThreadSafeMapper
@Categories(Category.distributions)
public class EmpiricalHistribution extends AliasSamplerDoubleLong {
@Example({"EmpiricalHistribution('50 25 13 12')", "implied frequencies of 0:50 1:25 2:13 3:12"})
@Example({
"EmpiricalHistribution('234:50 33:25 17:13 3:12')",
"labeled frequencies; 234,33,17,3 are labels, and 50,25,13,12 are weights"
})
public EmpiricalHistribution(String freqs) {
List<EvProbLongDouble> events = new ArrayList<>();
boolean labeled = (freqs.contains(":"));
String[] elems = freqs.split("[,; ]");
for (int i = 0; i < elems.length; i++) {
String[] parts = elems[i].split(":", 2);
if ((parts.length == 1 && labeled) || (parts.length == 2 && !labeled)) {
throw new RuntimeException(
"If any elements are labeled, all elements must be:" + freqs);
}
long id = labeled ? Long.parseLong(parts[0]) : i;
events.add(new EvProbLongDouble(id, Long.parseLong(parts[1])));
}
super(events);
}
public EmpiricalHistribution(long... freqs) {
super(genEvents(freqs));
}
private static List<EvProbLongDouble> genEvents(long[] freqs) {
ArrayList<EvProbLongDouble> events = new ArrayList<>();
for (int i = 0; i < freqs.length; i++) {
events.add(new EvProbLongDouble(i, freqs[i]));
}
return events;
}
}

View File

@ -0,0 +1,147 @@
package io.nosqlbench.virtdata.library.basics.shared.from_long.to_double;
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.Example;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.api.bindings.VirtDataConversions;
import java.util.function.LongToDoubleFunction;
import java.util.function.LongUnaryOperator;
/// Blends two functions with a domain of 0..Long.MAX_VALUE as the input interval,
/// and a double output. The output value is interpolated between the output value
/// of the two according to the mix function. When the mix function yields a value
/// of 0.0, then the mix is turned _fully counter-clockwise_., or fully on the first provided
/// function. When the value is 1.0, the mix is turned all the clockwise, or fully on the second
/// provided function.
///
/// If there are only two inner functions provided to HashMix, then it will default to
/// sampling random mixes at a randomized sample point. In other words, the variates
/// provided will be somewhere between the two curves on the unit interval. This is a simple way
/// to sample between two curves by default. The yielded value will be greater than or equal to
/// the lower of the two values at any point, and less than or equal to the greater of either.
///
/// If a third parameter is provided to control the mix, then the mix can be set directly as a
/// unit interval. (The dial goes from 0.0 to 1.0). Any double or float value here will suffice.
/// You can use this when you want to have a test parameter that slews between two modeled
/// shapes. You can alternately provide any other function which can be coerced to a LongToDouble
/// function as a dynamic mix control. IFF such a function is provided, it must also be responsible
/// for hashing the input value if pseudo-randomness is desired.
///
/// If a fourth parameter is provided, the sample point can also be controlled. By default, the
/// values on the provided curves will be sampled pseudo-randomly. However, a fourth parameter
/// can override this just like the mix ratio. As well, if you provide a value or function
/// to control the sample point, you are also responsible for any hashing needed to sample across
/// the whole space of possible values.
///
/// The flexibility of these two parameters provides a substantial amount of flexibility. You
/// can, for example:
///
/// - sample variates between two curves
/// - sample variates at a selected morphing step between the curves
/// - sample variates between two curves on a subsection of the unit interval
/// - sample variates within a defined band gap of the two curves
@ThreadSafeMapper
@Categories(Category.functional)
public class HashMix implements LongToDoubleFunction {
private final LongToDoubleFunction f1;
private final LongToDoubleFunction f2;
private final LongToDoubleFunction mixF;
private final LongUnaryOperator sampleF;
@Example({
"HashMix(Func1(),Func2())",
"yield samples between func1 and func2 values at some random random sample point x"
})
@Example({
"HashMix(Func1(),Func2(),0.25d)",
"yield samples which are 25% from the sample values for func1 and func2 at some random "
+ "sample point x"
})
@Example({
"HashMix(Func1(),Func2(),HashRange(0.25d,0.75d)",
"yield samples between 25% and 75% from func1 to func2 values at some random sample point x"
})
@Example({
"HashMix(Func1(),Func2(),0.0d,ScaledDouble())",
"access Func1 values as if it were the only one provided. ScaledDouble adds no "
+ "randomization the input value, but it does map it to the sample domain of 0.0d-0.1d."
})
public HashMix(Object curve1F, Object curve2F, Object mixPointF, Object samplePointF) {
if (mixPointF instanceof Double v) {
if (v > 1.0d || v < 0.0d) {
throw new RuntimeException(
"mix value (" + v + ") must be within the unit" + " range [0.0d,1.0d]");
}
this.mixF = n -> v;
} else if (mixPointF instanceof Float v) {
if (v > 1.0d || v < 0.0d) {
throw new RuntimeException(
"mix value (" + v + ") must be within the unit" + " range [0.0d,1.0d]");
}
this.mixF = n -> v;
} else {
this.mixF = VirtDataConversions.adaptFunction(mixPointF, LongToDoubleFunction.class);
}
this.f1 = VirtDataConversions.adaptFunction(curve1F, LongToDoubleFunction.class);
this.f2 = VirtDataConversions.adaptFunction(curve2F, LongToDoubleFunction.class);
this.sampleF = VirtDataConversions.adaptFunction(samplePointF, LongUnaryOperator.class);
}
public HashMix(Object curve1F, Object curve2F, Object mixPointF) {
this(
curve1F,
curve2F,
mixPointF,
new io.nosqlbench.virtdata.library.basics.shared.from_long.to_long.HashRange(Long.MAX_VALUE)
);
}
public HashMix(Object curve1F, Object curve2F) {
this(
curve1F,
curve2F,
new HashRange(0.0d, 1.0d),
new io.nosqlbench.virtdata.library.basics.shared.from_long.to_long.HashRange(Long.MAX_VALUE)
);
}
public HashMix(LongToDoubleFunction f1, LongToDoubleFunction f2) {
this(
f1,
f2,
new HashRange(0.0d, 1.0d),
new io.nosqlbench.virtdata.library.basics.shared.from_long.to_long.HashRange(Long.MAX_VALUE)
);
}
@Override
public double applyAsDouble(long value) {
long sampleAt = sampleF.applyAsLong(value);
double v1 = f1.applyAsDouble(sampleAt);
double v2 = f2.applyAsDouble(sampleAt);
double mix = mixF.applyAsDouble(value);
return LERP.lerp(v1, v2, mix);
}
}

View File

@ -0,0 +1,25 @@
package io.nosqlbench.virtdata.library.basics.shared.from_long.to_double;
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
public class LERP {
public static double lerp(double v1, double v2, double mix) {
return v1 + (v2 - v1) * mix;
}
}

View File

@ -0,0 +1,82 @@
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.core.stathelpers.aliasmethod;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.AliasSamplerDoubleInt;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.AliasSamplerDoubleLong;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.EvProbD;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.EvProbLongDouble;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
public class AliasSamplerDoubleLongTest {
private final static Logger logger = LogManager.getLogger(AliasSamplerDoubleLongTest.class);
@Test
public void testAliasSamplerBinaryFractions() {
List<EvProbLongDouble> events = new ArrayList();
events.add(new EvProbLongDouble(1L,1.0D));
events.add(new EvProbLongDouble(2L,1.0D));
events.add(new EvProbLongDouble(3L,2.0D));
events.add(new EvProbLongDouble(4L,4.0D));
events.add(new EvProbLongDouble(5L,8.0D));
events.add(new EvProbLongDouble(6L,16.0D));
events.add(new EvProbLongDouble(7L,32.0D));
events.add(new EvProbLongDouble(8L,64.0D));
AliasSamplerDoubleLong as = new AliasSamplerDoubleLong(events);
int[] stats = new int[9];
for (int i = 0; i < 10000; i++) {
double v = (double)i / 10000D;
long idx = as.applyAsLong(v);
stats[(int)idx]++;
}
logger.debug(Arrays.toString(stats));
assertThat(stats).containsExactly(0,79,79,157,313,626,1250,2499,4997);
}
@Test
public void testAliasSamplerSimple() {
List<EvProbD> events = new ArrayList<>();
events.add(new EvProbD(1,1D));
events.add(new EvProbD(2,2D));
events.add(new EvProbD(3,3D));
AliasSamplerDoubleInt as = new AliasSamplerDoubleInt(events);
int[] stats = new int[4];
for (int i = 0; i < 10000; i++) {
double v = (double)i / 10000D;
int idx = as.applyAsInt(v);
stats[idx]++;
}
logger.debug(Arrays.toString(stats));
assertThat(stats).containsExactly(0,1667,3334,4999);
}
}

View File

@ -0,0 +1,56 @@
package io.nosqlbench.virtdata.library.basics.shared.from_long.to_double;
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
public class EmpiricalHistributionTest {
@Test
public void testUniformSyntaxRequired() {
assertThatThrownBy(() -> new EmpiricalHistribution("1 2:2 3:3")).hasMessageContaining(
"all elements must be");
}
@Test
public void testBasicHistribution() {
EmpiricalHistribution h = new EmpiricalHistribution("1:1 2:2 3:3");
long[] counts = new long[10];
int total=1000000;
HashRange hr = new HashRange(0.0d, 1.0d);
for (int i = 0; i < total; i++) {
double hash = hr.applyAsDouble(i);
long v = h.applyAsLong(hash);
counts[(int)v]++;
}
assertThat((double) counts[0] / (double) total).isEqualTo(0.0d, Offset.offset(0.01));
assertThat((double) counts[1] / (double) total).isEqualTo(0.16666666d, Offset.offset(0.01));
assertThat((double) counts[2] / (double) total).isEqualTo(0.33333333d,
Offset.offset(0.01));
assertThat((double) counts[3] / (double) total).isEqualTo(0.5d, Offset.offset(0.01));
System.out.println(Arrays.toString(counts));
}
}

View File

@ -0,0 +1,56 @@
package io.nosqlbench.virtdata.library.basics.shared.from_long.to_double;
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Test;
import java.util.function.LongToDoubleFunction;
import static org.assertj.core.api.Assertions.assertThat;
public class HashMixTest {
private final static LongToDoubleFunction TO_UNIT_INTERVAL =
(l) -> ((double) l) / ((double) Long.MAX_VALUE);
private final static Object TO_UNIT_INTERVAL_OBJ = (Object) TO_UNIT_INTERVAL;
@Test
public void testLinearMix() {
HashMix um1 = new HashMix(TO_UNIT_INTERVAL, TO_UNIT_INTERVAL);
for (long i = 1; i < (Long.MAX_VALUE >> 1); i *= 2) {
assertThat(um1.applyAsDouble(i)).isEqualTo(
TO_UNIT_INTERVAL.applyAsDouble(i),
Offset.offset(0.0000001d)
);
}
}
@Test
public void testCrossfadeMix() {
LongToDoubleFunction rampdown1 = l -> 1.0d - TO_UNIT_INTERVAL.applyAsDouble(l);
LongToDoubleFunction rampdown2 = l -> 2.0d - TO_UNIT_INTERVAL.applyAsDouble(l);
HashMix um1 = new HashMix(rampdown1,rampdown2);
for (long i = 1<<24; i <= Long.MAX_VALUE>>1; i<<=1) {
double value = um1.applyAsDouble(i);
assertThat(um1.applyAsDouble(i)).isEqualTo(1.0d, Offset.offset(0.0000001d));
}
}
}

View File

@ -0,0 +1,37 @@
package io.nosqlbench.virtdata.library.basics.shared.from_long.to_double;
/*
* Copyright (c) nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
public class LERPTest {
@Test
public void testDoubleLerp() {
assertThat(LERP.lerp(10.0d,10.0d,1.0d)).isEqualTo(10.0d, Offset.offset(0.00001d));
assertThat(LERP.lerp(10.0d,0.0d,0.0d)).isEqualTo(10.0d, Offset.offset(0.00001d));
assertThat(LERP.lerp(10.0d,0.0d,1.0d)).isEqualTo(0.0d, Offset.offset(0.00001d));
assertThat(LERP.lerp(10.0d,5.0d,0.5d)).isEqualTo(7.5d, Offset.offset(0.00001d));
}
}