update CqlVector usage to codec branch changes

This commit is contained in:
Jonathan Shook 2023-07-05 12:31:30 -05:00
parent ebf5286520
commit 96f9fc13ec
4 changed files with 37 additions and 70 deletions

View File

@ -20,7 +20,6 @@ import com.datastax.oss.driver.api.core.data.CqlVector;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
import java.util.ArrayList;
@ -28,33 +27,22 @@ import java.util.List;
import java.util.function.Function;
/**
* Normalize a vector in List<Number> form, calling the appropriate conversion function
* depending on the component (Class) type of the incoming List values.
* Normalize a vector in {@link CqlVector<Float>} form. This presumes that the input type is
* Float, since we lose the type bounds on what is contained in the CQL type. If this doesn't match,
* then you will arbitrarily increase your storage cost, or otherwise haven truncation errors
* in your values.
*/
@ThreadSafeMapper
@Categories(Category.experimental)
public class NormalizeCqlVector implements Function<CqlVector, CqlVector> {
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
public class NormalizeCqlFloatVector implements Function<CqlVector<? extends Number>, CqlVector<? extends Number>> {
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
@Override
public CqlVector apply(CqlVector cqlVector) {
List<Object> list = cqlVector.getValues();
if (list.isEmpty()) {
return CqlVector.of();
} else if (list.get(0) instanceof Float) {
List<Float> srcFloats = new ArrayList<>(list.size());
list.forEach(o -> srcFloats.add((Float) o));
List<Float> floats = nfv.apply(srcFloats);
return new CqlVector(floats);
} else if (list.get(0) instanceof Double) {
List<Double> srcDoubles = new ArrayList<>();
list.forEach(o -> srcDoubles.add((Double) o));
List<Double> doubles = ndv.apply(srcDoubles);
return new CqlVector(doubles);
} else {
throw new RuntimeException("Only Doubles and Floats are recognized.");
}
public CqlVector apply(CqlVector<? extends Number> cqlVector) {
int size = cqlVector.size();
final List<Float> newVector = new ArrayList<>(size);
cqlVector.forEach(v -> newVector.add(v.floatValue()));
List<Float> normalized = nfv.apply(newVector);
return CqlVector.newInstance(normalized);
}
}

View File

@ -43,7 +43,7 @@ public class CqlVector implements LongFunction<com.datastax.oss.driver.api.core.
@Override
public com.datastax.oss.driver.api.core.data.CqlVector apply(long cycle) {
List components = func.apply(cycle);
com.datastax.oss.driver.api.core.data.CqlVector vector = new com.datastax.oss.driver.api.core.data.CqlVector<>(components);
com.datastax.oss.driver.api.core.data.CqlVector vector =com.datastax.oss.driver.api.core.data.CqlVector.newInstance(components);
return vector;
}
}

View File

@ -26,26 +26,33 @@ import java.util.List;
/**
* Convert the incoming object List, Number, or Array to a CqlVector
* using {@link CqlVector.Builder#add(Object[])}}. If any numeric value
* using {@link CqlVector#newInstance(Number...)}}. If any numeric value
* is passed in, then it becomes the only component of a 1D vector.
* Otherwise, the individual values are added as vector components.
*/
@ThreadSafeMapper
@Categories(Category.experimental)
public class ToCqlVector implements Function<Object, CqlVector> {
@Override
public CqlVector apply(Object object) {
Object[] ary = null;
if (object instanceof List list) {
ary = list.toArray();
} else if (object instanceof Number number) {
ary = new Object[]{number.floatValue()};
} else if (object.getClass().isArray()) {
ary = (Object[]) object;
if (list.size()==0) {
return CqlVector.newInstance();
}
Class<?> componentType = list.get(0).getClass();
if (componentType.equals(Float.TYPE)) {
return CqlVector.newInstance(((List<Float>) list).toArray(new Float[list.size()]));
} else if (componentType.equals(Double.TYPE)) {
return CqlVector.newInstance(((List<Double>)list).toArray(new Double[list.size()]));
} else if (componentType.equals(Long.TYPE)) {
return CqlVector.newInstance(((List<Long>)list).toArray(new Long[list.size()]));
} else if (componentType.equals(Integer.TYPE)) {
return CqlVector.newInstance(((List<Integer>)list).toArray(new Integer[list.size()]));
} else {
throw new RuntimeException("Unable to convert List of " + componentType.getSimpleName() + " to a CqlVector");
}
} else {
throw new RuntimeException("Unsupported input type for CqlVector: " + object.getClass().getCanonicalName());
}
return CqlVector.of(ary);
}
}

View File

@ -27,12 +27,12 @@ import static org.assertj.core.api.Assertions.assertThat;
public class NormalizeCqlVectorTest {
@Test
public void normalizeCqlVectorFloats() {
CqlVector square = CqlVector.of(1.0f, 1.0f);
NormalizeCqlVector nv = new NormalizeCqlVector();
CqlVector normaled = nv.apply(square);
public void normalizeCqlFloatVectorFloats() {
CqlVector square = CqlVector.newInstance(1.0f, 1.0f);
NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
CqlVector normalized = nv.apply(square);
List sides = normaled.getValues();
List sides = normalized.stream().toList();
assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Float.class);
assertThat(sides.get(1)).isInstanceOf(Float.class);
@ -42,39 +42,11 @@ public class NormalizeCqlVectorTest {
@Test
public void normalizeCqlVectorDoubles() {
CqlVector square = CqlVector.of(1.0d, 1.0d);
NormalizeCqlVector nv = new NormalizeCqlVector();
CqlVector normaled = nv.apply(square);
CqlVector square = CqlVector.newInstance(1.0d, 1.0d);
NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
CqlVector normalized = nv.apply(square);
List sides = normaled.getValues();
assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Double.class);
assertThat(sides.get(1)).isInstanceOf(Double.class);
assertThat(((Double)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
assertThat(((Double)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
}
@Test
public void normalizeCqlVectorFloatsV1() {
CqlVector square = CqlVector.of(1.0f, 1.0f);
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
CqlVector normaled = nv.apply(square);
List sides = normaled.getValues();
assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Float.class);
assertThat(sides.get(1)).isInstanceOf(Float.class);
assertThat(((Float)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
assertThat(((Float)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
}
@Test
public void normalizeCqlVectorDoublesV1() {
CqlVector square = CqlVector.of(1.0d, 1.0d);
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
CqlVector normaled = nv.apply(square);
List sides = normaled.getValues();
List sides = normalized.stream().toList();
assertThat(sides.size()).isEqualTo(2);
assertThat(sides.get(0)).isInstanceOf(Double.class);
assertThat(sides.get(1)).isInstanceOf(Double.class);