more vector functions

This commit is contained in:
Jonathan Shook 2023-08-18 17:05:45 -05:00
parent 302213c3c6
commit 28348180c1
2 changed files with 25 additions and 6 deletions

View File

@ -23,26 +23,47 @@ import java.util.List;
public class VectorMath {
public static long[] rowsToLongArray(String fieldName, List<Row> rows) {
public static long[] rowFieldsToLongArray(String fieldName, List<Row> rows) {
return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray();
}
public static String[] rowFieldsToStringArray(String fieldName, List<Row> rows) {
return rows.stream().map(r -> r.getString(fieldName)).toArray(String[]::new);
}
public static long[] stringArrayAsALongArray(String[] strings) {
long[] longs = new long[strings.length];
for (int i = 0; i < longs.length; i++) {
longs[i]=Long.parseLong(strings[i]);
}
return longs;
}
public static int[] stringArrayAsIntArray(String[] strings) {
int[] ints = new int[strings.length];
for (int i = 0; i < ints.length; i++) {
ints[i]=Integer.parseInt(strings[i]);
}
return ints;
}
public static int[] rowListToIntArray(String fieldName, List<Row> rows) {
return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray();
}
public double computeRecall(long[] referenceIndexes, long[] sampleIndexes) {
public static double computeRecall(long[] referenceIndexes, long[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
return (double)intersection.length/(double)referenceIndexes.length;
}
public double computeRecall(int[] referenceIndexes, int[] sampleIndexes) {
public static double computeRecall(int[] referenceIndexes, int[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
return (double)intersection.length/(double)referenceIndexes.length;
}
}

View File

@ -17,15 +17,13 @@
package io.nosqlbench.engine.extensions.vectormath;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class VectorMathTest {
private VectorMath vm = new VectorMath();
@Test
void computeRecallForRowListVsLongIndexList() {
new VectorMath();
VectorMath.computeRecall(new long[]{}, new long[]{});
}
@Test