Merge pull request #1525 from nosqlbench/nosqlbench-1522-moremath

Nosqlbench 1522 moremath
This commit is contained in:
Jonathan Shook 2023-09-08 11:10:14 -05:00 committed by GitHub
commit 6005aa65d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 335 additions and 194 deletions

View File

@ -17,6 +17,8 @@
package io.nosqlbench.engine.extensions.computefunctions; package io.nosqlbench.engine.extensions.computefunctions;
import java.util.Arrays; import java.util.Arrays;
import java.util.DoubleSummaryStatistics;
import java.util.HashSet;
/** /**
* <P>A collection of compute functions related to vector search relevancy. * <P>A collection of compute functions related to vector search relevancy.
@ -29,7 +31,8 @@ import java.util.Arrays;
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result * metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
* to do exactly what would have been done for a test where you actually query for that K limit. * to do exactly what would have been done for a test where you actually query for that K limit.
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results * <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
* are passed to these functions as ranked in results.</STRONG></P> * are passed to these functions as ranked in results.</STRONG></P> Some of the methods apply K to the
* expected (relevant) indices, others to the actual (response) indices, depending on what is appropriate.
* *
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P> * <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array, * Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
@ -41,91 +44,91 @@ public class ComputeFunctions {
/** /**
* Compute the recall as the proportion of matching indices divided by the expected indices * Compute the recall as the proportion of matching indices divided by the expected indices
* *
* @param referenceIndexes * @param relevant
* long array of indices * long array of indices
* @param sampleIndexes * @param actual
* long array of indices * long array of indices
* @return a fractional measure of matching vs expected indices * @return a fractional measure of matching vs expected indices
*/ */
public static double recall(long[] referenceIndexes, long[] sampleIndexes) { public static double recall(long[] relevant, long[] actual) {
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes); long[] intersection = Intersections.find(relevant, actual);
return (double) intersection.length / (double) referenceIndexes.length; return (double) intersection.length / (double) relevant.length;
} }
public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) { public static double recall(long[] relevant, long[] actual, int k) {
if (sampleIndexes.length < limit) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
} }
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit); actual = Arrays.copyOfRange(actual, 0, k);
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes); long[] intersection = Intersections.find(relevant, actual);
return (double) intersection.length / (double) referenceIndexes.length; return (double) intersection.length / (double) relevant.length;
} }
public static double precision(long[] referenceIndexes, long[] sampleIndexes) { public static double precision(long[] relevant, long[] actual) {
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes); long[] intersection = Intersections.find(relevant, actual);
return (double) intersection.length / (double) sampleIndexes.length; return (double) intersection.length / (double) actual.length;
} }
public static double precision(long[] referenceIndexes, long[] sampleIndexes, int limit) { public static double precision(long[] relevant, long[] actual, int k) {
if (sampleIndexes.length < limit) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
} }
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit); actual = Arrays.copyOfRange(actual, 0, k);
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes); long[] intersection = Intersections.find(relevant, actual);
return (double) intersection.length / (double) sampleIndexes.length; return (double) intersection.length / (double) actual.length;
} }
/** /**
* Compute the recall as the proportion of matching indices divided by the expected indices * Compute the recall as the proportion of matching indices divided by the expected indices
* *
* @param referenceIndexes * @param relevant
* int array of indices * int array of indices
* @param sampleIndexes * @param actual
* int array of indices * int array of indices
* @return a fractional measure of matching vs expected indices * @return a fractional measure of matching vs expected indices
*/ */
public static double recall(int[] referenceIndexes, int[] sampleIndexes) { public static double recall(int[] relevant, int[] actual) {
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length); int intersection = Intersections.count(relevant, actual, relevant.length);
return (double) intersection / (double) referenceIndexes.length; return (double) intersection / (double) relevant.length;
} }
public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) { public static double recall(int[] relevant, int[] actual, int k) {
if (sampleIndexes.length < limit) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
} }
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit); actual = Arrays.copyOfRange(actual, 0, k);
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length); int intersection = Intersections.count(relevant, actual, relevant.length);
return (double) intersection / (double) referenceIndexes.length; return (double) intersection / (double) relevant.length;
} }
public static double precision(int[] referenceIndexes, int[] sampleIndexes) { public static double precision(int[] relevant, int[] actual) {
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
int intersection = Intersections.count(referenceIndexes, sampleIndexes); int intersection = Intersections.count(relevant, actual);
return (double) intersection / (double) sampleIndexes.length; return (double) intersection / (double) actual.length;
} }
public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) { public static double precision(int[] relevant, int[] actual, int k) {
if (sampleIndexes.length < limit) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
} }
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit); actual = Arrays.copyOfRange(actual, 0, k);
Arrays.sort(referenceIndexes); Arrays.sort(relevant);
Arrays.sort(sampleIndexes); Arrays.sort(actual);
int intersection = Intersections.count(referenceIndexes, sampleIndexes); int intersection = Intersections.count(relevant, actual);
return (double) intersection / (double) sampleIndexes.length; return (double) intersection / (double) actual.length;
} }
/** /**
@ -135,65 +138,49 @@ public class ComputeFunctions {
return Intersections.find(a, b); return Intersections.find(a, b);
} }
public static long[] intersection(long[] a, long[] b, int limit) {
return Intersections.find(a, b, limit);
}
/** /**
* Compute the intersection of two int arrays * Compute the intersection of two int arrays
*/ */
public static int[] intersection(int[] reference, int[] sample) { public static int[] intersection(int[] a, int[] b) {
return Intersections.find(reference, sample); return Intersections.find(a, b);
}
public static int[] intersection(int[] reference, int[] sample, int limit) {
return Intersections.find(reference, sample, limit);
} }
/** /**
* Compute the size of the intersection of two int arrays * Compute the size of the intersection of two int arrays
*/ */
public static int intersectionSize(int[] reference, int[] sample) { public static int intersectionSize(int[] a, int[] b) {
return Intersections.count(reference, sample); return Intersections.count(a, b);
} }
public static int intersectionSize(int[] reference, int[] sample, int limit) { public static int intersectionSize(long[] a, long[] b) {
return Intersections.count(reference, sample, limit); return Intersections.count(a, b);
} }
public static int intersectionSize(long[] reference, long[] sample) { public static double F1(int[] relevant, int[] actual) {
return Intersections.count(reference, sample); return F1(relevant, actual, relevant.length);
} }
public static int intersectionSize(long[] reference, long[] sample, int limit) { public static double F1(int[] relevant, int[] actual, int k) {
return Intersections.count(reference, sample, limit); double recallAtK = recall(relevant, actual, k);
} double precisionAtK = precision(relevant, actual, k);
public static double F1(int[] reference, int[] sample) {
return F1(reference, sample, reference.length);
}
public static double F1(int[] reference, int[] sample, int limit) {
double recallAtK = recall(reference, sample, limit);
double precisionAtK = precision(reference, sample, limit);
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK)); return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
} }
public static double F1(long[] reference, long[] sample) { public static double F1(long[] relevant, long[] actual) {
return F1(reference, sample, reference.length); return F1(relevant, actual, relevant.length);
} }
public static double F1(long[] reference, long[] sample, int limit) { public static double F1(long[] relevant, long[] actual, int k) {
double recallAtK = recall(reference, sample, limit); double recallAtK = recall(relevant, actual, k);
double precisionAtK = precision(reference, sample, limit); double precisionAtK = precision(relevant, actual, k);
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK)); return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
} }
/** /**
* Reciprocal Rank - The multiplicative inverse of the first rank which is relevant. * Reciprocal Rank - The multiplicative inverse of the first rank which is relevant.
*/ */
public static double RR(long[] reference, long[] sample, int limit) { public static double reciprocal_rank(long[] relevant, long[] actual, int k) {
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit); int firstRank = Intersections.firstMatchingIndex(relevant, actual, k);
if (firstRank >= 0) { if (firstRank >= 0) {
return 1.0d / (firstRank+1); return 1.0d / (firstRank+1);
} else { } else {
@ -201,21 +188,63 @@ public class ComputeFunctions {
} }
} }
public static double RR(long[] reference, long[] sample) { public static double reciprocal_rank(long[] relevant, long[] actual) {
return RR(reference, sample, reference.length); return reciprocal_rank(relevant, actual, relevant.length);
} }
public static double RR(int[] reference, int[] sample, int limit) { /**
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit); * RR as in M(RR)
if (firstRank >= 0) { */
public static double reciprocal_rank(int[] relevant, int[] actual, int k) {
int firstRank = Intersections.firstMatchingIndex(relevant, actual, k);
if (firstRank<0) {
return 0;
}
return 1.0d / (firstRank+1); return 1.0d / (firstRank+1);
} else {
return 0.0;
}
} }
public static double RR(int[] reference, int[] sample) { public static double reciprocal_rank(int[] relevant, int[] actual) {
return RR(reference, sample, reference.length); return reciprocal_rank(relevant, actual, relevant.length);
}
public static double average_precision(int[] relevant, int[] actual) {
return average_precision(relevant,actual,relevant.length);
}
public static double average_precision(int[] relevant, int[] actual, int k) {
int maxK = Math.min(k,actual.length);
HashSet<Integer> relevantSet = new HashSet<>(relevant.length);
for (Integer i : relevant) {
relevantSet.add(i);
}
int relevantCount=0;
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
for (int i = 0; i < maxK; i++) {
if (relevantSet.contains(actual[i])){
relevantCount++;
double precisionAtIdx = (double) relevantCount / (i+1);
stats.accept(precisionAtIdx);
}
}
return stats.getAverage();
}
public static double average_precision(long[] relevant, long[] actual, int k) {
int maxK = Math.min(k,actual.length);
HashSet<Long> refset = new HashSet<>(relevant.length);
for (Long i : relevant) {
refset.add(i);
}
int relevantCount=0;
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
for (int i = 0; i < maxK; i++) {
if (refset.contains(actual[i])){
relevantCount++;
double precisionAtIdx = (double) relevantCount / (i+1);
stats.accept(precisionAtIdx);
}
}
return stats.getAverage();
} }
} }

View File

@ -27,10 +27,10 @@ public class Intersections {
public static int firstMatchingIndex(long[] reference, long[] sample, int limit) { public static int firstMatchingIndex(long[] reference, long[] sample, int limit) {
Arrays.sort(reference); Arrays.sort(reference);
int maxIndex = Math.min(sample.length, limit); int maxIndex = Math.min(sample.length, limit);
int foundAt=-1; int foundAt = -1;
for (int index = 0; index < maxIndex; index++) { for (int index = 0; index < maxIndex; index++) {
foundAt = Arrays.binarySearch(reference, sample[index]); foundAt = Arrays.binarySearch(reference, sample[index]);
if (foundAt>=0) break; if (foundAt >= 0) break;
} }
return foundAt; return foundAt;
} }
@ -38,21 +38,18 @@ public class Intersections {
public static int firstMatchingIndex(int[] reference, int[] sample, int limit) { public static int firstMatchingIndex(int[] reference, int[] sample, int limit) {
Arrays.sort(reference); Arrays.sort(reference);
int maxIndex = Math.min(sample.length, limit); int maxIndex = Math.min(sample.length, limit);
int foundAt=-1; int foundAt = -1;
for (int index = 0; index < maxIndex; index++) { for (int index = 0; index < maxIndex; index++) {
foundAt = Arrays.binarySearch(reference, sample[index]); foundAt = Arrays.binarySearch(reference, sample[index]);
if (foundAt>=0) break; if (foundAt >= 0) break;
} }
return foundAt; return foundAt;
} }
public static int count(int[] reference, int[] sample) { public static int count(int[] reference, int[] sample) {
return count(reference,sample,reference.length);
}
public static int count(int[] reference, int[] sample, int limit) {
int a_index = 0, b_index = 0, matches = 0; int a_index = 0, b_index = 0, matches = 0;
int a_element, b_element; int a_element, b_element;
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) { while (a_index < reference.length && b_index < sample.length) {
a_element = reference[a_index]; a_element = reference[a_index];
b_element = sample[b_index]; b_element = sample[b_index];
if (a_element == b_element) { if (a_element == b_element) {
@ -71,6 +68,7 @@ public class Intersections {
public static int count(long[] reference, long[] sample) { public static int count(long[] reference, long[] sample) {
return count(reference, sample, reference.length); return count(reference, sample, reference.length);
} }
public static int count(long[] reference, long[] sample, int limit) { public static int count(long[] reference, long[] sample, int limit) {
int a_index = 0, b_index = 0, matches = 0; int a_index = 0, b_index = 0, matches = 0;
long a_element, b_element; long a_element, b_element;
@ -91,18 +89,13 @@ public class Intersections {
} }
public static int[] find(int[] reference, int[] sample) { public static int[] find(int[] reference, int[] sample) {
return find(reference,sample,reference.length); int[] result = new int[sample.length];
}
public static int[] find(int[] reference, int[] sample, int limit) {
int[] result = new int[limit];
int a_index = 0, b_index = 0, acc_index = -1; int a_index = 0, b_index = 0, acc_index = -1;
int a_element, b_element; int a_element, b_element;
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) { while (a_index < reference.length && b_index < sample.length) {
a_element = reference[a_index]; a_element = reference[a_index];
b_element = sample[b_index]; b_element = sample[b_index];
if (a_element == b_element) { if (a_element == b_element) {
result = resize(result);
result[++acc_index] = a_element; result[++acc_index] = a_element;
a_index++; a_index++;
b_index++; b_index++;
@ -112,21 +105,17 @@ public class Intersections {
a_index++; a_index++;
} }
} }
return Arrays.copyOfRange(result,0,acc_index+1); return Arrays.copyOfRange(result, 0, acc_index + 1);
} }
public static long[] find(long[] reference, long[] sample) { public static long[] find(long[] reference, long[] sample) {
return find(reference, sample, reference.length); long[] result = new long[sample.length];
}
public static long[] find(long[] reference, long[] sample, int limit) {
long[] result = new long[limit];
int a_index = 0, b_index = 0, acc_index = -1; int a_index = 0, b_index = 0, acc_index = -1;
long a_element, b_element; long a_element, b_element;
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) { while (a_index < reference.length && b_index < sample.length) {
a_element = reference[a_index]; a_element = reference[a_index];
b_element = sample[b_index]; b_element = sample[b_index];
if (a_element == b_element) { if (a_element == b_element) {
result = resize(result);
result[++acc_index] = a_element; result[++acc_index] = a_element;
a_index++; a_index++;
b_index++; b_index++;
@ -136,22 +125,7 @@ public class Intersections {
a_index++; a_index++;
} }
} }
return Arrays.copyOfRange(result,0,acc_index+1); return Arrays.copyOfRange(result, 0, acc_index + 1);
} }
private static int[] resize(int[] arr) {
int len = arr.length;
int[] copy = new int[len + 1];
System.arraycopy(arr, 0, copy, 0, len);
return copy;
}
private static long[] resize(long[] arr) {
int len = arr.length;
long[] copy = new long[len + 1];
System.arraycopy(arr, 0, copy, 0, len);
return copy;
}
} }

View File

@ -23,7 +23,7 @@ import java.util.stream.IntStream;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
class ComputeFunctionsTest { class ComputeFunctionsIntTest {
private final static Offset<Double> offset=Offset.offset(0.001d); private final static Offset<Double> offset=Offset.offset(0.001d);
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9}; private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
@ -33,7 +33,10 @@ class ComputeFunctionsTest {
private final static int[] highInts56789 = new int[]{5,6,7,8,9}; private final static int[] highInts56789 = new int[]{5,6,7,8,9};
private final static int[] intsBy3_369 = new int[]{3,6,9}; private final static int[] intsBy3_369 = new int[]{3,6,9};
private final static int[] intsBy3_693 = new int[]{3,6,9}; private final static int[] intsBy3_693 = new int[]{6,9,3};
private final static int[] midInts45678 = new int[]{4,5,6,7,8};
private final static int[] ints12390 = new int[]{1,2,3,9,0};
@Test @Test
void testRecallIntArrays() { void testRecallIntArrays() {
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195)) assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
@ -81,12 +84,49 @@ class ComputeFunctionsTest {
@Test @Test
public void testReciprocalRank() { public void testReciprocalRank() {
assertThat(ComputeFunctions.RR(intsBy3_369,highInts56789)) assertThat(ComputeFunctions.reciprocal_rank(intsBy3_369,highInts56789))
.as("relevant results in rank 2 should yield RR=0.5") .as("relevant results in rank 2 should yield RR=0.5")
.isCloseTo(0.5d,offset); .isCloseTo(0.5d,offset);
assertThat(ComputeFunctions.RR(highInts56789,lowInts01234)) assertThat(ComputeFunctions.reciprocal_rank(highInts56789,lowInts01234))
.as("no relevant results should yield RR=0.0") .as("no relevant results should yield RR=0.0")
.isCloseTo(0.0d,offset); .isCloseTo(0.0d,offset);
} }
@Test
public void testIntegerIntersection() {
int[] result = Intersections.find(lowInts01234,midInts45678);
assertThat(result).isEqualTo(new int[]{4});
}
@Test
public void testCountIntIntersection() {
int result = Intersections.count(oddInts37195, ints12390);
assertThat(result).isEqualTo(2L);
}
@Test
public void testMasking() {
assertThat(Intersections.mask(ints12390,highInts56789))
.as("the last actual is relevant and should have a 1 in the mask")
.isEqualTo(new int[]{0,0,0,0,1});
assertThat(Intersections.mask(allInts,allInts))
.as("the last actual is relevant and should have a 1 in the mask")
.isEqualTo(new int[]{1,1,1,1,1,1,1,1,1,1});
}
@Test
public void testAP() {
double ap1 = ComputeFunctions.average_precision(new int[]{1, 2, 3, 4, 5, 6}, new int[]{3, 11, 5, 12, 1});
assertThat(ap1)
.as("")
.isCloseTo(0.755d,offset);
double ap2 = ComputeFunctions.average_precision(ints12390, intsBy3_369);
assertThat(ap2)
.as("")
.isCloseTo(0.833,offset);
}
} }

View File

@ -0,0 +1,101 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.computefunctions;
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Test;
import java.util.stream.LongStream;
import static org.assertj.core.api.Assertions.assertThat;
class ComputeFunctionsLongTest {
private final static Offset<Double> offset=Offset.offset(0.001d);
private final static long[] longs_0to9 =new long[]{0,1,2,3,4,5,6,7,8,9};
private final static long[] longs_37195 = new long[]{3,7,1,9,5};
private final static long[] longs_86204 = new long[]{8,6,2,0,4};
private final static long[] longs_01234 = new long[]{0,1,2,3,4};
private final static long[] longs_56789 = new long[]{5,6,7,8,9};
private final static long[] longs_369 = new long[]{3,6,9};
private final static long[] longs_693 = new long[]{6,9,3};
private final static long[] longs_45678 = new long[]{4,5,6,7,8};
private final static long[] longs_12390 = new long[]{1,2,3,9,0};
@Test
void testRecallLongArrays() {
assertThat(ComputeFunctions.recall(longs_86204, longs_37195))
.as("finding 0 actual of any should yield recall=0.0")
.isCloseTo(0.0d, offset);
assertThat(ComputeFunctions.recall(longs_86204, longs_369))
.as("finding 1 actual of 5 relevant should yield recall=0.2")
.isCloseTo(0.2d, offset);
assertThat(ComputeFunctions.recall(longs_86204, longs_369,1))
.as("finding 0 (limited) actual of 5 relevant should yield recall=0.0")
.isCloseTo(0.0d, offset);
}
@Test
void testPrecisionLongArrays() {
assertThat(ComputeFunctions.precision(longs_86204, longs_693))
.as("one of three results being relevant should yield precision=0.333")
.isCloseTo(0.333,offset);
assertThat(ComputeFunctions.precision(longs_86204, longs_01234))
.as("three of five results being relevant should yield precision=0.6")
.isCloseTo(0.6,offset);
assertThat(ComputeFunctions.precision(longs_86204, longs_37195))
.as("none of the results being relevant should yield precision=0.0")
.isCloseTo(0.0,offset);
}
@Test
public void sanityCheckRecallAndLimitRatioLongs() {
long[] hundo = LongStream.range(0,100).toArray();
for (int i = 0; i < hundo.length; i++) {
long[] partial=LongStream.range(0,i).toArray();
int finalI = i;
assertThat(ComputeFunctions.recall(hundo, partial))
.as(() -> "for subset size " + finalI +", recall should be fractional/100")
.isCloseTo((double)partial.length/(double)hundo.length,offset);
assertThat(ComputeFunctions.recall(hundo, hundo, i))
.as(() -> "for full intersection, limit " + finalI +" (K) recall should be fractional/100")
.isCloseTo((double)partial.length/(double)hundo.length,offset);
}
}
@Test
public void testReciprocalRankLongs() {
assertThat(ComputeFunctions.reciprocal_rank(longs_369, longs_56789))
.as("relevant results in rank 2 should yield RR=0.5")
.isCloseTo(0.5d,offset);
assertThat(ComputeFunctions.reciprocal_rank(longs_56789, longs_01234))
.as("no relevant results should yield RR=0.0")
.isCloseTo(0.0d,offset);
}
@Test
public void testCountLongIntersection() {
long result = Intersections.count(longs_37195, longs_12390);
assertThat(result).isEqualTo(3L);
}
}

View File

@ -1,50 +0,0 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.computefunctions;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class IntersectionsTest {
@Test
public void testIntegerIntersection() {
int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8});
assertThat(result).isEqualTo(new int[]{4,5});
}
@Test
public void testLongIntersection() {
long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8});
assertThat(result).isEqualTo(new long[]{4,5});
}
@Test
public void testCountIntIntersection() {
long result = Intersections.count(new int[]{1,3,5,7,9}, new int[]{1,2,3,9,10});
assertThat(result).isEqualTo(3L);
}
@Test
public void testCountLongIntersection() {
long result = Intersections.count(new long[]{1,3,5,7,9}, new long[]{1,2,3,9,10});
assertThat(result).isEqualTo(3);
}
}

View File

@ -20,12 +20,13 @@ import io.nosqlbench.api.config.NBLabels;
import io.nosqlbench.api.engine.metrics.instruments.NBMetricGauge; import io.nosqlbench.api.engine.metrics.instruments.NBMetricGauge;
import java.util.DoubleSummaryStatistics; import java.util.DoubleSummaryStatistics;
import java.util.function.DoubleConsumer;
/** /**
* Create a discrete stat reservoir as a gauge. * Create a discrete stat reservoir as a gauge.
*/ */
public class DoubleSummaryGauge implements NBMetricGauge<Double> { public class DoubleSummaryGauge implements NBMetricGauge<Double>, DoubleConsumer {
private final NBLabels labels; private final NBLabels labels;
private final Stat stat; private final Stat stat;
private final DoubleSummaryStatistics stats; private final DoubleSummaryStatistics stats;

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.api.engine.metrics.instruments;
import io.nosqlbench.api.config.NBLabels;
import java.util.function.DoubleConsumer;
public class CompoundGaugeFunction implements NBMetricGauge<Double>, DoubleConsumer {
private final NBLabels labels;
private final String name;
public CompoundGaugeFunction(NBLabels labels, String name) {
this.labels = labels;
this.name = name;
}
@Override
public Double getValue() {
return null;
}
@Override
public NBLabels getLabels() {
return null;
}
@Override
public void accept(double value) {
}
}