more vector functions

This commit is contained in:
Jonathan Shook 2023-08-18 17:05:45 -05:00
parent 302213c3c6
commit 28348180c1
2 changed files with 25 additions and 6 deletions

View File

@ -23,26 +23,47 @@ import java.util.List;
public class VectorMath { public class VectorMath {
public static long[] rowsToLongArray(String fieldName, List<Row> rows) { public static long[] rowFieldsToLongArray(String fieldName, List<Row> rows) {
return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray(); return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray();
} }
public static String[] rowFieldsToStringArray(String fieldName, List<Row> rows) {
return rows.stream().map(r -> r.getString(fieldName)).toArray(String[]::new);
}
public static long[] stringArrayAsALongArray(String[] strings) {
long[] longs = new long[strings.length];
for (int i = 0; i < longs.length; i++) {
longs[i]=Long.parseLong(strings[i]);
}
return longs;
}
public static int[] stringArrayAsIntArray(String[] strings) {
int[] ints = new int[strings.length];
for (int i = 0; i < ints.length; i++) {
ints[i]=Integer.parseInt(strings[i]);
}
return ints;
}
public static int[] rowListToIntArray(String fieldName, List<Row> rows) { public static int[] rowListToIntArray(String fieldName, List<Row> rows) {
return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray(); return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray();
} }
public double computeRecall(long[] referenceIndexes, long[] sampleIndexes) { public static double computeRecall(long[] referenceIndexes, long[] sampleIndexes) {
Arrays.sort(referenceIndexes); Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes); Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes,sampleIndexes); long[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
return (double)intersection.length/(double)referenceIndexes.length; return (double)intersection.length/(double)referenceIndexes.length;
} }
public double computeRecall(int[] referenceIndexes, int[] sampleIndexes) { public static double computeRecall(int[] referenceIndexes, int[] sampleIndexes) {
Arrays.sort(referenceIndexes); Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes); Arrays.sort(sampleIndexes);
int[] intersection = Intersections.find(referenceIndexes,sampleIndexes); int[] intersection = Intersections.find(referenceIndexes,sampleIndexes);
return (double)intersection.length/(double)referenceIndexes.length; return (double)intersection.length/(double)referenceIndexes.length;
} }
} }

View File

@ -17,15 +17,13 @@
package io.nosqlbench.engine.extensions.vectormath; package io.nosqlbench.engine.extensions.vectormath;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class VectorMathTest { class VectorMathTest {
private VectorMath vm = new VectorMath(); private VectorMath vm = new VectorMath();
@Test @Test
void computeRecallForRowListVsLongIndexList() { void computeRecallForRowListVsLongIndexList() {
new VectorMath(); VectorMath.computeRecall(new long[]{}, new long[]{});
} }
@Test @Test