mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2024-12-22 15:13:41 -06:00
update CqlVector usage to codec branch changes
This commit is contained in:
parent
ebf5286520
commit
96f9fc13ec
@ -20,7 +20,6 @@ import com.datastax.oss.driver.api.core.data.CqlVector;
|
||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector;
|
||||
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -28,33 +27,22 @@ import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* Normalize a vector in List<Number> form, calling the appropriate conversion function
|
||||
* depending on the component (Class) type of the incoming List values.
|
||||
* Normalize a vector in {@link CqlVector<Float>} form. This presumes that the input type is
|
||||
* Float, since we lose the type bounds on what is contained in the CQL type. If this doesn't match,
|
||||
* then you will arbitrarily increase your storage cost, or otherwise haven truncation errors
|
||||
* in your values.
|
||||
*/
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class NormalizeCqlVector implements Function<CqlVector, CqlVector> {
|
||||
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
|
||||
public class NormalizeCqlFloatVector implements Function<CqlVector<? extends Number>, CqlVector<? extends Number>> {
|
||||
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
|
||||
|
||||
@Override
|
||||
public CqlVector apply(CqlVector cqlVector) {
|
||||
|
||||
List<Object> list = cqlVector.getValues();
|
||||
if (list.isEmpty()) {
|
||||
return CqlVector.of();
|
||||
} else if (list.get(0) instanceof Float) {
|
||||
List<Float> srcFloats = new ArrayList<>(list.size());
|
||||
list.forEach(o -> srcFloats.add((Float) o));
|
||||
List<Float> floats = nfv.apply(srcFloats);
|
||||
return new CqlVector(floats);
|
||||
} else if (list.get(0) instanceof Double) {
|
||||
List<Double> srcDoubles = new ArrayList<>();
|
||||
list.forEach(o -> srcDoubles.add((Double) o));
|
||||
List<Double> doubles = ndv.apply(srcDoubles);
|
||||
return new CqlVector(doubles);
|
||||
} else {
|
||||
throw new RuntimeException("Only Doubles and Floats are recognized.");
|
||||
}
|
||||
public CqlVector apply(CqlVector<? extends Number> cqlVector) {
|
||||
int size = cqlVector.size();
|
||||
final List<Float> newVector = new ArrayList<>(size);
|
||||
cqlVector.forEach(v -> newVector.add(v.floatValue()));
|
||||
List<Float> normalized = nfv.apply(newVector);
|
||||
return CqlVector.newInstance(normalized);
|
||||
}
|
||||
}
|
@ -43,7 +43,7 @@ public class CqlVector implements LongFunction<com.datastax.oss.driver.api.core.
|
||||
@Override
|
||||
public com.datastax.oss.driver.api.core.data.CqlVector apply(long cycle) {
|
||||
List components = func.apply(cycle);
|
||||
com.datastax.oss.driver.api.core.data.CqlVector vector = new com.datastax.oss.driver.api.core.data.CqlVector<>(components);
|
||||
com.datastax.oss.driver.api.core.data.CqlVector vector =com.datastax.oss.driver.api.core.data.CqlVector.newInstance(components);
|
||||
return vector;
|
||||
}
|
||||
}
|
||||
|
@ -26,26 +26,33 @@ import java.util.List;
|
||||
|
||||
/**
|
||||
* Convert the incoming object List, Number, or Array to a CqlVector
|
||||
* using {@link CqlVector.Builder#add(Object[])}}. If any numeric value
|
||||
* using {@link CqlVector#newInstance(Number...)}}. If any numeric value
|
||||
* is passed in, then it becomes the only component of a 1D vector.
|
||||
* Otherwise, the individual values are added as vector components.
|
||||
*/
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class ToCqlVector implements Function<Object, CqlVector> {
|
||||
|
||||
@Override
|
||||
public CqlVector apply(Object object) {
|
||||
Object[] ary = null;
|
||||
if (object instanceof List list) {
|
||||
ary = list.toArray();
|
||||
} else if (object instanceof Number number) {
|
||||
ary = new Object[]{number.floatValue()};
|
||||
} else if (object.getClass().isArray()) {
|
||||
ary = (Object[]) object;
|
||||
if (list.size()==0) {
|
||||
return CqlVector.newInstance();
|
||||
}
|
||||
Class<?> componentType = list.get(0).getClass();
|
||||
if (componentType.equals(Float.TYPE)) {
|
||||
return CqlVector.newInstance(((List<Float>) list).toArray(new Float[list.size()]));
|
||||
} else if (componentType.equals(Double.TYPE)) {
|
||||
return CqlVector.newInstance(((List<Double>)list).toArray(new Double[list.size()]));
|
||||
} else if (componentType.equals(Long.TYPE)) {
|
||||
return CqlVector.newInstance(((List<Long>)list).toArray(new Long[list.size()]));
|
||||
} else if (componentType.equals(Integer.TYPE)) {
|
||||
return CqlVector.newInstance(((List<Integer>)list).toArray(new Integer[list.size()]));
|
||||
} else {
|
||||
throw new RuntimeException("Unable to convert List of " + componentType.getSimpleName() + " to a CqlVector");
|
||||
}
|
||||
} else {
|
||||
throw new RuntimeException("Unsupported input type for CqlVector: " + object.getClass().getCanonicalName());
|
||||
}
|
||||
return CqlVector.of(ary);
|
||||
}
|
||||
}
|
||||
|
@ -27,12 +27,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
public class NormalizeCqlVectorTest {
|
||||
|
||||
@Test
|
||||
public void normalizeCqlVectorFloats() {
|
||||
CqlVector square = CqlVector.of(1.0f, 1.0f);
|
||||
NormalizeCqlVector nv = new NormalizeCqlVector();
|
||||
CqlVector normaled = nv.apply(square);
|
||||
public void normalizeCqlFloatVectorFloats() {
|
||||
CqlVector square = CqlVector.newInstance(1.0f, 1.0f);
|
||||
NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
|
||||
CqlVector normalized = nv.apply(square);
|
||||
|
||||
List sides = normaled.getValues();
|
||||
List sides = normalized.stream().toList();
|
||||
assertThat(sides.size()).isEqualTo(2);
|
||||
assertThat(sides.get(0)).isInstanceOf(Float.class);
|
||||
assertThat(sides.get(1)).isInstanceOf(Float.class);
|
||||
@ -42,39 +42,11 @@ public class NormalizeCqlVectorTest {
|
||||
|
||||
@Test
|
||||
public void normalizeCqlVectorDoubles() {
|
||||
CqlVector square = CqlVector.of(1.0d, 1.0d);
|
||||
NormalizeCqlVector nv = new NormalizeCqlVector();
|
||||
CqlVector normaled = nv.apply(square);
|
||||
CqlVector square = CqlVector.newInstance(1.0d, 1.0d);
|
||||
NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
|
||||
CqlVector normalized = nv.apply(square);
|
||||
|
||||
List sides = normaled.getValues();
|
||||
assertThat(sides.size()).isEqualTo(2);
|
||||
assertThat(sides.get(0)).isInstanceOf(Double.class);
|
||||
assertThat(sides.get(1)).isInstanceOf(Double.class);
|
||||
assertThat(((Double)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
||||
assertThat(((Double)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void normalizeCqlVectorFloatsV1() {
|
||||
CqlVector square = CqlVector.of(1.0f, 1.0f);
|
||||
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
|
||||
CqlVector normaled = nv.apply(square);
|
||||
|
||||
List sides = normaled.getValues();
|
||||
assertThat(sides.size()).isEqualTo(2);
|
||||
assertThat(sides.get(0)).isInstanceOf(Float.class);
|
||||
assertThat(sides.get(1)).isInstanceOf(Float.class);
|
||||
assertThat(((Float)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
||||
assertThat(((Float)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void normalizeCqlVectorDoublesV1() {
|
||||
CqlVector square = CqlVector.of(1.0d, 1.0d);
|
||||
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
|
||||
CqlVector normaled = nv.apply(square);
|
||||
|
||||
List sides = normaled.getValues();
|
||||
List sides = normalized.stream().toList();
|
||||
assertThat(sides.size()).isEqualTo(2);
|
||||
assertThat(sides.get(0)).isInstanceOf(Double.class);
|
||||
assertThat(sides.get(1)).isInstanceOf(Double.class);
|
||||
|
Loading…
Reference in New Issue
Block a user