support recall@K and precision@K within actual retrieved data

This commit is contained in:
Jonathan Shook 2023-09-07 10:04:42 -05:00
parent be4ef9c34f
commit 8eb3620caf
3 changed files with 190 additions and 7 deletions

View File

@ -18,6 +18,24 @@ package io.nosqlbench.engine.extensions.computefunctions;
import java.util.Arrays;
/**
* <P>A collection of compute functions related to vector search relevancy.
* These are based on arrays of indices of vectors, where the expected data is from known KNN test data,
* and the actual data is from a vector search query.</P>
*
* <P>Variations of these functions have a limit parameter, which allows for derivation of relevancy
* measurements for a smaller query without having to run a separate test for each K value.
* If you are using test vectors from a computed KNN test data with for K=100, you can compute
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
* to do exactly what would have been done for a test where you actually query for that K limit.
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
* are passed to these functions as ranked in results.</STRONG></P>
*
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
* these methods will yield incorrect results as they rely on the <EM>two-pointer</EM> method and do not
* elide duplicates internally.
*/
public class ComputeFunctions {
/**
@ -32,6 +50,34 @@ public class ComputeFunctions {
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
return (double) intersection.length / (double) referenceIndexes.length;
}
public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
return (double) intersection.length / (double) referenceIndexes.length;
}
public static double precision(long[] referenceIndexes, long[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
return (double) intersection.length / (double) sampleIndexes.length;
}
public static double precision(long[] referenceIndexes, long[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
return (double) intersection.length / (double) sampleIndexes.length;
}
/**
* Compute the recall as the proportion of matching indices divided by the expected indices
@ -42,9 +88,36 @@ public class ComputeFunctions {
public static double recall(int[] referenceIndexes, int[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
return (double) intersection / (double) referenceIndexes.length;
}
public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
return (double) intersection / (double) referenceIndexes.length;
}
public static double precision(int[] referenceIndexes, int[] sampleIndexes) {
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
return (double) intersection / (double) sampleIndexes.length;
}
public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
return (double) intersection / (double) sampleIndexes.length;
}
/**
* Compute the intersection of two long arrays
@ -52,6 +125,9 @@ public class ComputeFunctions {
public static long[] intersection(long[] a, long[] b) {
return Intersections.find(a,b);
}
public static long[] intersection(long[] a, long[] b, int limit) {
return Intersections.find(a, b, limit);
}
/**
* Compute the intersection of two int arrays
@ -60,12 +136,24 @@ public class ComputeFunctions {
return Intersections.find(reference, sample);
}
public static int[] intersection(int[] reference, int[] sample, int limit) {
return Intersections.find(reference,sample,limit);
}
/**
* Compute the size of the intersection of two int arrays
*/
public static int intersectionSize(int[] reference, int[] sample) {
return Intersections.count(reference, sample);
}
public static int intersectionSize(int[] reference, int[] sample, int limit) {
return Intersections.count(reference, sample, limit);
}
public static int intersectionSize(long[] reference, long[] sample) {
return Intersections.count(reference, sample);
}
public static int intersectionSize(long[] reference, long[] sample, int limit) {
return Intersections.count(reference, sample, limit);
}
}

View File

@ -21,9 +21,12 @@ import java.util.Arrays;
public class Intersections {
public static int count(int[] reference, int[] sample) {
return count(reference,sample,reference.length);
}
public static int count(int[] reference, int[] sample, int limit) {
int a_index = 0, b_index = 0, matches = 0;
int a_element, b_element;
while (a_index < reference.length && b_index < sample.length) {
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
@ -39,10 +42,13 @@ public class Intersections {
return matches;
}
public static long count(long[] reference, long[] sample) {
public static int count(long[] reference, long[] sample) {
return count(reference, sample, reference.length);
}
public static int count(long[] reference, long[] sample, int limit) {
int a_index = 0, b_index = 0, matches = 0;
long a_element, b_element;
while (a_index < reference.length && b_index < sample.length) {
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
@ -59,10 +65,14 @@ public class Intersections {
}
public static int[] find(int[] reference, int[] sample) {
return find(reference,sample,reference.length);
}
public static int[] find(int[] reference, int[] sample, int limit) {
int[] result = new int[reference.length];
int a_index = 0, b_index = 0, acc_index = -1;
int a_element, b_element;
while (a_index < reference.length && b_index < sample.length) {
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
@ -78,11 +88,15 @@ public class Intersections {
}
return Arrays.copyOfRange(result,0,acc_index+1);
}
public static long[] find(long[] reference, long[] sample) {
return find(reference, sample, reference.length);
}
public static long[] find(long[] reference, long[] sample, int limit) {
long[] result = new long[reference.length];
int a_index = 0, b_index = 0, acc_index = -1;
long a_element, b_element;
while (a_index < reference.length && b_index < sample.length) {
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {

View File

@ -0,0 +1,81 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.computefunctions;
import org.assertj.core.data.Offset;
import org.junit.jupiter.api.Test;
import java.util.stream.IntStream;
import static org.assertj.core.api.Assertions.assertThat;
class ComputeFunctionsTest {
private final static Offset<Double> offset=Offset.offset(0.001d);
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
private final static int[] oddInts37195 = new int[]{3,7,1,9,5};
private final static int[] evenInts86204 = new int[]{8,6,2,0,4};
private final static int[] lowInts01234 = new int[]{0,1,2,3,4};
private final static int[] highInts56789 = new int[]{5,6,7,8,9};
private final static int[] intsBy3_369 = new int[]{3,6,9};
private final static int[] intsBy3_693 = new int[]{3,6,9};
@Test
void testRecallIntArrays() {
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
.as("finding 0 actual of any should yield recall=0.0")
.isCloseTo(0.0d, offset);
assertThat(ComputeFunctions.recall(evenInts86204,intsBy3_369))
.as("finding 1 actual of 5 relevant should yield recall=0.2")
.isCloseTo(0.2d, offset);
assertThat(ComputeFunctions.recall(evenInts86204,intsBy3_369,1))
.as("finding 0 (limited) actual of 5 relevant should yield recall=0.0")
.isCloseTo(0.0d, offset);
}
@Test
void testPrecisionIntArrays() {
assertThat(ComputeFunctions.precision(evenInts86204,intsBy3_693))
.as("one of three results being relevant should yield precision=0.333")
.isCloseTo(0.333,offset);
assertThat(ComputeFunctions.precision(evenInts86204,lowInts01234))
.as("three of five results being relevant should yield precision=0.6")
.isCloseTo(0.6,offset);
assertThat(ComputeFunctions.precision(evenInts86204,oddInts37195))
.as("none of the results being relevant should yield precision=0.0")
.isCloseTo(0.0,offset);
}
@Test
public void sanityCheckRecallAndLimitRatio() {
int[] hundo = IntStream.range(0,100).toArray();
for (int i = 0; i < hundo.length; i++) {
int[] partial=IntStream.range(0,i).toArray();
int finalI = i;
assertThat(ComputeFunctions.recall(hundo, partial))
.as(() -> "for subset size " + finalI +", recall should be fractional/100")
.isCloseTo((double)partial.length/(double)hundo.length,offset);
assertThat(ComputeFunctions.recall(hundo, hundo, i))
.as(() -> "for full intersection, limit " + finalI +" (K) recall should be fractional/100")
.isCloseTo((double)partial.length/(double)hundo.length,offset);
}
}
}