add (more) correct MAP implementation

This commit is contained in:
Jonathan Shook 2023-09-08 09:17:39 -05:00
parent a0006f18a7
commit 4484d22ff1
4 changed files with 172 additions and 85 deletions

View File

@ -17,6 +17,8 @@
package io.nosqlbench.engine.extensions.computefunctions;
import java.util.Arrays;
import java.util.DoubleSummaryStatistics;
import java.util.HashSet;
/**
* <P>A collection of compute functions related to vector search relevancy.
@ -29,7 +31,8 @@ import java.util.Arrays;
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
* to do exactly what would have been done for a test where you actually query for that K limit.
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
* are passed to these functions as ranked in results.</STRONG></P>
* are passed to these functions as ranked in results.</STRONG></P> Some of the methods apply K to the
* expected (relevant) indices, others to the actual (response) indices, depending on what is appropriate.
*
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
@ -192,7 +195,7 @@ public class ComputeFunctions {
/**
* Reciprocal Rank - The multiplicative inverse of the first rank which is relevant.
*/
public static double RR(long[] reference, long[] sample, int limit) {
public static double reciprocal_rank(long[] reference, long[] sample, int limit) {
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
if (firstRank >= 0) {
return 1.0d / (firstRank+1);
@ -201,11 +204,11 @@ public class ComputeFunctions {
}
}
public static double RR(long[] reference, long[] sample) {
return RR(reference, sample, reference.length);
public static double reciprocal_rank(long[] reference, long[] sample) {
return reciprocal_rank(reference, sample, reference.length);
}
public static double RR(int[] reference, int[] sample, int limit) {
public static double reciprocal_rank(int[] reference, int[] sample, int limit) {
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
if (firstRank >= 0) {
return 1.0d / (firstRank+1);
@ -214,8 +217,47 @@ public class ComputeFunctions {
}
}
public static double RR(int[] reference, int[] sample) {
return RR(reference, sample, reference.length);
public static double reciprocal_rank(int[] reference, int[] sample) {
return reciprocal_rank(reference, sample, reference.length);
}
public static double average_precision(int[] reference, int[] sample) {
return average_precision(reference,sample,reference.length);
}
public static double average_precision(int[] reference, int[] sample, int k) {
int maxK = Math.min(k,sample.length);
HashSet<Integer> refset = new HashSet<>(reference.length);
for (Integer i : reference) {
refset.add(i);
}
int relevant=0;
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
for (int i = 0; i < maxK; i++) {
if (refset.contains(sample[i])){
relevant++;
double precisionAtIdx = (double) relevant / (i+1);
stats.accept(precisionAtIdx);
}
}
return stats.getAverage();
}
public static double average_precision(long[] reference, long[] sample, int k) {
int maxK = Math.min(k,sample.length);
HashSet<Long> refset = new HashSet<>(reference.length);
for (Long i : reference) {
refset.add(i);
}
int relevant=0;
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
for (int i = 0; i < maxK; i++) {
if (refset.contains(sample[i])){
relevant++;
double precisionAtIdx = (double) relevant / (i+1);
stats.accept(precisionAtIdx);
}
}
return stats.getAverage();
}
}

View File

@ -27,10 +27,10 @@ public class Intersections {
public static int firstMatchingIndex(long[] reference, long[] sample, int limit) {
Arrays.sort(reference);
int maxIndex = Math.min(sample.length, limit);
int foundAt=-1;
int foundAt = -1;
for (int index = 0; index < maxIndex; index++) {
foundAt = Arrays.binarySearch(reference, sample[index]);
if (foundAt>=0) break;
if (foundAt >= 0) break;
}
return foundAt;
}
@ -38,17 +38,18 @@ public class Intersections {
public static int firstMatchingIndex(int[] reference, int[] sample, int limit) {
Arrays.sort(reference);
int maxIndex = Math.min(sample.length, limit);
int foundAt=-1;
int foundAt = -1;
for (int index = 0; index < maxIndex; index++) {
foundAt = Arrays.binarySearch(reference, sample[index]);
if (foundAt>=0) break;
if (foundAt >= 0) break;
}
return foundAt;
}
public static int count(int[] reference, int[] sample) {
return count(reference,sample,reference.length);
return count(reference, sample, reference.length);
}
public static int count(int[] reference, int[] sample, int limit) {
int a_index = 0, b_index = 0, matches = 0;
int a_element, b_element;
@ -71,6 +72,7 @@ public class Intersections {
public static int count(long[] reference, long[] sample) {
return count(reference, sample, reference.length);
}
public static int count(long[] reference, long[] sample, int limit) {
int a_index = 0, b_index = 0, matches = 0;
long a_element, b_element;
@ -91,9 +93,63 @@ public class Intersections {
}
public static int[] find(int[] reference, int[] sample) {
return find(reference,sample,reference.length);
return find(reference, sample, reference.length);
}
public static int[] mask(int[] reference, int[] sample) {
return mask(reference,sample,sample.length);
}
public static int[] mask(int[] reference, int[] sample, int limit) {
int[] mask = new int[sample.length];
int relevant_idx = 0, actual_idx = 0, acc_index = -1;
int relevant_element, actual_element;
while (relevant_idx < reference.length && relevant_idx < limit && actual_idx < sample.length && actual_idx < limit) {
relevant_element = reference[relevant_idx];
actual_element = sample[actual_idx];
if (relevant_element == actual_element) {
mask[actual_idx] = 1;
relevant_idx++;
actual_idx++;
} else if (actual_element < relevant_element) {
actual_idx++;
} else {
relevant_idx++;
}
}
return mask;
}
/**
* Compare the actual indices to the relevant indices, and return an array
* containing the ordered set of indices of the actual array which appear
* in the relevant array. A perfect result looks like counting from zero.
* @param relevant The array of relevant indices
* @param actual The array of actual indices
* @param limit limit the indices to the first [limit] items
* @return An array of relevant indices in the actual array.
*/
public static int[] findIndirect(int[] relevant, int[] actual, int limit) {
int[] result = new int[actual.length];
int a_index = 0, b_index = 0, acc_index = -1;
int a_element, b_element;
while (a_index < relevant.length && a_index < limit && b_index < actual.length && b_index < limit) {
a_element = relevant[a_index];
b_element = actual[b_index];
if (a_element == b_element) {
result[++acc_index] = b_index;
a_index++;
b_index++;
} else if (b_element < a_element) {
b_index++;
} else {
a_index++;
}
}
return Arrays.copyOfRange(result, 0, acc_index + 1);
}
public static int[] find(int[] reference, int[] sample, int limit) {
int[] result = new int[limit];
int a_index = 0, b_index = 0, acc_index = -1;
@ -102,7 +158,6 @@ public class Intersections {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
result = resize(result);
result[++acc_index] = a_element;
a_index++;
b_index++;
@ -112,12 +167,13 @@ public class Intersections {
a_index++;
}
}
return Arrays.copyOfRange(result,0,acc_index+1);
return Arrays.copyOfRange(result, 0, acc_index + 1);
}
public static long[] find(long[] reference, long[] sample) {
return find(reference, sample, reference.length);
}
public static long[] find(long[] reference, long[] sample, int limit) {
long[] result = new long[limit];
int a_index = 0, b_index = 0, acc_index = -1;
@ -126,7 +182,6 @@ public class Intersections {
a_element = reference[a_index];
b_element = sample[b_index];
if (a_element == b_element) {
result = resize(result);
result[++acc_index] = a_element;
a_index++;
b_index++;
@ -136,22 +191,22 @@ public class Intersections {
a_index++;
}
}
return Arrays.copyOfRange(result,0,acc_index+1);
}
private static int[] resize(int[] arr) {
int len = arr.length;
int[] copy = new int[len + 1];
System.arraycopy(arr, 0, copy, 0, len);
return copy;
}
private static long[] resize(long[] arr) {
int len = arr.length;
long[] copy = new long[len + 1];
System.arraycopy(arr, 0, copy, 0, len);
return copy;
return Arrays.copyOfRange(result, 0, acc_index + 1);
}
//
// private static int[] resize(int[] arr) {
// int len = arr.length;
// int[] copy = new int[len + 1];
// System.arraycopy(arr, 0, copy, 0, len);
// return copy;
// }
//
// private static long[] resize(long[] arr) {
// int len = arr.length;
// long[] copy = new long[len + 1];
// System.arraycopy(arr, 0, copy, 0, len);
// return copy;
// }
}

View File

@ -23,7 +23,7 @@ import java.util.stream.IntStream;
import static org.assertj.core.api.Assertions.assertThat;
class ComputeFunctionsTest {
class ComputeFunctionsIntTest {
private final static Offset<Double> offset=Offset.offset(0.001d);
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
@ -33,7 +33,10 @@ class ComputeFunctionsTest {
private final static int[] highInts56789 = new int[]{5,6,7,8,9};
private final static int[] intsBy3_369 = new int[]{3,6,9};
private final static int[] intsBy3_693 = new int[]{3,6,9};
private final static int[] intsBy3_693 = new int[]{6,9,3};
private final static int[] midInts45678 = new int[]{4,5,6,7,8};
private final static int[] ints12390 = new int[]{1,2,3,9,0};
@Test
void testRecallIntArrays() {
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
@ -81,12 +84,49 @@ class ComputeFunctionsTest {
@Test
public void testReciprocalRank() {
assertThat(ComputeFunctions.RR(intsBy3_369,highInts56789))
assertThat(ComputeFunctions.reciprocal_rank(intsBy3_369,highInts56789))
.as("relevant results in rank 2 should yield RR=0.5")
.isCloseTo(0.5d,offset);
assertThat(ComputeFunctions.RR(highInts56789,lowInts01234))
assertThat(ComputeFunctions.reciprocal_rank(highInts56789,lowInts01234))
.as("no relevant results should yield RR=0.0")
.isCloseTo(0.0d,offset);
}
@Test
public void testIntegerIntersection() {
int[] result = Intersections.find(lowInts01234,midInts45678);
assertThat(result).isEqualTo(new int[]{4});
}
@Test
public void testCountIntIntersection() {
int result = Intersections.count(oddInts37195, ints12390);
assertThat(result).isEqualTo(2L);
}
@Test
public void testMasking() {
assertThat(Intersections.mask(ints12390,highInts56789))
.as("the last actual is relevant and should have a 1 in the mask")
.isEqualTo(new int[]{0,0,0,0,1});
assertThat(Intersections.mask(allInts,allInts))
.as("the last actual is relevant and should have a 1 in the mask")
.isEqualTo(new int[]{1,1,1,1,1,1,1,1,1,1});
}
@Test
public void testAP() {
double ap1 = ComputeFunctions.average_precision(new int[]{1, 2, 3, 4, 5, 6}, new int[]{3, 11, 5, 12, 1});
assertThat(ap1)
.as("")
.isCloseTo(0.755d,offset);
double ap2 = ComputeFunctions.average_precision(ints12390, intsBy3_369);
assertThat(ap2)
.as("")
.isCloseTo(0.833,offset);
}
}

View File

@ -1,50 +0,0 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.engine.extensions.computefunctions;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class IntersectionsTest {
@Test
public void testIntegerIntersection() {
int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8});
assertThat(result).isEqualTo(new int[]{4,5});
}
@Test
public void testLongIntersection() {
long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8});
assertThat(result).isEqualTo(new long[]{4,5});
}
@Test
public void testCountIntIntersection() {
long result = Intersections.count(new int[]{1,3,5,7,9}, new int[]{1,2,3,9,10});
assertThat(result).isEqualTo(3L);
}
@Test
public void testCountLongIntersection() {
long result = Intersections.count(new long[]{1,3,5,7,9}, new long[]{1,2,3,9,10});
assertThat(result).isEqualTo(3);
}
}