diff --git a/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/Intersections.java b/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/Intersections.java new file mode 100644 index 000000000..1e4b4919a --- /dev/null +++ b/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/Intersections.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 nosqlbench + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.nosqlbench.engine.extensions.vectormath; + +import java.util.Arrays; + +public class Intersections { + + public static long[] find(long[] reference, long[] sample) { + long[] result = new long[reference.length]; + int a_index = 0, b_index = 0, acc_index = -1; + long a_element, b_element; + while (a_index < reference.length && b_index < sample.length) { + a_element = reference[a_index]; + b_element = sample[b_index]; + if (a_element == b_element) { + result = resize(result); + result[++acc_index] = a_element; + a_index++; + b_index++; + } else if (b_element < a_element) { + b_index++; + } else { + a_index++; + } + } + return Arrays.copyOfRange(result,0,acc_index+1); + } + + public static int[] find(int[] reference, int[] sample) { + int[] result = new int[reference.length]; + int a_index = 0, b_index = 0, acc_index = -1; + int a_element, b_element; + while (a_index < reference.length && b_index < sample.length) { + a_element = reference[a_index]; + b_element = sample[b_index]; + if (a_element == b_element) { + result = resize(result); + result[++acc_index] = a_element; + a_index++; + b_index++; + } else if (b_element < a_element) { + b_index++; + } else { + a_index++; + } + } + return Arrays.copyOfRange(result,0,acc_index+1); + } + + public static int[] resize(int[] arr) { + int len = arr.length; + int[] copy = new int[len + 1]; + for (int i = 0; i < len; i++) { + copy[i] = arr[i]; + } + return copy; + } + + public static long[] resize(long[] arr) { + int len = arr.length; + long[] copy = new long[len + 1]; + for (int i = 0; i < len; i++) { + copy[i] = arr[i]; + } + return copy; + } + +} diff --git a/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java b/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java index ccb97d95b..3e94cf0b8 100644 --- a/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java +++ b/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java @@ -17,16 +17,32 @@ package io.nosqlbench.engine.extensions.vectormath; import com.datastax.oss.driver.api.core.cql.Row; -import com.datastax.oss.driver.shaded.guava.common.collect.Sets; +import java.util.Arrays; import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; public class VectorMath { - public double computeRecall(List rows, List expectedRowIds) { - Set found = rows.stream().map(r -> r.getString("key")).collect(Collectors.toSet()); - Set expected = expectedRowIds.stream().map(String::valueOf).collect(Collectors.toSet()); - return ((double)Sets.intersection(found,expected).size()/(double)expected.size()); + + public static long[] rowsToLongArray(String fieldName, List rows) { + return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray(); } + + public static int[] rowListToIntArray(String fieldName, List rows) { + return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray(); + } + + public double computeRecall(long[] referenceIndexes, long[] sampleIndexes) { + Arrays.sort(referenceIndexes); + Arrays.sort(sampleIndexes); + long[] intersection = Intersections.find(referenceIndexes,sampleIndexes); + return (double)intersection.length/(double)referenceIndexes.length; + } + + public double computeRecall(int[] referenceIndexes, int[] sampleIndexes) { + Arrays.sort(referenceIndexes); + Arrays.sort(sampleIndexes); + int[] intersection = Intersections.find(referenceIndexes,sampleIndexes); + return (double)intersection.length/(double)referenceIndexes.length; + } + } diff --git a/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/IntersectionsTest.java b/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/IntersectionsTest.java new file mode 100644 index 000000000..fd6752fc1 --- /dev/null +++ b/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/IntersectionsTest.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023 nosqlbench + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.nosqlbench.engine.extensions.vectormath; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +class IntersectionsTest { + + @Test + public void testIntegerIntersection() { + int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8}); + assertThat(result).isEqualTo(new int[]{4,5}); + } + + @Test + public void testLongIntersection() { + long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8}); + assertThat(result).isEqualTo(new long[]{4,5}); + } + +} diff --git a/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java b/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java new file mode 100644 index 000000000..e498a1f8a --- /dev/null +++ b/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023 nosqlbench + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.nosqlbench.engine.extensions.vectormath; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; +class VectorMathTest { + + private VectorMath vm = new VectorMath(); + + @Test + void computeRecallForRowListVsLongIndexList() { + new VectorMath(); + } + + @Test + void computeRecallForRowListVsIntIndexList() { + } + + @Test + void computeRecallForRowListVsIntIndexArray() { + } + + @Test + void computeRecallForRowListVsLongIndexArray() { + } +}