Merge pull request #1523 from nosqlbench/nosqlbench-1522-moremath

Nosqlbench 1522 moremath
This commit is contained in:
Jonathan Shook 2023-09-07 14:03:24 -05:00 committed by GitHub
commit a0006f18a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 14 deletions

View File

@ -40,8 +40,11 @@ public class ComputeFunctions {
/** /**
* Compute the recall as the proportion of matching indices divided by the expected indices * Compute the recall as the proportion of matching indices divided by the expected indices
* @param referenceIndexes long array of indices *
* @param sampleIndexes long array of indices * @param referenceIndexes
* long array of indices
* @param sampleIndexes
* long array of indices
* @return a fractional measure of matching vs expected indices * @return a fractional measure of matching vs expected indices
*/ */
public static double recall(long[] referenceIndexes, long[] sampleIndexes) { public static double recall(long[] referenceIndexes, long[] sampleIndexes) {
@ -50,6 +53,7 @@ public class ComputeFunctions {
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes); long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
return (double) intersection.length / (double) referenceIndexes.length; return (double) intersection.length / (double) referenceIndexes.length;
} }
public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) { public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) {
if (sampleIndexes.length < limit) { if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
@ -81,8 +85,11 @@ public class ComputeFunctions {
/** /**
* Compute the recall as the proportion of matching indices divided by the expected indices * Compute the recall as the proportion of matching indices divided by the expected indices
* @param referenceIndexes int array of indices *
* @param sampleIndexes int array of indices * @param referenceIndexes
* int array of indices
* @param sampleIndexes
* int array of indices
* @return a fractional measure of matching vs expected indices * @return a fractional measure of matching vs expected indices
*/ */
public static double recall(int[] referenceIndexes, int[] sampleIndexes) { public static double recall(int[] referenceIndexes, int[] sampleIndexes) {
@ -91,6 +98,7 @@ public class ComputeFunctions {
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length); int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
return (double) intersection / (double) referenceIndexes.length; return (double) intersection / (double) referenceIndexes.length;
} }
public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) { public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) {
if (sampleIndexes.length < limit) { if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
@ -108,6 +116,7 @@ public class ComputeFunctions {
int intersection = Intersections.count(referenceIndexes, sampleIndexes); int intersection = Intersections.count(referenceIndexes, sampleIndexes);
return (double) intersection / (double) sampleIndexes.length; return (double) intersection / (double) sampleIndexes.length;
} }
public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) { public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) {
if (sampleIndexes.length < limit) { if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit); throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
@ -125,6 +134,7 @@ public class ComputeFunctions {
public static long[] intersection(long[] a, long[] b) { public static long[] intersection(long[] a, long[] b) {
return Intersections.find(a, b); return Intersections.find(a, b);
} }
public static long[] intersection(long[] a, long[] b, int limit) { public static long[] intersection(long[] a, long[] b, int limit) {
return Intersections.find(a, b, limit); return Intersections.find(a, b, limit);
} }
@ -146,14 +156,66 @@ public class ComputeFunctions {
public static int intersectionSize(int[] reference, int[] sample) { public static int intersectionSize(int[] reference, int[] sample) {
return Intersections.count(reference, sample); return Intersections.count(reference, sample);
} }
public static int intersectionSize(int[] reference, int[] sample, int limit) { public static int intersectionSize(int[] reference, int[] sample, int limit) {
return Intersections.count(reference, sample, limit); return Intersections.count(reference, sample, limit);
} }
public static int intersectionSize(long[] reference, long[] sample) { public static int intersectionSize(long[] reference, long[] sample) {
return Intersections.count(reference, sample); return Intersections.count(reference, sample);
} }
public static int intersectionSize(long[] reference, long[] sample, int limit) { public static int intersectionSize(long[] reference, long[] sample, int limit) {
return Intersections.count(reference, sample, limit); return Intersections.count(reference, sample, limit);
} }
public static double F1(int[] reference, int[] sample) {
return F1(reference, sample, reference.length);
}
public static double F1(int[] reference, int[] sample, int limit) {
double recallAtK = recall(reference, sample, limit);
double precisionAtK = precision(reference, sample, limit);
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
}
public static double F1(long[] reference, long[] sample) {
return F1(reference, sample, reference.length);
}
public static double F1(long[] reference, long[] sample, int limit) {
double recallAtK = recall(reference, sample, limit);
double precisionAtK = precision(reference, sample, limit);
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
}
/**
* Reciprocal Rank - The multiplicative inverse of the first rank which is relevant.
*/
public static double RR(long[] reference, long[] sample, int limit) {
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
if (firstRank >= 0) {
return 1.0d / (firstRank+1);
} else {
return 0.0;
}
}
public static double RR(long[] reference, long[] sample) {
return RR(reference, sample, reference.length);
}
public static double RR(int[] reference, int[] sample, int limit) {
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
if (firstRank >= 0) {
return 1.0d / (firstRank+1);
} else {
return 0.0;
}
}
public static double RR(int[] reference, int[] sample) {
return RR(reference, sample, reference.length);
}
} }

View File

@ -20,6 +20,32 @@ import java.util.Arrays;
public class Intersections { public class Intersections {
/**
* Return a non-negative index of the first value in the sample array which is present in the reference array,
* OR, return a negative number. This returns array index which start at 0, not rank, which is starts at 1.
*/
public static int firstMatchingIndex(long[] reference, long[] sample, int limit) {
Arrays.sort(reference);
int maxIndex = Math.min(sample.length, limit);
int foundAt=-1;
for (int index = 0; index < maxIndex; index++) {
foundAt = Arrays.binarySearch(reference, sample[index]);
if (foundAt>=0) break;
}
return foundAt;
}
public static int firstMatchingIndex(int[] reference, int[] sample, int limit) {
Arrays.sort(reference);
int maxIndex = Math.min(sample.length, limit);
int foundAt=-1;
for (int index = 0; index < maxIndex; index++) {
foundAt = Arrays.binarySearch(reference, sample[index]);
if (foundAt>=0) break;
}
return foundAt;
}
public static int count(int[] reference, int[] sample) { public static int count(int[] reference, int[] sample) {
return count(reference,sample,reference.length); return count(reference,sample,reference.length);
} }

View File

@ -78,4 +78,15 @@ class ComputeFunctionsTest {
} }
} }
@Test
public void testReciprocalRank() {
assertThat(ComputeFunctions.RR(intsBy3_369,highInts56789))
.as("relevant results in rank 2 should yield RR=0.5")
.isCloseTo(0.5d,offset);
assertThat(ComputeFunctions.RR(highInts56789,lowInts01234))
.as("no relevant results should yield RR=0.0")
.isCloseTo(0.0d,offset);
}
} }