mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
improved vectormath functions
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.engine.extensions.vectormath;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class Intersections {
|
||||
|
||||
public static long[] find(long[] reference, long[] sample) {
|
||||
long[] result = new long[reference.length];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
long a_element, b_element;
|
||||
while (a_index < reference.length && b_index < sample.length) {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
result = resize(result);
|
||||
result[++acc_index] = a_element;
|
||||
a_index++;
|
||||
b_index++;
|
||||
} else if (b_element < a_element) {
|
||||
b_index++;
|
||||
} else {
|
||||
a_index++;
|
||||
}
|
||||
}
|
||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
||||
}
|
||||
|
||||
public static int[] find(int[] reference, int[] sample) {
|
||||
int[] result = new int[reference.length];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
int a_element, b_element;
|
||||
while (a_index < reference.length && b_index < sample.length) {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
result = resize(result);
|
||||
result[++acc_index] = a_element;
|
||||
a_index++;
|
||||
b_index++;
|
||||
} else if (b_element < a_element) {
|
||||
b_index++;
|
||||
} else {
|
||||
a_index++;
|
||||
}
|
||||
}
|
||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
||||
}
|
||||
|
||||
public static int[] resize(int[] arr) {
|
||||
int len = arr.length;
|
||||
int[] copy = new int[len + 1];
|
||||
for (int i = 0; i < len; i++) {
|
||||
copy[i] = arr[i];
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
public static long[] resize(long[] arr) {
|
||||
int len = arr.length;
|
||||
long[] copy = new long[len + 1];
|
||||
for (int i = 0; i < len; i++) {
|
||||
copy[i] = arr[i];
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -17,16 +17,32 @@
|
||||
package io.nosqlbench.engine.extensions.vectormath;
|
||||
|
||||
import com.datastax.oss.driver.api.core.cql.Row;
|
||||
import com.datastax.oss.driver.shaded.guava.common.collect.Sets;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class VectorMath {
|
||||
public double computeRecall(List<Row> rows, List<Long> expectedRowIds) {
|
||||
Set<String> found = rows.stream().map(r -> r.getString("key")).collect(Collectors.toSet());
|
||||
Set<String> expected = expectedRowIds.stream().map(String::valueOf).collect(Collectors.toSet());
|
||||
return ((double)Sets.intersection(found,expected).size()/(double)expected.size());
|
||||
|
||||
public static long[] rowsToLongArray(String fieldName, List<Row> rows) {
|
||||
return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray();
|
||||
}
|
||||
|
||||
public static int[] rowListToIntArray(String fieldName, List<Row> rows) {
|
||||
return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray();
|
||||
}
|
||||
|
||||
public double computeRecall(long[] referenceIndexes, long[] sampleIndexes) {
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
long[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
|
||||
return (double)intersection.length/(double)referenceIndexes.length;
|
||||
}
|
||||
|
||||
public double computeRecall(int[] referenceIndexes, int[] sampleIndexes) {
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
int[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
|
||||
return (double)intersection.length/(double)referenceIndexes.length;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.engine.extensions.vectormath;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
class IntersectionsTest {
|
||||
|
||||
@Test
|
||||
public void testIntegerIntersection() {
|
||||
int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8});
|
||||
assertThat(result).isEqualTo(new int[]{4,5});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongIntersection() {
|
||||
long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8});
|
||||
assertThat(result).isEqualTo(new long[]{4,5});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.engine.extensions.vectormath;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
class VectorMathTest {
|
||||
|
||||
private VectorMath vm = new VectorMath();
|
||||
|
||||
@Test
|
||||
void computeRecallForRowListVsLongIndexList() {
|
||||
new VectorMath();
|
||||
}
|
||||
|
||||
@Test
|
||||
void computeRecallForRowListVsIntIndexList() {
|
||||
}
|
||||
|
||||
@Test
|
||||
void computeRecallForRowListVsIntIndexArray() {
|
||||
}
|
||||
|
||||
@Test
|
||||
void computeRecallForRowListVsLongIndexArray() {
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user