mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
ninja fix cql driver changes
This commit is contained in:
parent
cb4553eb40
commit
b2da44bd10
@ -28,33 +28,35 @@ import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* Normalize a vector in List<Number> form, calling the appropriate conversion function
|
||||
* depending on the component (Class) type of the incoming List values.
|
||||
*/
|
||||
Normalize a vector in List<Number> form, calling the appropriate conversion function
|
||||
depending on the component (Class) type of the incoming List values. */
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class NormalizeCqlVector implements Function<CqlVector, CqlVector> {
|
||||
public class NormalizeCqlVector<N extends Number> implements Function<CqlVector<N>, CqlVector<N>> {
|
||||
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
|
||||
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
|
||||
|
||||
@Override
|
||||
public CqlVector apply(CqlVector cqlVector) {
|
||||
public CqlVector apply(CqlVector<N> cqlVector) {
|
||||
double[] vals = new double[cqlVector.size()];
|
||||
double accumulator= 0.0d;
|
||||
double accumulator = 0.0d;
|
||||
for (int i = 0; i < vals.length; i++) {
|
||||
vals[i]=cqlVector.get(i).doubleValue();
|
||||
accumulator+=vals[i]*vals[i];
|
||||
vals[i] = cqlVector.get(i).doubleValue();
|
||||
accumulator += vals[i] * vals[i];
|
||||
}
|
||||
double factor = 1.0d/Math.sqrt(Arrays.stream(vals).map(d -> d * d).sum());
|
||||
double factor = 1.0d / Math.sqrt(Arrays.stream(vals).map(d -> d * d).sum());
|
||||
|
||||
if (cqlVector.get(0) instanceof Float) {
|
||||
List<Float> list = Arrays.stream(vals).mapToObj(d -> Float.valueOf((float) (d * factor))).toList();
|
||||
List<Float> list =
|
||||
Arrays.stream(vals).mapToObj(d -> Float.valueOf((float) (d * factor))).toList();
|
||||
return CqlVector.newInstance(list);
|
||||
} else if (cqlVector.get(0) instanceof Double) {
|
||||
List<Double> list = Arrays.stream(vals).mapToObj(d -> Double.valueOf((float) (d * factor))).toList();
|
||||
List<Double> list =
|
||||
Arrays.stream(vals).mapToObj(d -> Double.valueOf((float) (d * factor))).toList();
|
||||
return CqlVector.newInstance(list);
|
||||
} else {
|
||||
throw new RuntimeException(NormalizeCqlVector.class.getCanonicalName()+ " only supports Double and Float type");
|
||||
throw new RuntimeException(NormalizeCqlVector.class.getCanonicalName()
|
||||
+ " only supports Double and Float type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -46,9 +46,9 @@ public class NormalizeCqlVectorTest {
|
||||
|
||||
@Test
|
||||
public void normalizeCqlVectorDoubles() {
|
||||
CqlVector square = CqlVector.newInstance(1.0d, 1.0d);
|
||||
CqlVector<Number> square = CqlVector.newInstance(1.0d, 1.0d);
|
||||
NormalizeCqlVector nv = new NormalizeCqlVector();
|
||||
CqlVector normaled = nv.apply(square);
|
||||
CqlVector<Number> normaled = nv.apply(square);
|
||||
|
||||
assertThat(normaled.size()).isEqualTo(2);
|
||||
assertThat(normaled.get(0)).isInstanceOf(Double.class);
|
||||
|
Loading…
Reference in New Issue
Block a user