simplify CSV sampling usage

This commit is contained in:
Jonathan Shook
2021-06-01 17:35:40 -05:00
parent 4f2482ecb2
commit 0cd100d49f
4 changed files with 374 additions and 0 deletions

View File

@@ -0,0 +1,169 @@
package io.nosqlbench.virtdata.library.basics.shared.distributions;
import io.nosqlbench.nb.api.content.NBIO;
import io.nosqlbench.nb.api.errors.BasicError;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.Example;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.AliasElementSampler;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.ElemProbD;
import io.nosqlbench.virtdata.library.basics.core.stathelpers.EvProbD;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_long.Hash;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.util.*;
import java.util.function.Function;
import java.util.function.LongFunction;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
/**
*
* This function is a toolkit version of the {@link WeightedStringsFromCSV} function.
* It is more capable and should be the preferred function for alias sampling over any CSV data.
* This sampler uses a named column in the CSV data as the value. This is also referred to as the
* <em>labelColumn</em>. The frequency of this label depends on the weight assigned to it in another named
* CSV column, known as the <em>weightColumn</em>.
*
* <H3>Combining duplicate labels</H3>
* When you have CSV data which is not organized around the specific identifier that you want to sample by,
* you can use some combining functions to tabulate these prior to sampling. In that case, you can use
* any of "sum", "avg", "count", "min", or "max" as the reducing function on the value in the weight column.
* If none are specified, then "sum" is used by default. All modes except "count" and "name" require a valid weight
* column to be specified.
*
* <UL>
* <LI>sum, avg, min, max - takes the given stat for the weight of each distinct label</LI>
* <LI>count - takes the number of occurrences of a given label as the weight</LI>
* <LI>name - sets the weight of all distinct labels to 1.0d</LI>
* </UL>
*
* <H3>Map vs Hash mode</H3>
* As with some of the other statistical functions, you can use this one to pick through the sample values
* by using the <em>map</em> mode. This is distinct from the default <em>hash</em> mode. When map mode is used,
* the values will appear monotonically as you scan through the unit interval of all long values.
* Specifically, 0L represents 0.0d in the unit interval on input, and Long.MAX_VALUE represents
* 1.0 on the unit interval.) This mode is only recommended for advanced scenarios and should otherwise be
* avoided. You will know if you need this mode.
*
*/
@Categories(Category.general)
@ThreadSafeMapper
public class CSVSampler implements LongFunction<String> {
private final AliasElementSampler<String> sampler;
private final LongUnaryOperator prefunc;
private final static Set MODES = Set.of("map", "hash", "sum", "avg", "count", "min", "name", "max");
/**
* Build an efficient O(1) sampler for the given column values with respect to the weights,
* combining equal values by summing the weights.
*
* @param labelColumn The CSV column name containing the value
* @param weightColumn The CSV column name containing a double weight
* @param data Sampling modes or file names. Any of map, hash, sum, avg, count are taken
* as configuration modes, and all others are taken as CSV filenames.
*/
@Example({"CSVSampler('USPS','n/a','name','census_state_abbrev')",""})
public CSVSampler(String labelColumn, String weightColumn, String... data) {
List<EvProbD> events = new ArrayList<>();
List<String> values = new ArrayList<>();
Function<LabeledStatistic, Double> weightFunc = LabeledStatistic::sum;
LongUnaryOperator prefunc = new Hash();
boolean weightRequired = false;
while (data.length > 0 && MODES.contains(data[0])) {
String cfg = data[0];
data = Arrays.copyOfRange(data, 1, data.length);
switch (cfg) {
case "map":
prefunc = i -> i;
break;
case "hash":
prefunc = new Hash();
break;
case "sum":
weightFunc = LabeledStatistic::sum;
weightRequired = true;
break;
case "min":
weightFunc = LabeledStatistic::min;
weightRequired = true;
break;
case "max":
weightFunc = LabeledStatistic::max;
weightRequired = true;
break;
case "avg":
weightFunc = LabeledStatistic::avg;
weightRequired = true;
break;
case "count":
weightFunc = LabeledStatistic::count;
weightRequired = false;
break;
case "name":
weightFunc = (v) -> 1.0d;
weightRequired = false;
break;
default:
throw new BasicError("Unknown cfg verb '" + cfg + "'");
}
}
this.prefunc = prefunc;
final Function<LabeledStatistic, Double> valFunc = weightFunc;
Map<String, LabeledStatistic> entries = new HashMap<>();
for (String filename : data) {
if (!filename.endsWith(".csv")) {
filename = filename + ".csv";
}
CSVParser csvdata = NBIO.readFileCSV(filename);
String labelName = csvdata.getHeaderNames().stream()
.filter(labelColumn::equalsIgnoreCase)
.findAny().orElseThrow();
String weightName = "none";
if (weightRequired) {
weightName = csvdata.getHeaderNames().stream()
.filter(weightColumn::equalsIgnoreCase)
.findAny().orElseThrow();
}
double weight = 1.0d;
for (CSVRecord csvdatum : csvdata) {
if (csvdatum.get(labelName) != null) {
String label = csvdatum.get(labelName);
if (weightRequired) {
String weightString = csvdatum.get(weightName);
weight = weightString.isEmpty() ? 1.0d : Double.parseDouble(weightString);
}
LabeledStatistic entry = new LabeledStatistic(label, weight);
entries.merge(label, entry, LabeledStatistic::merge);
}
}
}
List<ElemProbD<String>> elemList = entries.values()
.stream()
.map(t -> new ElemProbD<>(t.label, valFunc.apply(t)))
.collect(Collectors.toList());
this.sampler = new AliasElementSampler<String>(elemList);
}
@Override
public String apply(long value) {
value = prefunc.applyAsLong(value);
double unitValue = (double) value / (double) Long.MAX_VALUE;
String val = sampler.apply(unitValue);
return val;
}
}

View File

@@ -0,0 +1,64 @@
package io.nosqlbench.virtdata.library.basics.shared.distributions;
class LabeledStatistic {
public final String label;
public final double total;
public final int count;
public final double min;
public final double max;
public LabeledStatistic(String label, double weight) {
this.label = label;
this.total = weight;
this.min = weight;
this.max = weight;
this.count = 1;
}
private LabeledStatistic(String label, double total, double min, double max, int count) {
this.label = label;
this.total = total;
this.min = min;
this.max = max;
this.count = count;
}
public LabeledStatistic merge(LabeledStatistic tuple) {
return new LabeledStatistic(
this.label,
this.total + tuple.total,
Math.min(this.min, tuple.min),
Math.max(this.max, tuple.max),
this.count + tuple.count
);
}
public double count() {
return count;
}
public double avg() {
return total / count;
}
public double sum() {
return total;
}
@Override
public String toString() {
return "EntryTuple{" +
"label='" + label + '\'' +
", total=" + total +
", count=" + count +
'}';
}
public double min() {
return this.min;
}
public double max() {
return this.max;
}
}

View File

@@ -0,0 +1,134 @@
package io.nosqlbench.virtdata.library.basics.shared.distributions;
import org.assertj.core.data.Percentage;
import org.junit.Test;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
/**
* All tests in this class are based on a CSV file with the following contents.
*
* <pre> {@code
* NAME,WEIGHT,MEMO
* alpha,1,this is sparta
* beta,2,this is sparta
* gamma,3,this is sparta
* delta,4,this is sparta
* epsilon,5,this is sparta
* alpha,6,this is sparta
* } </pre>
*/
public class CSVSamplerTest {
/**
* In this test, alpha appears twice, and all others once, so alpha should appear roughly 2x more frequently
*/
@Test
public void testByCount() {
CSVSampler sampler = new CSVSampler("name", "weightfoo", "count", "basicdata");
String value = sampler.apply(1);
Map<String,Double> results = new HashMap<>();
for (int i = 0; i < 100000; i++) {
String name = sampler.apply(i);
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
}
System.out.println(results);
assertThat(results.get("alpha")).isCloseTo(results.get("beta")*2, Percentage.withPercentage(5.0d));
assertThat(results.get("alpha")).isCloseTo(results.get("gamma")*2, Percentage.withPercentage(5.0d));
assertThat(results.get("alpha")).isCloseTo(results.get("delta")*2, Percentage.withPercentage(5.0d));
assertThat(results.get("alpha")).isCloseTo(results.get("epsilon")*2, Percentage.withPercentage(5.0d));
}
/**
* In this test, alpha's weights sum to 1/3 of the total weight, thus it should appear roughly 1/3 of the time
*/
@Test
public void testBySum() {
CSVSampler sampler = new CSVSampler("name", "weight", "sum", "basicdata");
String value = sampler.apply(1);
Map<String,Double> results = new HashMap<>();
for (int i = 0; i < 100000; i++) {
String name = sampler.apply(i);
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
}
System.out.println(results);
assertThat(results.get("alpha")).isCloseTo(33333, Percentage.withPercentage(2.0d));
}
/**
* In this test, alpha's weights avg to 3.5, or 3.5/17.5 or 20%, so should appear 20% of the time.
*/
@Test
public void testByAvgs() {
CSVSampler sampler = new CSVSampler("name", "weight", "avg", "basicdata");
String value = sampler.apply(1);
Map<String,Double> results = new HashMap<>();
for (int i = 0; i < 100000; i++) {
String name = sampler.apply(i);
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
}
System.out.println(results);
assertThat(results.get("alpha")).isCloseTo(20000, Percentage.withPercentage(2.0d));
}
/**
* In this test, alpha is 1/15 of the total weight, or 6.6% of expected frequency
*/
@Test
public void testByMin() {
CSVSampler sampler = new CSVSampler("name", "weight", "min", "basicdata");
String value = sampler.apply(1);
Map<String,Double> results = new HashMap<>();
for (int i = 0; i < 100000; i++) {
String name = sampler.apply(i);
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
}
System.out.println(results);
assertThat(results.get("alpha")).isCloseTo(6666, Percentage.withPercentage(2.0d));
}
/**
* In this test, alpha is 6/20 of expected frequency or 30%
*/
@Test
public void testByMax() {
CSVSampler sampler = new CSVSampler("name", "weight", "max", "basicdata");
String value = sampler.apply(1);
Map<String,Double> results = new HashMap<>();
for (int i = 0; i < 100000; i++) {
String name = sampler.apply(i);
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
}
System.out.println(results);
assertThat(results.get("alpha")).isCloseTo(30000, Percentage.withPercentage(2.0d));
}
/**
* In this test, alpha is 1/5 of the distinct names included.
*/
@Test
public void testByName() {
CSVSampler sampler = new CSVSampler("name", "does not matter", "name", "basicdata");
String value = sampler.apply(1);
Map<String,Double> results = new HashMap<>();
for (int i = 0; i < 100000; i++) {
String name = sampler.apply(i);
results.compute(name, (k,v) -> v==null ? 1d : v + 1d);
}
System.out.println(results);
assertThat(results.get("alpha")).isCloseTo(20000, Percentage.withPercentage(2.0d));
}
}

View File

@@ -0,0 +1,7 @@
NAME,WEIGHT,MEMO
alpha,1,this is sparta
beta,2,this is sparta
gamma,3,this is sparta
delta,4,this is sparta
epsilon,5,this is sparta
alpha,6,this is sparta
1 NAME WEIGHT MEMO
2 alpha 1 this is sparta
3 beta 2 this is sparta
4 gamma 3 this is sparta
5 delta 4 this is sparta
6 epsilon 5 this is sparta
7 alpha 6 this is sparta