update CqlVector usage to codec branch changes

This commit is contained in:
Jonathan Shook 2023-07-05 12:31:30 -05:00
parent ebf5286520
commit 96f9fc13ec
4 changed files with 37 additions and 70 deletions

View File

@ -20,7 +20,6 @@ import com.datastax.oss.driver.api.core.data.CqlVector;
import io.nosqlbench.virtdata.api.annotations.Categories; import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category; import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper; import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector; import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
import java.util.ArrayList; import java.util.ArrayList;
@ -28,33 +27,22 @@ import java.util.List;
import java.util.function.Function; import java.util.function.Function;
/** /**
* Normalize a vector in List<Number> form, calling the appropriate conversion function * Normalize a vector in {@link CqlVector<Float>} form. This presumes that the input type is
* depending on the component (Class) type of the incoming List values. * Float, since we lose the type bounds on what is contained in the CQL type. If this doesn't match,
* then you will arbitrarily increase your storage cost, or otherwise haven truncation errors
* in your values.
*/ */
@ThreadSafeMapper @ThreadSafeMapper
@Categories(Category.experimental) @Categories(Category.experimental)
public class NormalizeCqlVector implements Function<CqlVector, CqlVector> { public class NormalizeCqlFloatVector implements Function<CqlVector<? extends Number>, CqlVector<? extends Number>> {
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector(); private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
@Override @Override
public CqlVector apply(CqlVector cqlVector) { public CqlVector apply(CqlVector<? extends Number> cqlVector) {
int size = cqlVector.size();
List<Object> list = cqlVector.getValues(); final List<Float> newVector = new ArrayList<>(size);
if (list.isEmpty()) { cqlVector.forEach(v -> newVector.add(v.floatValue()));
return CqlVector.of(); List<Float> normalized = nfv.apply(newVector);
} else if (list.get(0) instanceof Float) { return CqlVector.newInstance(normalized);
List<Float> srcFloats = new ArrayList<>(list.size());
list.forEach(o -> srcFloats.add((Float) o));
List<Float> floats = nfv.apply(srcFloats);
return new CqlVector(floats);
} else if (list.get(0) instanceof Double) {
List<Double> srcDoubles = new ArrayList<>();
list.forEach(o -> srcDoubles.add((Double) o));
List<Double> doubles = ndv.apply(srcDoubles);
return new CqlVector(doubles);
} else {
throw new RuntimeException("Only Doubles and Floats are recognized.");
}
} }
} }

View File

@ -43,7 +43,7 @@ public class CqlVector implements LongFunction<com.datastax.oss.driver.api.core.
@Override @Override
public com.datastax.oss.driver.api.core.data.CqlVector apply(long cycle) { public com.datastax.oss.driver.api.core.data.CqlVector apply(long cycle) {
List components = func.apply(cycle); List components = func.apply(cycle);
com.datastax.oss.driver.api.core.data.CqlVector vector = new com.datastax.oss.driver.api.core.data.CqlVector<>(components); com.datastax.oss.driver.api.core.data.CqlVector vector =com.datastax.oss.driver.api.core.data.CqlVector.newInstance(components);
return vector; return vector;
} }
} }

View File

@ -26,26 +26,33 @@ import java.util.List;
/** /**
* Convert the incoming object List, Number, or Array to a CqlVector * Convert the incoming object List, Number, or Array to a CqlVector
* using {@link CqlVector.Builder#add(Object[])}}. If any numeric value * using {@link CqlVector#newInstance(Number...)}}. If any numeric value
* is passed in, then it becomes the only component of a 1D vector. * is passed in, then it becomes the only component of a 1D vector.
* Otherwise, the individual values are added as vector components. * Otherwise, the individual values are added as vector components.
*/ */
@ThreadSafeMapper @ThreadSafeMapper
@Categories(Category.experimental) @Categories(Category.experimental)
public class ToCqlVector implements Function<Object, CqlVector> { public class ToCqlVector implements Function<Object, CqlVector> {
@Override @Override
public CqlVector apply(Object object) { public CqlVector apply(Object object) {
Object[] ary = null;
if (object instanceof List list) { if (object instanceof List list) {
ary = list.toArray(); if (list.size()==0) {
} else if (object instanceof Number number) { return CqlVector.newInstance();
ary = new Object[]{number.floatValue()}; }
} else if (object.getClass().isArray()) { Class<?> componentType = list.get(0).getClass();
ary = (Object[]) object; if (componentType.equals(Float.TYPE)) {
return CqlVector.newInstance(((List<Float>) list).toArray(new Float[list.size()]));
} else if (componentType.equals(Double.TYPE)) {
return CqlVector.newInstance(((List<Double>)list).toArray(new Double[list.size()]));
} else if (componentType.equals(Long.TYPE)) {
return CqlVector.newInstance(((List<Long>)list).toArray(new Long[list.size()]));
} else if (componentType.equals(Integer.TYPE)) {
return CqlVector.newInstance(((List<Integer>)list).toArray(new Integer[list.size()]));
} else {
throw new RuntimeException("Unable to convert List of " + componentType.getSimpleName() + " to a CqlVector");
}
} else { } else {
throw new RuntimeException("Unsupported input type for CqlVector: " + object.getClass().getCanonicalName()); throw new RuntimeException("Unsupported input type for CqlVector: " + object.getClass().getCanonicalName());
} }
return CqlVector.of(ary);
} }
} }

View File

@ -27,12 +27,12 @@ import static org.assertj.core.api.Assertions.assertThat;
public class NormalizeCqlVectorTest { public class NormalizeCqlVectorTest {
@Test @Test
public void normalizeCqlVectorFloats() { public void normalizeCqlFloatVectorFloats() {
CqlVector square = CqlVector.of(1.0f, 1.0f); CqlVector square = CqlVector.newInstance(1.0f, 1.0f);
NormalizeCqlVector nv = new NormalizeCqlVector(); NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
CqlVector normaled = nv.apply(square); CqlVector normalized = nv.apply(square);
List sides = normaled.getValues(); List sides = normalized.stream().toList();
assertThat(sides.size()).isEqualTo(2); assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Float.class); assertThat(sides.get(0)).isInstanceOf(Float.class);
assertThat(sides.get(1)).isInstanceOf(Float.class); assertThat(sides.get(1)).isInstanceOf(Float.class);
@ -42,39 +42,11 @@ public class NormalizeCqlVectorTest {
@Test @Test
public void normalizeCqlVectorDoubles() { public void normalizeCqlVectorDoubles() {
CqlVector square = CqlVector.of(1.0d, 1.0d); CqlVector square = CqlVector.newInstance(1.0d, 1.0d);
NormalizeCqlVector nv = new NormalizeCqlVector(); NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
CqlVector normaled = nv.apply(square); CqlVector normalized = nv.apply(square);
List sides = normaled.getValues(); List sides = normalized.stream().toList();
assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Double.class);
assertThat(sides.get(1)).isInstanceOf(Double.class);
assertThat(((Double)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
assertThat(((Double)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
}
@Test
public void normalizeCqlVectorFloatsV1() {
CqlVector square = CqlVector.of(1.0f, 1.0f);
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
CqlVector normaled = nv.apply(square);
List sides = normaled.getValues();
assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Float.class);
assertThat(sides.get(1)).isInstanceOf(Float.class);
assertThat(((Float)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
assertThat(((Float)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
}
@Test
public void normalizeCqlVectorDoublesV1() {
CqlVector square = CqlVector.of(1.0d, 1.0d);
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
CqlVector normaled = nv.apply(square);
List sides = normaled.getValues();
assertThat(sides.size()).isEqualTo(2); assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Double.class); assertThat(sides.get(0)).isInstanceOf(Double.class);
assertThat(sides.get(1)).isInstanceOf(Double.class); assertThat(sides.get(1)).isInstanceOf(Double.class);