mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-01-11 00:12:04 -06:00
add (more) correct MAP implementation
This commit is contained in:
parent
a0006f18a7
commit
4484d22ff1
@ -17,6 +17,8 @@
|
||||
package io.nosqlbench.engine.extensions.computefunctions;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.DoubleSummaryStatistics;
|
||||
import java.util.HashSet;
|
||||
|
||||
/**
|
||||
* <P>A collection of compute functions related to vector search relevancy.
|
||||
@ -29,7 +31,8 @@ import java.util.Arrays;
|
||||
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
|
||||
* to do exactly what would have been done for a test where you actually query for that K limit.
|
||||
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
|
||||
* are passed to these functions as ranked in results.</STRONG></P>
|
||||
* are passed to these functions as ranked in results.</STRONG></P> Some of the methods apply K to the
|
||||
* expected (relevant) indices, others to the actual (response) indices, depending on what is appropriate.
|
||||
*
|
||||
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
|
||||
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
|
||||
@ -192,7 +195,7 @@ public class ComputeFunctions {
|
||||
/**
|
||||
* Reciprocal Rank - The multiplicative inverse of the first rank which is relevant.
|
||||
*/
|
||||
public static double RR(long[] reference, long[] sample, int limit) {
|
||||
public static double reciprocal_rank(long[] reference, long[] sample, int limit) {
|
||||
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
|
||||
if (firstRank >= 0) {
|
||||
return 1.0d / (firstRank+1);
|
||||
@ -201,11 +204,11 @@ public class ComputeFunctions {
|
||||
}
|
||||
}
|
||||
|
||||
public static double RR(long[] reference, long[] sample) {
|
||||
return RR(reference, sample, reference.length);
|
||||
public static double reciprocal_rank(long[] reference, long[] sample) {
|
||||
return reciprocal_rank(reference, sample, reference.length);
|
||||
}
|
||||
|
||||
public static double RR(int[] reference, int[] sample, int limit) {
|
||||
public static double reciprocal_rank(int[] reference, int[] sample, int limit) {
|
||||
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
|
||||
if (firstRank >= 0) {
|
||||
return 1.0d / (firstRank+1);
|
||||
@ -214,8 +217,47 @@ public class ComputeFunctions {
|
||||
}
|
||||
}
|
||||
|
||||
public static double RR(int[] reference, int[] sample) {
|
||||
return RR(reference, sample, reference.length);
|
||||
public static double reciprocal_rank(int[] reference, int[] sample) {
|
||||
return reciprocal_rank(reference, sample, reference.length);
|
||||
}
|
||||
|
||||
public static double average_precision(int[] reference, int[] sample) {
|
||||
return average_precision(reference,sample,reference.length);
|
||||
}
|
||||
|
||||
public static double average_precision(int[] reference, int[] sample, int k) {
|
||||
int maxK = Math.min(k,sample.length);
|
||||
HashSet<Integer> refset = new HashSet<>(reference.length);
|
||||
for (Integer i : reference) {
|
||||
refset.add(i);
|
||||
}
|
||||
int relevant=0;
|
||||
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
|
||||
for (int i = 0; i < maxK; i++) {
|
||||
if (refset.contains(sample[i])){
|
||||
relevant++;
|
||||
double precisionAtIdx = (double) relevant / (i+1);
|
||||
stats.accept(precisionAtIdx);
|
||||
}
|
||||
}
|
||||
return stats.getAverage();
|
||||
}
|
||||
|
||||
public static double average_precision(long[] reference, long[] sample, int k) {
|
||||
int maxK = Math.min(k,sample.length);
|
||||
HashSet<Long> refset = new HashSet<>(reference.length);
|
||||
for (Long i : reference) {
|
||||
refset.add(i);
|
||||
}
|
||||
int relevant=0;
|
||||
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
|
||||
for (int i = 0; i < maxK; i++) {
|
||||
if (refset.contains(sample[i])){
|
||||
relevant++;
|
||||
double precisionAtIdx = (double) relevant / (i+1);
|
||||
stats.accept(precisionAtIdx);
|
||||
}
|
||||
}
|
||||
return stats.getAverage();
|
||||
}
|
||||
}
|
||||
|
@ -27,10 +27,10 @@ public class Intersections {
|
||||
public static int firstMatchingIndex(long[] reference, long[] sample, int limit) {
|
||||
Arrays.sort(reference);
|
||||
int maxIndex = Math.min(sample.length, limit);
|
||||
int foundAt=-1;
|
||||
int foundAt = -1;
|
||||
for (int index = 0; index < maxIndex; index++) {
|
||||
foundAt = Arrays.binarySearch(reference, sample[index]);
|
||||
if (foundAt>=0) break;
|
||||
if (foundAt >= 0) break;
|
||||
}
|
||||
return foundAt;
|
||||
}
|
||||
@ -38,17 +38,18 @@ public class Intersections {
|
||||
public static int firstMatchingIndex(int[] reference, int[] sample, int limit) {
|
||||
Arrays.sort(reference);
|
||||
int maxIndex = Math.min(sample.length, limit);
|
||||
int foundAt=-1;
|
||||
int foundAt = -1;
|
||||
for (int index = 0; index < maxIndex; index++) {
|
||||
foundAt = Arrays.binarySearch(reference, sample[index]);
|
||||
if (foundAt>=0) break;
|
||||
if (foundAt >= 0) break;
|
||||
}
|
||||
return foundAt;
|
||||
}
|
||||
|
||||
public static int count(int[] reference, int[] sample) {
|
||||
return count(reference,sample,reference.length);
|
||||
return count(reference, sample, reference.length);
|
||||
}
|
||||
|
||||
public static int count(int[] reference, int[] sample, int limit) {
|
||||
int a_index = 0, b_index = 0, matches = 0;
|
||||
int a_element, b_element;
|
||||
@ -71,6 +72,7 @@ public class Intersections {
|
||||
public static int count(long[] reference, long[] sample) {
|
||||
return count(reference, sample, reference.length);
|
||||
}
|
||||
|
||||
public static int count(long[] reference, long[] sample, int limit) {
|
||||
int a_index = 0, b_index = 0, matches = 0;
|
||||
long a_element, b_element;
|
||||
@ -91,9 +93,63 @@ public class Intersections {
|
||||
}
|
||||
|
||||
public static int[] find(int[] reference, int[] sample) {
|
||||
return find(reference,sample,reference.length);
|
||||
return find(reference, sample, reference.length);
|
||||
}
|
||||
|
||||
public static int[] mask(int[] reference, int[] sample) {
|
||||
return mask(reference,sample,sample.length);
|
||||
}
|
||||
public static int[] mask(int[] reference, int[] sample, int limit) {
|
||||
int[] mask = new int[sample.length];
|
||||
int relevant_idx = 0, actual_idx = 0, acc_index = -1;
|
||||
int relevant_element, actual_element;
|
||||
|
||||
while (relevant_idx < reference.length && relevant_idx < limit && actual_idx < sample.length && actual_idx < limit) {
|
||||
relevant_element = reference[relevant_idx];
|
||||
actual_element = sample[actual_idx];
|
||||
if (relevant_element == actual_element) {
|
||||
mask[actual_idx] = 1;
|
||||
relevant_idx++;
|
||||
actual_idx++;
|
||||
} else if (actual_element < relevant_element) {
|
||||
actual_idx++;
|
||||
} else {
|
||||
relevant_idx++;
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare the actual indices to the relevant indices, and return an array
|
||||
* containing the ordered set of indices of the actual array which appear
|
||||
* in the relevant array. A perfect result looks like counting from zero.
|
||||
* @param relevant The array of relevant indices
|
||||
* @param actual The array of actual indices
|
||||
* @param limit limit the indices to the first [limit] items
|
||||
* @return An array of relevant indices in the actual array.
|
||||
*/
|
||||
public static int[] findIndirect(int[] relevant, int[] actual, int limit) {
|
||||
int[] result = new int[actual.length];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
int a_element, b_element;
|
||||
while (a_index < relevant.length && a_index < limit && b_index < actual.length && b_index < limit) {
|
||||
a_element = relevant[a_index];
|
||||
b_element = actual[b_index];
|
||||
if (a_element == b_element) {
|
||||
result[++acc_index] = b_index;
|
||||
a_index++;
|
||||
b_index++;
|
||||
} else if (b_element < a_element) {
|
||||
b_index++;
|
||||
} else {
|
||||
a_index++;
|
||||
}
|
||||
}
|
||||
return Arrays.copyOfRange(result, 0, acc_index + 1);
|
||||
}
|
||||
|
||||
|
||||
public static int[] find(int[] reference, int[] sample, int limit) {
|
||||
int[] result = new int[limit];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
@ -102,7 +158,6 @@ public class Intersections {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
result = resize(result);
|
||||
result[++acc_index] = a_element;
|
||||
a_index++;
|
||||
b_index++;
|
||||
@ -112,12 +167,13 @@ public class Intersections {
|
||||
a_index++;
|
||||
}
|
||||
}
|
||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
||||
return Arrays.copyOfRange(result, 0, acc_index + 1);
|
||||
}
|
||||
|
||||
public static long[] find(long[] reference, long[] sample) {
|
||||
return find(reference, sample, reference.length);
|
||||
}
|
||||
|
||||
public static long[] find(long[] reference, long[] sample, int limit) {
|
||||
long[] result = new long[limit];
|
||||
int a_index = 0, b_index = 0, acc_index = -1;
|
||||
@ -126,7 +182,6 @@ public class Intersections {
|
||||
a_element = reference[a_index];
|
||||
b_element = sample[b_index];
|
||||
if (a_element == b_element) {
|
||||
result = resize(result);
|
||||
result[++acc_index] = a_element;
|
||||
a_index++;
|
||||
b_index++;
|
||||
@ -136,22 +191,22 @@ public class Intersections {
|
||||
a_index++;
|
||||
}
|
||||
}
|
||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
||||
}
|
||||
|
||||
private static int[] resize(int[] arr) {
|
||||
int len = arr.length;
|
||||
int[] copy = new int[len + 1];
|
||||
System.arraycopy(arr, 0, copy, 0, len);
|
||||
return copy;
|
||||
}
|
||||
|
||||
private static long[] resize(long[] arr) {
|
||||
int len = arr.length;
|
||||
long[] copy = new long[len + 1];
|
||||
System.arraycopy(arr, 0, copy, 0, len);
|
||||
return copy;
|
||||
return Arrays.copyOfRange(result, 0, acc_index + 1);
|
||||
}
|
||||
//
|
||||
// private static int[] resize(int[] arr) {
|
||||
// int len = arr.length;
|
||||
// int[] copy = new int[len + 1];
|
||||
// System.arraycopy(arr, 0, copy, 0, len);
|
||||
// return copy;
|
||||
// }
|
||||
//
|
||||
// private static long[] resize(long[] arr) {
|
||||
// int len = arr.length;
|
||||
// long[] copy = new long[len + 1];
|
||||
// System.arraycopy(arr, 0, copy, 0, len);
|
||||
// return copy;
|
||||
// }
|
||||
|
||||
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ import java.util.stream.IntStream;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class ComputeFunctionsTest {
|
||||
class ComputeFunctionsIntTest {
|
||||
|
||||
private final static Offset<Double> offset=Offset.offset(0.001d);
|
||||
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
|
||||
@ -33,7 +33,10 @@ class ComputeFunctionsTest {
|
||||
private final static int[] highInts56789 = new int[]{5,6,7,8,9};
|
||||
|
||||
private final static int[] intsBy3_369 = new int[]{3,6,9};
|
||||
private final static int[] intsBy3_693 = new int[]{3,6,9};
|
||||
private final static int[] intsBy3_693 = new int[]{6,9,3};
|
||||
private final static int[] midInts45678 = new int[]{4,5,6,7,8};
|
||||
private final static int[] ints12390 = new int[]{1,2,3,9,0};
|
||||
|
||||
@Test
|
||||
void testRecallIntArrays() {
|
||||
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
|
||||
@ -81,12 +84,49 @@ class ComputeFunctionsTest {
|
||||
|
||||
@Test
|
||||
public void testReciprocalRank() {
|
||||
assertThat(ComputeFunctions.RR(intsBy3_369,highInts56789))
|
||||
assertThat(ComputeFunctions.reciprocal_rank(intsBy3_369,highInts56789))
|
||||
.as("relevant results in rank 2 should yield RR=0.5")
|
||||
.isCloseTo(0.5d,offset);
|
||||
|
||||
assertThat(ComputeFunctions.RR(highInts56789,lowInts01234))
|
||||
assertThat(ComputeFunctions.reciprocal_rank(highInts56789,lowInts01234))
|
||||
.as("no relevant results should yield RR=0.0")
|
||||
.isCloseTo(0.0d,offset);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntegerIntersection() {
|
||||
int[] result = Intersections.find(lowInts01234,midInts45678);
|
||||
assertThat(result).isEqualTo(new int[]{4});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountIntIntersection() {
|
||||
int result = Intersections.count(oddInts37195, ints12390);
|
||||
assertThat(result).isEqualTo(2L);
|
||||
}
|
||||
@Test
|
||||
public void testMasking() {
|
||||
assertThat(Intersections.mask(ints12390,highInts56789))
|
||||
.as("the last actual is relevant and should have a 1 in the mask")
|
||||
.isEqualTo(new int[]{0,0,0,0,1});
|
||||
|
||||
assertThat(Intersections.mask(allInts,allInts))
|
||||
.as("the last actual is relevant and should have a 1 in the mask")
|
||||
.isEqualTo(new int[]{1,1,1,1,1,1,1,1,1,1});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAP() {
|
||||
double ap1 = ComputeFunctions.average_precision(new int[]{1, 2, 3, 4, 5, 6}, new int[]{3, 11, 5, 12, 1});
|
||||
assertThat(ap1)
|
||||
.as("")
|
||||
.isCloseTo(0.755d,offset);
|
||||
|
||||
double ap2 = ComputeFunctions.average_precision(ints12390, intsBy3_369);
|
||||
assertThat(ap2)
|
||||
.as("")
|
||||
.isCloseTo(0.833,offset);
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.engine.extensions.computefunctions;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class IntersectionsTest {
|
||||
|
||||
|
||||
@Test
|
||||
public void testIntegerIntersection() {
|
||||
int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8});
|
||||
assertThat(result).isEqualTo(new int[]{4,5});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLongIntersection() {
|
||||
long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8});
|
||||
assertThat(result).isEqualTo(new long[]{4,5});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCountIntIntersection() {
|
||||
long result = Intersections.count(new int[]{1,3,5,7,9}, new int[]{1,2,3,9,10});
|
||||
assertThat(result).isEqualTo(3L);
|
||||
}
|
||||
@Test
|
||||
public void testCountLongIntersection() {
|
||||
long result = Intersections.count(new long[]{1,3,5,7,9}, new long[]{1,2,3,9,10});
|
||||
assertThat(result).isEqualTo(3);
|
||||
}
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user