From 28348180c18d8f1f2ce7bf822c6d90b22463bc1b Mon Sep 17 00:00:00 2001 From: Jonathan Shook Date: Fri, 18 Aug 2023 17:05:45 -0500 Subject: [PATCH] more vector functions --- .../extensions/vectormath/VectorMath.java | 27 ++++++++++++++++--- .../extensions/vectormath/VectorMathTest.java | 4 +-- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java b/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java index 3e94cf0b8..a529d0b21 100644 --- a/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java +++ b/adapter-cqld4/src/main/java/io/nosqlbench/engine/extensions/vectormath/VectorMath.java @@ -23,26 +23,47 @@ import java.util.List; public class VectorMath { - public static long[] rowsToLongArray(String fieldName, List rows) { + public static long[] rowFieldsToLongArray(String fieldName, List rows) { return rows.stream().mapToLong(r -> r.getLong(fieldName)).toArray(); } + public static String[] rowFieldsToStringArray(String fieldName, List rows) { + return rows.stream().map(r -> r.getString(fieldName)).toArray(String[]::new); + } + + public static long[] stringArrayAsALongArray(String[] strings) { + long[] longs = new long[strings.length]; + for (int i = 0; i < longs.length; i++) { + longs[i]=Long.parseLong(strings[i]); + } + return longs; + } + + public static int[] stringArrayAsIntArray(String[] strings) { + int[] ints = new int[strings.length]; + for (int i = 0; i < ints.length; i++) { + ints[i]=Integer.parseInt(strings[i]); + } + return ints; + } + public static int[] rowListToIntArray(String fieldName, List rows) { return rows.stream().mapToInt(r -> r.getInt(fieldName)).toArray(); } - public double computeRecall(long[] referenceIndexes, long[] sampleIndexes) { + public static double computeRecall(long[] referenceIndexes, long[] sampleIndexes) { Arrays.sort(referenceIndexes); Arrays.sort(sampleIndexes); long[] intersection = Intersections.find(referenceIndexes,sampleIndexes); return (double)intersection.length/(double)referenceIndexes.length; } - public double computeRecall(int[] referenceIndexes, int[] sampleIndexes) { + public static double computeRecall(int[] referenceIndexes, int[] sampleIndexes) { Arrays.sort(referenceIndexes); Arrays.sort(sampleIndexes); int[] intersection = Intersections.find(referenceIndexes,sampleIndexes); return (double)intersection.length/(double)referenceIndexes.length; } + } diff --git a/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java b/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java index e498a1f8a..334a6c008 100644 --- a/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java +++ b/adapter-cqld4/src/test/java/io/nosqlbench/engine/extensions/vectormath/VectorMathTest.java @@ -17,15 +17,13 @@ package io.nosqlbench.engine.extensions.vectormath; import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; class VectorMathTest { private VectorMath vm = new VectorMath(); @Test void computeRecallForRowListVsLongIndexList() { - new VectorMath(); + VectorMath.computeRecall(new long[]{}, new long[]{}); } @Test