improved vectormath functions

This commit is contained in:
Jonathan Shook
2023-08-18 12:24:21 -05:00
parent b38543266d
commit 5d519d1984
4 changed files with 184 additions and 7 deletions

View File

@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.vectormath;
import java.util.Arrays;
public class Intersections {
public static long[] find(long[] reference, long[] sample) {
long[] result = new long[reference.length];
int a_index = 0, b_index = 0, acc_index = -1;
long a_element, b_element;
while (a_index < reference.length && b_index < sample.length) {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
result = resize(result);
result[++acc_index] = a_element;
a_index++;
b_index++;
} else if (b_element < a_element) {
b_index++;
} else {
a_index++;
}
}
return Arrays.copyOfRange(result,0,acc_index+1);
}
public static int[] find(int[] reference, int[] sample) {
int[] result = new int[reference.length];
int a_index = 0, b_index = 0, acc_index = -1;
int a_element, b_element;
while (a_index < reference.length && b_index < sample.length) {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
result = resize(result);
result[++acc_index] = a_element;
a_index++;
b_index++;
} else if (b_element < a_element) {
b_index++;
} else {
a_index++;
}
}
return Arrays.copyOfRange(result,0,acc_index+1);
}
public static int[] resize(int[] arr) {
int len = arr.length;
int[] copy = new int[len + 1];
for (int i = 0; i < len; i++) {
copy[i] = arr[i];
}
return copy;
}
public static long[] resize(long[] arr) {
int len = arr.length;
long[] copy = new long[len + 1];
for (int i = 0; i < len; i++) {
copy[i] = arr[i];
}
return copy;
}
}

View File

@@ -17,16 +17,32 @@
package io.nosqlbench.engine.extensions.vectormath;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.shaded.guava.common.collect.Sets;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class VectorMath {
public double computeRecall(List<Row> rows, List<Long> expectedRowIds) {
Set<String> found = rows.stream().map(r -> r.getString("key")).collect(Collectors.toSet());
Set<String> expected = expectedRowIds.stream().map(String::valueOf).collect(Collectors.toSet());
return ((double)Sets.intersection(found,expected).size()/(double)expected.size());
public static long[] rowsToLongArray(String fieldName, List<Row> rows) {
return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray();
}
public static int[] rowListToIntArray(String fieldName, List<Row> rows) {
return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray();
}
public double computeRecall(long[] referenceIndexes, long[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
return (double)intersection.length/(double)referenceIndexes.length;
}
public double computeRecall(int[] referenceIndexes, int[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
return (double)intersection.length/(double)referenceIndexes.length;
}
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.vectormath;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class IntersectionsTest {
@Test
public void testIntegerIntersection() {
int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8});
assertThat(result).isEqualTo(new int[]{4,5});
}
@Test
public void testLongIntersection() {
long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8});
assertThat(result).isEqualTo(new long[]{4,5});
}
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.vectormath;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class VectorMathTest {
private VectorMath vm = new VectorMath();
@Test
void computeRecallForRowListVsLongIndexList() {
new VectorMath();
}
@Test
void computeRecallForRowListVsIntIndexList() {
}
@Test
void computeRecallForRowListVsIntIndexArray() {
}
@Test
void computeRecallForRowListVsLongIndexArray() {
}
}