mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2024-12-22 15:13:41 -06:00
update CqlVector usage to codec branch changes
This commit is contained in:
parent
ebf5286520
commit
96f9fc13ec
@ -20,7 +20,6 @@ import com.datastax.oss.driver.api.core.data.CqlVector;
|
|||||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||||
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector;
|
|
||||||
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
|
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@ -28,33 +27,22 @@ import java.util.List;
|
|||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Normalize a vector in List<Number> form, calling the appropriate conversion function
|
* Normalize a vector in {@link CqlVector<Float>} form. This presumes that the input type is
|
||||||
* depending on the component (Class) type of the incoming List values.
|
* Float, since we lose the type bounds on what is contained in the CQL type. If this doesn't match,
|
||||||
|
* then you will arbitrarily increase your storage cost, or otherwise haven truncation errors
|
||||||
|
* in your values.
|
||||||
*/
|
*/
|
||||||
@ThreadSafeMapper
|
@ThreadSafeMapper
|
||||||
@Categories(Category.experimental)
|
@Categories(Category.experimental)
|
||||||
public class NormalizeCqlVector implements Function<CqlVector, CqlVector> {
|
public class NormalizeCqlFloatVector implements Function<CqlVector<? extends Number>, CqlVector<? extends Number>> {
|
||||||
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
|
|
||||||
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
|
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CqlVector apply(CqlVector cqlVector) {
|
public CqlVector apply(CqlVector<? extends Number> cqlVector) {
|
||||||
|
int size = cqlVector.size();
|
||||||
List<Object> list = cqlVector.getValues();
|
final List<Float> newVector = new ArrayList<>(size);
|
||||||
if (list.isEmpty()) {
|
cqlVector.forEach(v -> newVector.add(v.floatValue()));
|
||||||
return CqlVector.of();
|
List<Float> normalized = nfv.apply(newVector);
|
||||||
} else if (list.get(0) instanceof Float) {
|
return CqlVector.newInstance(normalized);
|
||||||
List<Float> srcFloats = new ArrayList<>(list.size());
|
|
||||||
list.forEach(o -> srcFloats.add((Float) o));
|
|
||||||
List<Float> floats = nfv.apply(srcFloats);
|
|
||||||
return new CqlVector(floats);
|
|
||||||
} else if (list.get(0) instanceof Double) {
|
|
||||||
List<Double> srcDoubles = new ArrayList<>();
|
|
||||||
list.forEach(o -> srcDoubles.add((Double) o));
|
|
||||||
List<Double> doubles = ndv.apply(srcDoubles);
|
|
||||||
return new CqlVector(doubles);
|
|
||||||
} else {
|
|
||||||
throw new RuntimeException("Only Doubles and Floats are recognized.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -43,7 +43,7 @@ public class CqlVector implements LongFunction<com.datastax.oss.driver.api.core.
|
|||||||
@Override
|
@Override
|
||||||
public com.datastax.oss.driver.api.core.data.CqlVector apply(long cycle) {
|
public com.datastax.oss.driver.api.core.data.CqlVector apply(long cycle) {
|
||||||
List components = func.apply(cycle);
|
List components = func.apply(cycle);
|
||||||
com.datastax.oss.driver.api.core.data.CqlVector vector = new com.datastax.oss.driver.api.core.data.CqlVector<>(components);
|
com.datastax.oss.driver.api.core.data.CqlVector vector =com.datastax.oss.driver.api.core.data.CqlVector.newInstance(components);
|
||||||
return vector;
|
return vector;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,26 +26,33 @@ import java.util.List;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the incoming object List, Number, or Array to a CqlVector
|
* Convert the incoming object List, Number, or Array to a CqlVector
|
||||||
* using {@link CqlVector.Builder#add(Object[])}}. If any numeric value
|
* using {@link CqlVector#newInstance(Number...)}}. If any numeric value
|
||||||
* is passed in, then it becomes the only component of a 1D vector.
|
* is passed in, then it becomes the only component of a 1D vector.
|
||||||
* Otherwise, the individual values are added as vector components.
|
* Otherwise, the individual values are added as vector components.
|
||||||
*/
|
*/
|
||||||
@ThreadSafeMapper
|
@ThreadSafeMapper
|
||||||
@Categories(Category.experimental)
|
@Categories(Category.experimental)
|
||||||
public class ToCqlVector implements Function<Object, CqlVector> {
|
public class ToCqlVector implements Function<Object, CqlVector> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CqlVector apply(Object object) {
|
public CqlVector apply(Object object) {
|
||||||
Object[] ary = null;
|
|
||||||
if (object instanceof List list) {
|
if (object instanceof List list) {
|
||||||
ary = list.toArray();
|
if (list.size()==0) {
|
||||||
} else if (object instanceof Number number) {
|
return CqlVector.newInstance();
|
||||||
ary = new Object[]{number.floatValue()};
|
}
|
||||||
} else if (object.getClass().isArray()) {
|
Class<?> componentType = list.get(0).getClass();
|
||||||
ary = (Object[]) object;
|
if (componentType.equals(Float.TYPE)) {
|
||||||
|
return CqlVector.newInstance(((List<Float>) list).toArray(new Float[list.size()]));
|
||||||
|
} else if (componentType.equals(Double.TYPE)) {
|
||||||
|
return CqlVector.newInstance(((List<Double>)list).toArray(new Double[list.size()]));
|
||||||
|
} else if (componentType.equals(Long.TYPE)) {
|
||||||
|
return CqlVector.newInstance(((List<Long>)list).toArray(new Long[list.size()]));
|
||||||
|
} else if (componentType.equals(Integer.TYPE)) {
|
||||||
|
return CqlVector.newInstance(((List<Integer>)list).toArray(new Integer[list.size()]));
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unable to convert List of " + componentType.getSimpleName() + " to a CqlVector");
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Unsupported input type for CqlVector: " + object.getClass().getCanonicalName());
|
throw new RuntimeException("Unsupported input type for CqlVector: " + object.getClass().getCanonicalName());
|
||||||
}
|
}
|
||||||
return CqlVector.of(ary);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,12 +27,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|||||||
public class NormalizeCqlVectorTest {
|
public class NormalizeCqlVectorTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void normalizeCqlVectorFloats() {
|
public void normalizeCqlFloatVectorFloats() {
|
||||||
CqlVector square = CqlVector.of(1.0f, 1.0f);
|
CqlVector square = CqlVector.newInstance(1.0f, 1.0f);
|
||||||
NormalizeCqlVector nv = new NormalizeCqlVector();
|
NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
|
||||||
CqlVector normaled = nv.apply(square);
|
CqlVector normalized = nv.apply(square);
|
||||||
|
|
||||||
List sides = normaled.getValues();
|
List sides = normalized.stream().toList();
|
||||||
assertThat(sides.size()).isEqualTo(2);
|
assertThat(sides.size()).isEqualTo(2);
|
||||||
assertThat(sides.get(0)).isInstanceOf(Float.class);
|
assertThat(sides.get(0)).isInstanceOf(Float.class);
|
||||||
assertThat(sides.get(1)).isInstanceOf(Float.class);
|
assertThat(sides.get(1)).isInstanceOf(Float.class);
|
||||||
@ -42,39 +42,11 @@ public class NormalizeCqlVectorTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void normalizeCqlVectorDoubles() {
|
public void normalizeCqlVectorDoubles() {
|
||||||
CqlVector square = CqlVector.of(1.0d, 1.0d);
|
CqlVector square = CqlVector.newInstance(1.0d, 1.0d);
|
||||||
NormalizeCqlVector nv = new NormalizeCqlVector();
|
NormalizeCqlFloatVector nv = new NormalizeCqlFloatVector();
|
||||||
CqlVector normaled = nv.apply(square);
|
CqlVector normalized = nv.apply(square);
|
||||||
|
|
||||||
List sides = normaled.getValues();
|
List sides = normalized.stream().toList();
|
||||||
assertThat(sides.size()).isEqualTo(2);
|
|
||||||
assertThat(sides.get(0)).isInstanceOf(Double.class);
|
|
||||||
assertThat(sides.get(1)).isInstanceOf(Double.class);
|
|
||||||
assertThat(((Double)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
|
||||||
assertThat(((Double)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void normalizeCqlVectorFloatsV1() {
|
|
||||||
CqlVector square = CqlVector.of(1.0f, 1.0f);
|
|
||||||
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
|
|
||||||
CqlVector normaled = nv.apply(square);
|
|
||||||
|
|
||||||
List sides = normaled.getValues();
|
|
||||||
assertThat(sides.size()).isEqualTo(2);
|
|
||||||
assertThat(sides.get(0)).isInstanceOf(Float.class);
|
|
||||||
assertThat(sides.get(1)).isInstanceOf(Float.class);
|
|
||||||
assertThat(((Float)sides.get(0)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
|
||||||
assertThat(((Float)sides.get(1)).doubleValue()).isCloseTo(0.707, Offset.offset(0.001d));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void normalizeCqlVectorDoublesV1() {
|
|
||||||
CqlVector square = CqlVector.of(1.0d, 1.0d);
|
|
||||||
NormalizeCqlVectorV1 nv = new NormalizeCqlVectorV1();
|
|
||||||
CqlVector normaled = nv.apply(square);
|
|
||||||
|
|
||||||
List sides = normaled.getValues();
|
|
||||||
assertThat(sides.size()).isEqualTo(2);
|
assertThat(sides.size()).isEqualTo(2);
|
||||||
assertThat(sides.get(0)).isInstanceOf(Double.class);
|
assertThat(sides.get(0)).isInstanceOf(Double.class);
|
||||||
assertThat(sides.get(1)).isInstanceOf(Double.class);
|
assertThat(sides.get(1)).isInstanceOf(Double.class);
|
||||||
|
Loading…
Reference in New Issue
Block a user