mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-01-11 00:12:04 -06:00
support recall@K and precision@K within actual retrieved data
This commit is contained in:
parent
be4ef9c34f
commit
8eb3620caf
@ -18,6 +18,24 @@ package io.nosqlbench.engine.extensions.computefunctions;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* <P>A collection of compute functions related to vector search relevancy.
|
||||
* These are based on arrays of indices of vectors, where the expected data is from known KNN test data,
|
||||
* and the actual data is from a vector search query.</P>
|
||||
*
|
||||
* <P>Variations of these functions have a limit parameter, which allows for derivation of relevancy
|
||||
* measurements for a smaller query without having to run a separate test for each K value.
|
||||
* If you are using test vectors from a computed KNN test data with for K=100, you can compute
|
||||
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
|
||||
* to do exactly what would have been done for a test where you actually query for that K limit.
|
||||
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
|
||||
* are passed to these functions as ranked in results.</STRONG></P>
|
||||
*
|
||||
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
|
||||
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
|
||||
* these methods will yield incorrect results as they rely on the <EM>two-pointer</EM> method and do not
|
||||
* elide duplicates internally.
|
||||
*/
|
||||
public class ComputeFunctions {
|
||||
|
||||
/**
|
||||
@ -32,6 +50,34 @@ public class ComputeFunctions {
|
||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
||||
return (double) intersection.length / (double) referenceIndexes.length;
|
||||
}
|
||||
public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) {
|
||||
if (sampleIndexes.length<limit) {
|
||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
||||
}
|
||||
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
||||
return (double) intersection.length / (double) referenceIndexes.length;
|
||||
}
|
||||
|
||||
public static double precision(long[] referenceIndexes, long[] sampleIndexes) {
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
||||
return (double) intersection.length / (double) sampleIndexes.length;
|
||||
}
|
||||
|
||||
public static double precision(long[] referenceIndexes, long[] sampleIndexes, int limit) {
|
||||
if (sampleIndexes.length<limit) {
|
||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
||||
}
|
||||
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
||||
return (double) intersection.length / (double) sampleIndexes.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the recall as the proportion of matching indices divided by the expected indices
|
||||
@ -42,9 +88,36 @@ public class ComputeFunctions {
|
||||
public static double recall(int[] referenceIndexes, int[] sampleIndexes) {
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
|
||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
|
||||
return (double) intersection / (double) referenceIndexes.length;
|
||||
}
|
||||
public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) {
|
||||
if (sampleIndexes.length<limit) {
|
||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
||||
}
|
||||
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
|
||||
return (double) intersection / (double) referenceIndexes.length;
|
||||
}
|
||||
|
||||
public static double precision(int[] referenceIndexes, int[] sampleIndexes) {
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
|
||||
return (double) intersection / (double) sampleIndexes.length;
|
||||
}
|
||||
public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) {
|
||||
if (sampleIndexes.length<limit) {
|
||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
||||
}
|
||||
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
|
||||
Arrays.sort(referenceIndexes);
|
||||
Arrays.sort(sampleIndexes);
|
||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
|
||||
return (double) intersection / (double) sampleIndexes.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the intersection of two long arrays
|
||||
@ -52,6 +125,9 @@ public class ComputeFunctions {
|
||||
public static long[] intersection(long[] a, long[] b) {
|
||||
return Intersections.find(a,b);
|
||||
}
|
||||
public static long[] intersection(long[] a, long[] b, int limit) {
|
||||
return Intersections.find(a, b, limit);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the intersection of two int arrays
|
||||
@ -60,12 +136,24 @@ public class ComputeFunctions {
|
||||
return Intersections.find(reference, sample);
|
||||
}
|
||||
|
||||
public static int[] intersection(int[] reference, int[] sample, int limit) {
|
||||
return Intersections.find(reference,sample,limit);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the size of the intersection of two int arrays
|
||||
*/
|
||||
public static int intersectionSize(int[] reference, int[] sample) {
|
||||
return Intersections.count(reference, sample);
|
||||
}
|
||||
|
||||
public static int intersectionSize(int[] reference, int[] sample, int limit) {
|
||||
return Intersections.count(reference, sample, limit);
|
||||
}
|
||||
public static int intersectionSize(long[] reference, long[] sample) {
|
||||
return Intersections.count(reference, sample);
|
||||
}
|
||||
public static int intersectionSize(long[] reference, long[] sample, int limit) {
|
||||
return Intersections.count(reference, sample, limit);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -21,9 +21,12 @@ import java.util.Arrays;
|
||||
public class Intersections {
|
||||
|
||||
public static int count(int[] reference, int[] sample) {
|
||||
return count(reference,sample,reference.length);
|
||||
}
|
||||
public static int count(int[] reference, int[] sample, int limit) {
|
||||
int a_index = 0, b_index = 0, matches = 0;
|
||||
int a_element, b_element;
|
||||
while (a_index < reference.length && b_index < sample.length) {
|
||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
@ -39,10 +42,13 @@ public class Intersections {
|
||||
return matches;
|
||||
}
|
||||
|
||||
public static long count(long[] reference, long[] sample) {
|
||||
public static int count(long[] reference, long[] sample) {
|
||||
return count(reference, sample, reference.length);
|
||||
}
|
||||
public static int count(long[] reference, long[] sample, int limit) {
|
||||
int a_index = 0, b_index = 0, matches = 0;
|
||||
long a_element, b_element;
|
||||
while (a_index < reference.length && b_index < sample.length) {
|
||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
@ -59,10 +65,14 @@ public class Intersections {
|
||||
}
|
||||
|
||||
public static int[] find(int[] reference, int[] sample) {
|
||||
return find(reference,sample,reference.length);
|
||||
}
|
||||
|
||||
public static int[] find(int[] reference, int[] sample, int limit) {
|
||||
int[] result = new int[reference.length];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
int a_element, b_element;
|
||||
while (a_index < reference.length && b_index < sample.length) {
|
||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
@ -78,11 +88,15 @@ public class Intersections {
|
||||
}
|
||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
||||
}
|
||||
|
||||
public static long[] find(long[] reference, long[] sample) {
|
||||
return find(reference, sample, reference.length);
|
||||
}
|
||||
public static long[] find(long[] reference, long[] sample, int limit) {
|
||||
long[] result = new long[reference.length];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
long a_element, b_element;
|
||||
while (a_index < reference.length && b_index < sample.length) {
|
||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
|
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.engine.extensions.computefunctions;
|
||||
|
||||
import org.assertj.core.data.Offset;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class ComputeFunctionsTest {
|
||||
|
||||
private final static Offset<Double> offset=Offset.offset(0.001d);
|
||||
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
|
||||
private final static int[] oddInts37195 = new int[]{3,7,1,9,5};
|
||||
private final static int[] evenInts86204 = new int[]{8,6,2,0,4};
|
||||
private final static int[] lowInts01234 = new int[]{0,1,2,3,4};
|
||||
private final static int[] highInts56789 = new int[]{5,6,7,8,9};
|
||||
|
||||
private final static int[] intsBy3_369 = new int[]{3,6,9};
|
||||
private final static int[] intsBy3_693 = new int[]{3,6,9};
|
||||
@Test
|
||||
void testRecallIntArrays() {
|
||||
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
|
||||
.as("finding 0 actual of any should yield recall=0.0")
|
||||
.isCloseTo(0.0d, offset);
|
||||
|
||||
assertThat(ComputeFunctions.recall(evenInts86204,intsBy3_369))
|
||||
.as("finding 1 actual of 5 relevant should yield recall=0.2")
|
||||
.isCloseTo(0.2d, offset);
|
||||
|
||||
assertThat(ComputeFunctions.recall(evenInts86204,intsBy3_369,1))
|
||||
.as("finding 0 (limited) actual of 5 relevant should yield recall=0.0")
|
||||
.isCloseTo(0.0d, offset);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPrecisionIntArrays() {
|
||||
assertThat(ComputeFunctions.precision(evenInts86204,intsBy3_693))
|
||||
.as("one of three results being relevant should yield precision=0.333")
|
||||
.isCloseTo(0.333,offset);
|
||||
assertThat(ComputeFunctions.precision(evenInts86204,lowInts01234))
|
||||
.as("three of five results being relevant should yield precision=0.6")
|
||||
.isCloseTo(0.6,offset);
|
||||
assertThat(ComputeFunctions.precision(evenInts86204,oddInts37195))
|
||||
.as("none of the results being relevant should yield precision=0.0")
|
||||
.isCloseTo(0.0,offset);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sanityCheckRecallAndLimitRatio() {
|
||||
int[] hundo = IntStream.range(0,100).toArray();
|
||||
|
||||
for (int i = 0; i < hundo.length; i++) {
|
||||
int[] partial=IntStream.range(0,i).toArray();
|
||||
int finalI = i;
|
||||
assertThat(ComputeFunctions.recall(hundo, partial))
|
||||
.as(() -> "for subset size " + finalI +", recall should be fractional/100")
|
||||
.isCloseTo((double)partial.length/(double)hundo.length,offset);
|
||||
assertThat(ComputeFunctions.recall(hundo, hundo, i))
|
||||
.as(() -> "for full intersection, limit " + finalI +" (K) recall should be fractional/100")
|
||||
.isCloseTo((double)partial.length/(double)hundo.length,offset);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user