diff --git a/virtdata-lib-basics/src/main/java/io/nosqlbench/virtdata/library/basics/shared/distributions/CSVSampler.java b/virtdata-lib-basics/src/main/java/io/nosqlbench/virtdata/library/basics/shared/distributions/CSVSampler.java new file mode 100644 index 000000000..15e4a059a --- /dev/null +++ b/virtdata-lib-basics/src/main/java/io/nosqlbench/virtdata/library/basics/shared/distributions/CSVSampler.java @@ -0,0 +1,169 @@ +package io.nosqlbench.virtdata.library.basics.shared.distributions; + +import io.nosqlbench.nb.api.content.NBIO; +import io.nosqlbench.nb.api.errors.BasicError; +import io.nosqlbench.virtdata.api.annotations.Categories; +import io.nosqlbench.virtdata.api.annotations.Category; +import io.nosqlbench.virtdata.api.annotations.Example; +import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper; +import io.nosqlbench.virtdata.library.basics.core.stathelpers.AliasElementSampler; +import io.nosqlbench.virtdata.library.basics.core.stathelpers.ElemProbD; +import io.nosqlbench.virtdata.library.basics.core.stathelpers.EvProbD; +import io.nosqlbench.virtdata.library.basics.shared.from_long.to_long.Hash; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; + +import java.util.*; +import java.util.function.Function; +import java.util.function.LongFunction; +import java.util.function.LongUnaryOperator; +import java.util.stream.Collectors; + +/** + * + * This function is a toolkit version of the {@link WeightedStringsFromCSV} function. + * It is more capable and should be the preferred function for alias sampling over any CSV data. + * This sampler uses a named column in the CSV data as the value. This is also referred to as the + * labelColumn. The frequency of this label depends on the weight assigned to it in another named + * CSV column, known as the weightColumn. + * + *

Combining duplicate labels

+ * When you have CSV data which is not organized around the specific identifier that you want to sample by, + * you can use some combining functions to tabulate these prior to sampling. In that case, you can use + * any of "sum", "avg", "count", "min", or "max" as the reducing function on the value in the weight column. + * If none are specified, then "sum" is used by default. All modes except "count" and "name" require a valid weight + * column to be specified. + * + * + * + *

Map vs Hash mode

+ * As with some of the other statistical functions, you can use this one to pick through the sample values + * by using the map mode. This is distinct from the default hash mode. When map mode is used, + * the values will appear monotonically as you scan through the unit interval of all long values. + * Specifically, 0L represents 0.0d in the unit interval on input, and Long.MAX_VALUE represents + * 1.0 on the unit interval.) This mode is only recommended for advanced scenarios and should otherwise be + * avoided. You will know if you need this mode. + * + */ +@Categories(Category.general) +@ThreadSafeMapper +public class CSVSampler implements LongFunction { + + private final AliasElementSampler sampler; + private final LongUnaryOperator prefunc; + private final static Set MODES = Set.of("map", "hash", "sum", "avg", "count", "min", "name", "max"); + + /** + * Build an efficient O(1) sampler for the given column values with respect to the weights, + * combining equal values by summing the weights. + * + * @param labelColumn The CSV column name containing the value + * @param weightColumn The CSV column name containing a double weight + * @param data Sampling modes or file names. Any of map, hash, sum, avg, count are taken + * as configuration modes, and all others are taken as CSV filenames. + */ + @Example({"CSVSampler('USPS','n/a','name','census_state_abbrev')",""}) + public CSVSampler(String labelColumn, String weightColumn, String... data) { + List events = new ArrayList<>(); + List values = new ArrayList<>(); + + Function weightFunc = LabeledStatistic::sum; + LongUnaryOperator prefunc = new Hash(); + boolean weightRequired = false; + + while (data.length > 0 && MODES.contains(data[0])) { + String cfg = data[0]; + data = Arrays.copyOfRange(data, 1, data.length); + switch (cfg) { + case "map": + prefunc = i -> i; + break; + case "hash": + prefunc = new Hash(); + break; + case "sum": + weightFunc = LabeledStatistic::sum; + weightRequired = true; + break; + case "min": + weightFunc = LabeledStatistic::min; + weightRequired = true; + break; + case "max": + weightFunc = LabeledStatistic::max; + weightRequired = true; + break; + case "avg": + weightFunc = LabeledStatistic::avg; + weightRequired = true; + break; + case "count": + weightFunc = LabeledStatistic::count; + weightRequired = false; + break; + case "name": + weightFunc = (v) -> 1.0d; + weightRequired = false; + break; + default: + throw new BasicError("Unknown cfg verb '" + cfg + "'"); + + } + } + this.prefunc = prefunc; + + final Function valFunc = weightFunc; + + Map entries = new HashMap<>(); + + for (String filename : data) { + if (!filename.endsWith(".csv")) { + filename = filename + ".csv"; + } + CSVParser csvdata = NBIO.readFileCSV(filename); + + String labelName = csvdata.getHeaderNames().stream() + .filter(labelColumn::equalsIgnoreCase) + .findAny().orElseThrow(); + + String weightName = "none"; + if (weightRequired) { + weightName = csvdata.getHeaderNames().stream() + .filter(weightColumn::equalsIgnoreCase) + .findAny().orElseThrow(); + } + + double weight = 1.0d; + for (CSVRecord csvdatum : csvdata) { + if (csvdatum.get(labelName) != null) { + String label = csvdatum.get(labelName); + if (weightRequired) { + String weightString = csvdatum.get(weightName); + weight = weightString.isEmpty() ? 1.0d : Double.parseDouble(weightString); + } + LabeledStatistic entry = new LabeledStatistic(label, weight); + entries.merge(label, entry, LabeledStatistic::merge); + } + } + } + + List> elemList = entries.values() + .stream() + .map(t -> new ElemProbD<>(t.label, valFunc.apply(t))) + .collect(Collectors.toList()); + + this.sampler = new AliasElementSampler(elemList); + } + + @Override + public String apply(long value) { + value = prefunc.applyAsLong(value); + double unitValue = (double) value / (double) Long.MAX_VALUE; + String val = sampler.apply(unitValue); + return val; + } +} diff --git a/virtdata-lib-basics/src/main/java/io/nosqlbench/virtdata/library/basics/shared/distributions/LabeledStatistic.java b/virtdata-lib-basics/src/main/java/io/nosqlbench/virtdata/library/basics/shared/distributions/LabeledStatistic.java new file mode 100644 index 000000000..1a3a2e941 --- /dev/null +++ b/virtdata-lib-basics/src/main/java/io/nosqlbench/virtdata/library/basics/shared/distributions/LabeledStatistic.java @@ -0,0 +1,64 @@ +package io.nosqlbench.virtdata.library.basics.shared.distributions; + +class LabeledStatistic { + public final String label; + public final double total; + public final int count; + public final double min; + public final double max; + + public LabeledStatistic(String label, double weight) { + this.label = label; + this.total = weight; + this.min = weight; + this.max = weight; + this.count = 1; + } + + private LabeledStatistic(String label, double total, double min, double max, int count) { + this.label = label; + this.total = total; + this.min = min; + this.max = max; + this.count = count; + } + + public LabeledStatistic merge(LabeledStatistic tuple) { + return new LabeledStatistic( + this.label, + this.total + tuple.total, + Math.min(this.min, tuple.min), + Math.max(this.max, tuple.max), + this.count + tuple.count + ); + } + + public double count() { + return count; + } + + public double avg() { + return total / count; + } + + public double sum() { + return total; + } + + @Override + public String toString() { + return "EntryTuple{" + + "label='" + label + '\'' + + ", total=" + total + + ", count=" + count + + '}'; + } + + public double min() { + return this.min; + } + + public double max() { + return this.max; + } +} diff --git a/virtdata-lib-basics/src/test/java/io/nosqlbench/virtdata/library/basics/shared/distributions/CSVSamplerTest.java b/virtdata-lib-basics/src/test/java/io/nosqlbench/virtdata/library/basics/shared/distributions/CSVSamplerTest.java new file mode 100644 index 000000000..0e369f258 --- /dev/null +++ b/virtdata-lib-basics/src/test/java/io/nosqlbench/virtdata/library/basics/shared/distributions/CSVSamplerTest.java @@ -0,0 +1,134 @@ +package io.nosqlbench.virtdata.library.basics.shared.distributions; + +import org.assertj.core.data.Percentage; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * All tests in this class are based on a CSV file with the following contents. + * + *
 {@code
+ * NAME,WEIGHT,MEMO
+ * alpha,1,this is sparta
+ * beta,2,this is sparta
+ * gamma,3,this is sparta
+ * delta,4,this is sparta
+ * epsilon,5,this is sparta
+ * alpha,6,this is sparta
+ * } 
+ */ +public class CSVSamplerTest { + + + /** + * In this test, alpha appears twice, and all others once, so alpha should appear roughly 2x more frequently + */ + @Test + public void testByCount() { + CSVSampler sampler = new CSVSampler("name", "weightfoo", "count", "basicdata"); + String value = sampler.apply(1); + + Map results = new HashMap<>(); + for (int i = 0; i < 100000; i++) { + String name = sampler.apply(i); + results.compute(name, (k,v) -> v==null ? 1d : v + 1d); + } + System.out.println(results); + assertThat(results.get("alpha")).isCloseTo(results.get("beta")*2, Percentage.withPercentage(5.0d)); + assertThat(results.get("alpha")).isCloseTo(results.get("gamma")*2, Percentage.withPercentage(5.0d)); + assertThat(results.get("alpha")).isCloseTo(results.get("delta")*2, Percentage.withPercentage(5.0d)); + assertThat(results.get("alpha")).isCloseTo(results.get("epsilon")*2, Percentage.withPercentage(5.0d)); + } + + /** + * In this test, alpha's weights sum to 1/3 of the total weight, thus it should appear roughly 1/3 of the time + */ + @Test + public void testBySum() { + CSVSampler sampler = new CSVSampler("name", "weight", "sum", "basicdata"); + String value = sampler.apply(1); + + Map results = new HashMap<>(); + for (int i = 0; i < 100000; i++) { + String name = sampler.apply(i); + results.compute(name, (k,v) -> v==null ? 1d : v + 1d); + } + System.out.println(results); + assertThat(results.get("alpha")).isCloseTo(33333, Percentage.withPercentage(2.0d)); + } + + /** + * In this test, alpha's weights avg to 3.5, or 3.5/17.5 or 20%, so should appear 20% of the time. + */ + @Test + public void testByAvgs() { + CSVSampler sampler = new CSVSampler("name", "weight", "avg", "basicdata"); + String value = sampler.apply(1); + + Map results = new HashMap<>(); + for (int i = 0; i < 100000; i++) { + String name = sampler.apply(i); + results.compute(name, (k,v) -> v==null ? 1d : v + 1d); + } + System.out.println(results); + assertThat(results.get("alpha")).isCloseTo(20000, Percentage.withPercentage(2.0d)); + } + + /** + * In this test, alpha is 1/15 of the total weight, or 6.6% of expected frequency + */ + @Test + public void testByMin() { + CSVSampler sampler = new CSVSampler("name", "weight", "min", "basicdata"); + String value = sampler.apply(1); + + Map results = new HashMap<>(); + for (int i = 0; i < 100000; i++) { + String name = sampler.apply(i); + results.compute(name, (k,v) -> v==null ? 1d : v + 1d); + } + System.out.println(results); + assertThat(results.get("alpha")).isCloseTo(6666, Percentage.withPercentage(2.0d)); + } + + /** + * In this test, alpha is 6/20 of expected frequency or 30% + */ + @Test + public void testByMax() { + CSVSampler sampler = new CSVSampler("name", "weight", "max", "basicdata"); + String value = sampler.apply(1); + + Map results = new HashMap<>(); + for (int i = 0; i < 100000; i++) { + String name = sampler.apply(i); + results.compute(name, (k,v) -> v==null ? 1d : v + 1d); + } + System.out.println(results); + assertThat(results.get("alpha")).isCloseTo(30000, Percentage.withPercentage(2.0d)); + } + + /** + * In this test, alpha is 1/5 of the distinct names included. + */ + @Test + public void testByName() { + CSVSampler sampler = new CSVSampler("name", "does not matter", "name", "basicdata"); + String value = sampler.apply(1); + + Map results = new HashMap<>(); + for (int i = 0; i < 100000; i++) { + String name = sampler.apply(i); + results.compute(name, (k,v) -> v==null ? 1d : v + 1d); + } + System.out.println(results); + assertThat(results.get("alpha")).isCloseTo(20000, Percentage.withPercentage(2.0d)); + } + + + +} diff --git a/virtdata-lib-basics/src/test/resources/basicdata.csv b/virtdata-lib-basics/src/test/resources/basicdata.csv new file mode 100644 index 000000000..dc8f44ea9 --- /dev/null +++ b/virtdata-lib-basics/src/test/resources/basicdata.csv @@ -0,0 +1,7 @@ +NAME,WEIGHT,MEMO +alpha,1,this is sparta +beta,2,this is sparta +gamma,3,this is sparta +delta,4,this is sparta +epsilon,5,this is sparta +alpha,6,this is sparta