mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
simplify CSV sampling usage
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
package io.nosqlbench.virtdata.library.basics.shared.distributions;
|
||||
|
||||
import io.nosqlbench.nb.api.content.NBIO;
|
||||
import io.nosqlbench.nb.api.errors.BasicError;
|
||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||
import io.nosqlbench.virtdata.api.annotations.Example;
|
||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||
import io.nosqlbench.virtdata.library.basics.core.stathelpers.AliasElementSampler;
|
||||
import io.nosqlbench.virtdata.library.basics.core.stathelpers.ElemProbD;
|
||||
import io.nosqlbench.virtdata.library.basics.core.stathelpers.EvProbD;
|
||||
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_long.Hash;
|
||||
import org.apache.commons.csv.CSVParser;
|
||||
import org.apache.commons.csv.CSVRecord;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.LongFunction;
|
||||
import java.util.function.LongUnaryOperator;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
*
|
||||
* This function is a toolkit version of the {@link WeightedStringsFromCSV} function.
|
||||
* It is more capable and should be the preferred function for alias sampling over any CSV data.
|
||||
* This sampler uses a named column in the CSV data as the value. This is also referred to as the
|
||||
* <em>labelColumn</em>. The frequency of this label depends on the weight assigned to it in another named
|
||||
* CSV column, known as the <em>weightColumn</em>.
|
||||
*
|
||||
* <H3>Combining duplicate labels</H3>
|
||||
* When you have CSV data which is not organized around the specific identifier that you want to sample by,
|
||||
* you can use some combining functions to tabulate these prior to sampling. In that case, you can use
|
||||
* any of "sum", "avg", "count", "min", or "max" as the reducing function on the value in the weight column.
|
||||
* If none are specified, then "sum" is used by default. All modes except "count" and "name" require a valid weight
|
||||
* column to be specified.
|
||||
*
|
||||
* <UL>
|
||||
* <LI>sum, avg, min, max - takes the given stat for the weight of each distinct label</LI>
|
||||
* <LI>count - takes the number of occurrences of a given label as the weight</LI>
|
||||
* <LI>name - sets the weight of all distinct labels to 1.0d</LI>
|
||||
* </UL>
|
||||
*
|
||||
* <H3>Map vs Hash mode</H3>
|
||||
* As with some of the other statistical functions, you can use this one to pick through the sample values
|
||||
* by using the <em>map</em> mode. This is distinct from the default <em>hash</em> mode. When map mode is used,
|
||||
* the values will appear monotonically as you scan through the unit interval of all long values.
|
||||
* Specifically, 0L represents 0.0d in the unit interval on input, and Long.MAX_VALUE represents
|
||||
* 1.0 on the unit interval.) This mode is only recommended for advanced scenarios and should otherwise be
|
||||
* avoided. You will know if you need this mode.
|
||||
*
|
||||
*/
|
||||
@Categories(Category.general)
|
||||
@ThreadSafeMapper
|
||||
public class CSVSampler implements LongFunction<String> {
|
||||
|
||||
private final AliasElementSampler<String> sampler;
|
||||
private final LongUnaryOperator prefunc;
|
||||
private final static Set MODES = Set.of("map", "hash", "sum", "avg", "count", "min", "name", "max");
|
||||
|
||||
/**
|
||||
* Build an efficient O(1) sampler for the given column values with respect to the weights,
|
||||
* combining equal values by summing the weights.
|
||||
*
|
||||
* @param labelColumn The CSV column name containing the value
|
||||
* @param weightColumn The CSV column name containing a double weight
|
||||
* @param data Sampling modes or file names. Any of map, hash, sum, avg, count are taken
|
||||
* as configuration modes, and all others are taken as CSV filenames.
|
||||
*/
|
||||
@Example({"CSVSampler('USPS','n/a','name','census_state_abbrev')",""})
|
||||
public CSVSampler(String labelColumn, String weightColumn, String... data) {
|
||||
List<EvProbD> events = new ArrayList<>();
|
||||
List<String> values = new ArrayList<>();
|
||||
|
||||
Function<LabeledStatistic, Double> weightFunc = LabeledStatistic::sum;
|
||||
LongUnaryOperator prefunc = new Hash();
|
||||
boolean weightRequired = false;
|
||||
|
||||
while (data.length > 0 && MODES.contains(data[0])) {
|
||||
String cfg = data[0];
|
||||
data = Arrays.copyOfRange(data, 1, data.length);
|
||||
switch (cfg) {
|
||||
case "map":
|
||||
prefunc = i -> i;
|
||||
break;
|
||||
case "hash":
|
||||
prefunc = new Hash();
|
||||
break;
|
||||
case "sum":
|
||||
weightFunc = LabeledStatistic::sum;
|
||||
weightRequired = true;
|
||||
break;
|
||||
case "min":
|
||||
weightFunc = LabeledStatistic::min;
|
||||
weightRequired = true;
|
||||
break;
|
||||
case "max":
|
||||
weightFunc = LabeledStatistic::max;
|
||||
weightRequired = true;
|
||||
break;
|
||||
case "avg":
|
||||
weightFunc = LabeledStatistic::avg;
|
||||
weightRequired = true;
|
||||
break;
|
||||
case "count":
|
||||
weightFunc = LabeledStatistic::count;
|
||||
weightRequired = false;
|
||||
break;
|
||||
case "name":
|
||||
weightFunc = (v) -> 1.0d;
|
||||
weightRequired = false;
|
||||
break;
|
||||
default:
|
||||
throw new BasicError("Unknown cfg verb '" + cfg + "'");
|
||||
|
||||
}
|
||||
}
|
||||
this.prefunc = prefunc;
|
||||
|
||||
final Function<LabeledStatistic, Double> valFunc = weightFunc;
|
||||
|
||||
Map<String, LabeledStatistic> entries = new HashMap<>();
|
||||
|
||||
for (String filename : data) {
|
||||
if (!filename.endsWith(".csv")) {
|
||||
filename = filename + ".csv";
|
||||
}
|
||||
CSVParser csvdata = NBIO.readFileCSV(filename);
|
||||
|
||||
String labelName = csvdata.getHeaderNames().stream()
|
||||
.filter(labelColumn::equalsIgnoreCase)
|
||||
.findAny().orElseThrow();
|
||||
|
||||
String weightName = "none";
|
||||
if (weightRequired) {
|
||||
weightName = csvdata.getHeaderNames().stream()
|
||||
.filter(weightColumn::equalsIgnoreCase)
|
||||
.findAny().orElseThrow();
|
||||
}
|
||||
|
||||
double weight = 1.0d;
|
||||
for (CSVRecord csvdatum : csvdata) {
|
||||
if (csvdatum.get(labelName) != null) {
|
||||
String label = csvdatum.get(labelName);
|
||||
if (weightRequired) {
|
||||
String weightString = csvdatum.get(weightName);
|
||||
weight = weightString.isEmpty() ? 1.0d : Double.parseDouble(weightString);
|
||||
}
|
||||
LabeledStatistic entry = new LabeledStatistic(label, weight);
|
||||
entries.merge(label, entry, LabeledStatistic::merge);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<ElemProbD<String>> elemList = entries.values()
|
||||
.stream()
|
||||
.map(t -> new ElemProbD<>(t.label, valFunc.apply(t)))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
this.sampler = new AliasElementSampler<String>(elemList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String apply(long value) {
|
||||
value = prefunc.applyAsLong(value);
|
||||
double unitValue = (double) value / (double) Long.MAX_VALUE;
|
||||
String val = sampler.apply(unitValue);
|
||||
return val;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package io.nosqlbench.virtdata.library.basics.shared.distributions;
|
||||
|
||||
class LabeledStatistic {
|
||||
public final String label;
|
||||
public final double total;
|
||||
public final int count;
|
||||
public final double min;
|
||||
public final double max;
|
||||
|
||||
public LabeledStatistic(String label, double weight) {
|
||||
this.label = label;
|
||||
this.total = weight;
|
||||
this.min = weight;
|
||||
this.max = weight;
|
||||
this.count = 1;
|
||||
}
|
||||
|
||||
private LabeledStatistic(String label, double total, double min, double max, int count) {
|
||||
this.label = label;
|
||||
this.total = total;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
public LabeledStatistic merge(LabeledStatistic tuple) {
|
||||
return new LabeledStatistic(
|
||||
this.label,
|
||||
this.total + tuple.total,
|
||||
Math.min(this.min, tuple.min),
|
||||
Math.max(this.max, tuple.max),
|
||||
this.count + tuple.count
|
||||
);
|
||||
}
|
||||
|
||||
public double count() {
|
||||
return count;
|
||||
}
|
||||
|
||||
public double avg() {
|
||||
return total / count;
|
||||
}
|
||||
|
||||
public double sum() {
|
||||
return total;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "EntryTuple{" +
|
||||
"label='" + label + '\'' +
|
||||
", total=" + total +
|
||||
", count=" + count +
|
||||
'}';
|
||||
}
|
||||
|
||||
public double min() {
|
||||
return this.min;
|
||||
}
|
||||
|
||||
public double max() {
|
||||
return this.max;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package io.nosqlbench.virtdata.library.basics.shared.distributions;
|
||||
|
||||
import org.assertj.core.data.Percentage;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* All tests in this class are based on a CSV file with the following contents.
|
||||
*
|
||||
* <pre> {@code
|
||||
* NAME,WEIGHT,MEMO
|
||||
* alpha,1,this is sparta
|
||||
* beta,2,this is sparta
|
||||
* gamma,3,this is sparta
|
||||
* delta,4,this is sparta
|
||||
* epsilon,5,this is sparta
|
||||
* alpha,6,this is sparta
|
||||
* } </pre>
|
||||
*/
|
||||
public class CSVSamplerTest {
|
||||
|
||||
|
||||
/**
|
||||
* In this test, alpha appears twice, and all others once, so alpha should appear roughly 2x more frequently
|
||||
*/
|
||||
@Test
|
||||
public void testByCount() {
|
||||
CSVSampler sampler = new CSVSampler("name", "weightfoo", "count", "basicdata");
|
||||
String value = sampler.apply(1);
|
||||
|
||||
Map<String,Double> results = new HashMap<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
String name = sampler.apply(i);
|
||||
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
|
||||
}
|
||||
System.out.println(results);
|
||||
assertThat(results.get("alpha")).isCloseTo(results.get("beta")*2, Percentage.withPercentage(5.0d));
|
||||
assertThat(results.get("alpha")).isCloseTo(results.get("gamma")*2, Percentage.withPercentage(5.0d));
|
||||
assertThat(results.get("alpha")).isCloseTo(results.get("delta")*2, Percentage.withPercentage(5.0d));
|
||||
assertThat(results.get("alpha")).isCloseTo(results.get("epsilon")*2, Percentage.withPercentage(5.0d));
|
||||
}
|
||||
|
||||
/**
|
||||
* In this test, alpha's weights sum to 1/3 of the total weight, thus it should appear roughly 1/3 of the time
|
||||
*/
|
||||
@Test
|
||||
public void testBySum() {
|
||||
CSVSampler sampler = new CSVSampler("name", "weight", "sum", "basicdata");
|
||||
String value = sampler.apply(1);
|
||||
|
||||
Map<String,Double> results = new HashMap<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
String name = sampler.apply(i);
|
||||
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
|
||||
}
|
||||
System.out.println(results);
|
||||
assertThat(results.get("alpha")).isCloseTo(33333, Percentage.withPercentage(2.0d));
|
||||
}
|
||||
|
||||
/**
|
||||
* In this test, alpha's weights avg to 3.5, or 3.5/17.5 or 20%, so should appear 20% of the time.
|
||||
*/
|
||||
@Test
|
||||
public void testByAvgs() {
|
||||
CSVSampler sampler = new CSVSampler("name", "weight", "avg", "basicdata");
|
||||
String value = sampler.apply(1);
|
||||
|
||||
Map<String,Double> results = new HashMap<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
String name = sampler.apply(i);
|
||||
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
|
||||
}
|
||||
System.out.println(results);
|
||||
assertThat(results.get("alpha")).isCloseTo(20000, Percentage.withPercentage(2.0d));
|
||||
}
|
||||
|
||||
/**
|
||||
* In this test, alpha is 1/15 of the total weight, or 6.6% of expected frequency
|
||||
*/
|
||||
@Test
|
||||
public void testByMin() {
|
||||
CSVSampler sampler = new CSVSampler("name", "weight", "min", "basicdata");
|
||||
String value = sampler.apply(1);
|
||||
|
||||
Map<String,Double> results = new HashMap<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
String name = sampler.apply(i);
|
||||
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
|
||||
}
|
||||
System.out.println(results);
|
||||
assertThat(results.get("alpha")).isCloseTo(6666, Percentage.withPercentage(2.0d));
|
||||
}
|
||||
|
||||
/**
|
||||
* In this test, alpha is 6/20 of expected frequency or 30%
|
||||
*/
|
||||
@Test
|
||||
public void testByMax() {
|
||||
CSVSampler sampler = new CSVSampler("name", "weight", "max", "basicdata");
|
||||
String value = sampler.apply(1);
|
||||
|
||||
Map<String,Double> results = new HashMap<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
String name = sampler.apply(i);
|
||||
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
|
||||
}
|
||||
System.out.println(results);
|
||||
assertThat(results.get("alpha")).isCloseTo(30000, Percentage.withPercentage(2.0d));
|
||||
}
|
||||
|
||||
/**
|
||||
* In this test, alpha is 1/5 of the distinct names included.
|
||||
*/
|
||||
@Test
|
||||
public void testByName() {
|
||||
CSVSampler sampler = new CSVSampler("name", "does not matter", "name", "basicdata");
|
||||
String value = sampler.apply(1);
|
||||
|
||||
Map<String,Double> results = new HashMap<>();
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
String name = sampler.apply(i);
|
||||
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
|
||||
}
|
||||
System.out.println(results);
|
||||
assertThat(results.get("alpha")).isCloseTo(20000, Percentage.withPercentage(2.0d));
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
7
virtdata-lib-basics/src/test/resources/basicdata.csv
Normal file
7
virtdata-lib-basics/src/test/resources/basicdata.csv
Normal file
@@ -0,0 +1,7 @@
|
||||
NAME,WEIGHT,MEMO
|
||||
alpha,1,this is sparta
|
||||
beta,2,this is sparta
|
||||
gamma,3,this is sparta
|
||||
delta,4,this is sparta
|
||||
epsilon,5,this is sparta
|
||||
alpha,6,this is sparta
|
||||
|
Reference in New Issue
Block a user