mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
Merge pull request #1525 from nosqlbench/nosqlbench-1522-moremath
Nosqlbench 1522 moremath
This commit is contained in:
commit
6005aa65d4
@ -17,6 +17,8 @@
|
|||||||
package io.nosqlbench.engine.extensions.computefunctions;
|
package io.nosqlbench.engine.extensions.computefunctions;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.DoubleSummaryStatistics;
|
||||||
|
import java.util.HashSet;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* <P>A collection of compute functions related to vector search relevancy.
|
* <P>A collection of compute functions related to vector search relevancy.
|
||||||
@ -29,7 +31,8 @@ import java.util.Arrays;
|
|||||||
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
|
* metrics "@K" for any size up to and including K=100. This simply uses a partial view of the result
|
||||||
* to do exactly what would have been done for a test where you actually query for that K limit.
|
* to do exactly what would have been done for a test where you actually query for that K limit.
|
||||||
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
|
* <STRONG>This assumes that the result rank is stable irrespective of the limit AND the results
|
||||||
* are passed to these functions as ranked in results.</STRONG></P>
|
* are passed to these functions as ranked in results.</STRONG></P> Some of the methods apply K to the
|
||||||
|
* expected (relevant) indices, others to the actual (response) indices, depending on what is appropriate.
|
||||||
*
|
*
|
||||||
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
|
* <P>The array indices passed to these functions should not be sorted before-hand as a general rule.</P>
|
||||||
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
|
* Yet, no provision is made for duplicate entries. If you have duplicate indices in either array,
|
||||||
@ -41,91 +44,91 @@ public class ComputeFunctions {
|
|||||||
/**
|
/**
|
||||||
* Compute the recall as the proportion of matching indices divided by the expected indices
|
* Compute the recall as the proportion of matching indices divided by the expected indices
|
||||||
*
|
*
|
||||||
* @param referenceIndexes
|
* @param relevant
|
||||||
* long array of indices
|
* long array of indices
|
||||||
* @param sampleIndexes
|
* @param actual
|
||||||
* long array of indices
|
* long array of indices
|
||||||
* @return a fractional measure of matching vs expected indices
|
* @return a fractional measure of matching vs expected indices
|
||||||
*/
|
*/
|
||||||
public static double recall(long[] referenceIndexes, long[] sampleIndexes) {
|
public static double recall(long[] relevant, long[] actual) {
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
long[] intersection = Intersections.find(relevant, actual);
|
||||||
return (double) intersection.length / (double) referenceIndexes.length;
|
return (double) intersection.length / (double) relevant.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) {
|
public static double recall(long[] relevant, long[] actual, int k) {
|
||||||
if (sampleIndexes.length < limit) {
|
if (actual.length < k) {
|
||||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
|
||||||
}
|
}
|
||||||
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
|
actual = Arrays.copyOfRange(actual, 0, k);
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
long[] intersection = Intersections.find(relevant, actual);
|
||||||
return (double) intersection.length / (double) referenceIndexes.length;
|
return (double) intersection.length / (double) relevant.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double precision(long[] referenceIndexes, long[] sampleIndexes) {
|
public static double precision(long[] relevant, long[] actual) {
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
long[] intersection = Intersections.find(relevant, actual);
|
||||||
return (double) intersection.length / (double) sampleIndexes.length;
|
return (double) intersection.length / (double) actual.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double precision(long[] referenceIndexes, long[] sampleIndexes, int limit) {
|
public static double precision(long[] relevant, long[] actual, int k) {
|
||||||
if (sampleIndexes.length < limit) {
|
if (actual.length < k) {
|
||||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
|
||||||
}
|
}
|
||||||
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
|
actual = Arrays.copyOfRange(actual, 0, k);
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
|
long[] intersection = Intersections.find(relevant, actual);
|
||||||
return (double) intersection.length / (double) sampleIndexes.length;
|
return (double) intersection.length / (double) actual.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the recall as the proportion of matching indices divided by the expected indices
|
* Compute the recall as the proportion of matching indices divided by the expected indices
|
||||||
*
|
*
|
||||||
* @param referenceIndexes
|
* @param relevant
|
||||||
* int array of indices
|
* int array of indices
|
||||||
* @param sampleIndexes
|
* @param actual
|
||||||
* int array of indices
|
* int array of indices
|
||||||
* @return a fractional measure of matching vs expected indices
|
* @return a fractional measure of matching vs expected indices
|
||||||
*/
|
*/
|
||||||
public static double recall(int[] referenceIndexes, int[] sampleIndexes) {
|
public static double recall(int[] relevant, int[] actual) {
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
|
int intersection = Intersections.count(relevant, actual, relevant.length);
|
||||||
return (double) intersection / (double) referenceIndexes.length;
|
return (double) intersection / (double) relevant.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) {
|
public static double recall(int[] relevant, int[] actual, int k) {
|
||||||
if (sampleIndexes.length < limit) {
|
if (actual.length < k) {
|
||||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
|
||||||
}
|
}
|
||||||
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
|
actual = Arrays.copyOfRange(actual, 0, k);
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
|
int intersection = Intersections.count(relevant, actual, relevant.length);
|
||||||
return (double) intersection / (double) referenceIndexes.length;
|
return (double) intersection / (double) relevant.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double precision(int[] referenceIndexes, int[] sampleIndexes) {
|
public static double precision(int[] relevant, int[] actual) {
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
|
int intersection = Intersections.count(relevant, actual);
|
||||||
return (double) intersection / (double) sampleIndexes.length;
|
return (double) intersection / (double) actual.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) {
|
public static double precision(int[] relevant, int[] actual, int k) {
|
||||||
if (sampleIndexes.length < limit) {
|
if (actual.length < k) {
|
||||||
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
|
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k);
|
||||||
}
|
}
|
||||||
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
|
actual = Arrays.copyOfRange(actual, 0, k);
|
||||||
Arrays.sort(referenceIndexes);
|
Arrays.sort(relevant);
|
||||||
Arrays.sort(sampleIndexes);
|
Arrays.sort(actual);
|
||||||
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
|
int intersection = Intersections.count(relevant, actual);
|
||||||
return (double) intersection / (double) sampleIndexes.length;
|
return (double) intersection / (double) actual.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -135,65 +138,49 @@ public class ComputeFunctions {
|
|||||||
return Intersections.find(a, b);
|
return Intersections.find(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long[] intersection(long[] a, long[] b, int limit) {
|
|
||||||
return Intersections.find(a, b, limit);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the intersection of two int arrays
|
* Compute the intersection of two int arrays
|
||||||
*/
|
*/
|
||||||
public static int[] intersection(int[] reference, int[] sample) {
|
public static int[] intersection(int[] a, int[] b) {
|
||||||
return Intersections.find(reference, sample);
|
return Intersections.find(a, b);
|
||||||
}
|
|
||||||
|
|
||||||
public static int[] intersection(int[] reference, int[] sample, int limit) {
|
|
||||||
return Intersections.find(reference, sample, limit);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the size of the intersection of two int arrays
|
* Compute the size of the intersection of two int arrays
|
||||||
*/
|
*/
|
||||||
public static int intersectionSize(int[] reference, int[] sample) {
|
public static int intersectionSize(int[] a, int[] b) {
|
||||||
return Intersections.count(reference, sample);
|
return Intersections.count(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int intersectionSize(int[] reference, int[] sample, int limit) {
|
public static int intersectionSize(long[] a, long[] b) {
|
||||||
return Intersections.count(reference, sample, limit);
|
return Intersections.count(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int intersectionSize(long[] reference, long[] sample) {
|
public static double F1(int[] relevant, int[] actual) {
|
||||||
return Intersections.count(reference, sample);
|
return F1(relevant, actual, relevant.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int intersectionSize(long[] reference, long[] sample, int limit) {
|
public static double F1(int[] relevant, int[] actual, int k) {
|
||||||
return Intersections.count(reference, sample, limit);
|
double recallAtK = recall(relevant, actual, k);
|
||||||
}
|
double precisionAtK = precision(relevant, actual, k);
|
||||||
|
|
||||||
public static double F1(int[] reference, int[] sample) {
|
|
||||||
return F1(reference, sample, reference.length);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double F1(int[] reference, int[] sample, int limit) {
|
|
||||||
double recallAtK = recall(reference, sample, limit);
|
|
||||||
double precisionAtK = precision(reference, sample, limit);
|
|
||||||
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
|
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double F1(long[] reference, long[] sample) {
|
public static double F1(long[] relevant, long[] actual) {
|
||||||
return F1(reference, sample, reference.length);
|
return F1(relevant, actual, relevant.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double F1(long[] reference, long[] sample, int limit) {
|
public static double F1(long[] relevant, long[] actual, int k) {
|
||||||
double recallAtK = recall(reference, sample, limit);
|
double recallAtK = recall(relevant, actual, k);
|
||||||
double precisionAtK = precision(reference, sample, limit);
|
double precisionAtK = precision(relevant, actual, k);
|
||||||
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
|
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reciprocal Rank - The multiplicative inverse of the first rank which is relevant.
|
* Reciprocal Rank - The multiplicative inverse of the first rank which is relevant.
|
||||||
*/
|
*/
|
||||||
public static double RR(long[] reference, long[] sample, int limit) {
|
public static double reciprocal_rank(long[] relevant, long[] actual, int k) {
|
||||||
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
|
int firstRank = Intersections.firstMatchingIndex(relevant, actual, k);
|
||||||
if (firstRank >= 0) {
|
if (firstRank >= 0) {
|
||||||
return 1.0d / (firstRank+1);
|
return 1.0d / (firstRank+1);
|
||||||
} else {
|
} else {
|
||||||
@ -201,21 +188,63 @@ public class ComputeFunctions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double RR(long[] reference, long[] sample) {
|
public static double reciprocal_rank(long[] relevant, long[] actual) {
|
||||||
return RR(reference, sample, reference.length);
|
return reciprocal_rank(relevant, actual, relevant.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double RR(int[] reference, int[] sample, int limit) {
|
/**
|
||||||
int firstRank = Intersections.firstMatchingIndex(reference, sample, limit);
|
* RR as in M(RR)
|
||||||
if (firstRank >= 0) {
|
*/
|
||||||
|
public static double reciprocal_rank(int[] relevant, int[] actual, int k) {
|
||||||
|
int firstRank = Intersections.firstMatchingIndex(relevant, actual, k);
|
||||||
|
if (firstRank<0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return 1.0d / (firstRank+1);
|
return 1.0d / (firstRank+1);
|
||||||
} else {
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static double RR(int[] reference, int[] sample) {
|
public static double reciprocal_rank(int[] relevant, int[] actual) {
|
||||||
return RR(reference, sample, reference.length);
|
return reciprocal_rank(relevant, actual, relevant.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double average_precision(int[] relevant, int[] actual) {
|
||||||
|
return average_precision(relevant,actual,relevant.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double average_precision(int[] relevant, int[] actual, int k) {
|
||||||
|
int maxK = Math.min(k,actual.length);
|
||||||
|
HashSet<Integer> relevantSet = new HashSet<>(relevant.length);
|
||||||
|
for (Integer i : relevant) {
|
||||||
|
relevantSet.add(i);
|
||||||
|
}
|
||||||
|
int relevantCount=0;
|
||||||
|
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
|
||||||
|
for (int i = 0; i < maxK; i++) {
|
||||||
|
if (relevantSet.contains(actual[i])){
|
||||||
|
relevantCount++;
|
||||||
|
double precisionAtIdx = (double) relevantCount / (i+1);
|
||||||
|
stats.accept(precisionAtIdx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats.getAverage();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static double average_precision(long[] relevant, long[] actual, int k) {
|
||||||
|
int maxK = Math.min(k,actual.length);
|
||||||
|
HashSet<Long> refset = new HashSet<>(relevant.length);
|
||||||
|
for (Long i : relevant) {
|
||||||
|
refset.add(i);
|
||||||
|
}
|
||||||
|
int relevantCount=0;
|
||||||
|
DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
|
||||||
|
for (int i = 0; i < maxK; i++) {
|
||||||
|
if (refset.contains(actual[i])){
|
||||||
|
relevantCount++;
|
||||||
|
double precisionAtIdx = (double) relevantCount / (i+1);
|
||||||
|
stats.accept(precisionAtIdx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats.getAverage();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -27,10 +27,10 @@ public class Intersections {
|
|||||||
public static int firstMatchingIndex(long[] reference, long[] sample, int limit) {
|
public static int firstMatchingIndex(long[] reference, long[] sample, int limit) {
|
||||||
Arrays.sort(reference);
|
Arrays.sort(reference);
|
||||||
int maxIndex = Math.min(sample.length, limit);
|
int maxIndex = Math.min(sample.length, limit);
|
||||||
int foundAt=-1;
|
int foundAt = -1;
|
||||||
for (int index = 0; index < maxIndex; index++) {
|
for (int index = 0; index < maxIndex; index++) {
|
||||||
foundAt = Arrays.binarySearch(reference, sample[index]);
|
foundAt = Arrays.binarySearch(reference, sample[index]);
|
||||||
if (foundAt>=0) break;
|
if (foundAt >= 0) break;
|
||||||
}
|
}
|
||||||
return foundAt;
|
return foundAt;
|
||||||
}
|
}
|
||||||
@ -38,21 +38,18 @@ public class Intersections {
|
|||||||
public static int firstMatchingIndex(int[] reference, int[] sample, int limit) {
|
public static int firstMatchingIndex(int[] reference, int[] sample, int limit) {
|
||||||
Arrays.sort(reference);
|
Arrays.sort(reference);
|
||||||
int maxIndex = Math.min(sample.length, limit);
|
int maxIndex = Math.min(sample.length, limit);
|
||||||
int foundAt=-1;
|
int foundAt = -1;
|
||||||
for (int index = 0; index < maxIndex; index++) {
|
for (int index = 0; index < maxIndex; index++) {
|
||||||
foundAt = Arrays.binarySearch(reference, sample[index]);
|
foundAt = Arrays.binarySearch(reference, sample[index]);
|
||||||
if (foundAt>=0) break;
|
if (foundAt >= 0) break;
|
||||||
}
|
}
|
||||||
return foundAt;
|
return foundAt;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int count(int[] reference, int[] sample) {
|
public static int count(int[] reference, int[] sample) {
|
||||||
return count(reference,sample,reference.length);
|
|
||||||
}
|
|
||||||
public static int count(int[] reference, int[] sample, int limit) {
|
|
||||||
int a_index = 0, b_index = 0, matches = 0;
|
int a_index = 0, b_index = 0, matches = 0;
|
||||||
int a_element, b_element;
|
int a_element, b_element;
|
||||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
while (a_index < reference.length && b_index < sample.length) {
|
||||||
a_element = reference[a_index];
|
a_element = reference[a_index];
|
||||||
b_element = sample[b_index];
|
b_element = sample[b_index];
|
||||||
if (a_element == b_element) {
|
if (a_element == b_element) {
|
||||||
@ -71,6 +68,7 @@ public class Intersections {
|
|||||||
public static int count(long[] reference, long[] sample) {
|
public static int count(long[] reference, long[] sample) {
|
||||||
return count(reference, sample, reference.length);
|
return count(reference, sample, reference.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int count(long[] reference, long[] sample, int limit) {
|
public static int count(long[] reference, long[] sample, int limit) {
|
||||||
int a_index = 0, b_index = 0, matches = 0;
|
int a_index = 0, b_index = 0, matches = 0;
|
||||||
long a_element, b_element;
|
long a_element, b_element;
|
||||||
@ -91,18 +89,13 @@ public class Intersections {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static int[] find(int[] reference, int[] sample) {
|
public static int[] find(int[] reference, int[] sample) {
|
||||||
return find(reference,sample,reference.length);
|
int[] result = new int[sample.length];
|
||||||
}
|
|
||||||
|
|
||||||
public static int[] find(int[] reference, int[] sample, int limit) {
|
|
||||||
int[] result = new int[limit];
|
|
||||||
int a_index = 0, b_index = 0, acc_index = -1;
|
int a_index = 0, b_index = 0, acc_index = -1;
|
||||||
int a_element, b_element;
|
int a_element, b_element;
|
||||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
while (a_index < reference.length && b_index < sample.length) {
|
||||||
a_element = reference[a_index];
|
a_element = reference[a_index];
|
||||||
b_element = sample[b_index];
|
b_element = sample[b_index];
|
||||||
if (a_element == b_element) {
|
if (a_element == b_element) {
|
||||||
result = resize(result);
|
|
||||||
result[++acc_index] = a_element;
|
result[++acc_index] = a_element;
|
||||||
a_index++;
|
a_index++;
|
||||||
b_index++;
|
b_index++;
|
||||||
@ -112,21 +105,17 @@ public class Intersections {
|
|||||||
a_index++;
|
a_index++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
return Arrays.copyOfRange(result, 0, acc_index + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static long[] find(long[] reference, long[] sample) {
|
public static long[] find(long[] reference, long[] sample) {
|
||||||
return find(reference, sample, reference.length);
|
long[] result = new long[sample.length];
|
||||||
}
|
|
||||||
public static long[] find(long[] reference, long[] sample, int limit) {
|
|
||||||
long[] result = new long[limit];
|
|
||||||
int a_index = 0, b_index = 0, acc_index = -1;
|
int a_index = 0, b_index = 0, acc_index = -1;
|
||||||
long a_element, b_element;
|
long a_element, b_element;
|
||||||
while (a_index < reference.length && a_index < limit && b_index < sample.length && b_index < limit) {
|
while (a_index < reference.length && b_index < sample.length) {
|
||||||
a_element = reference[a_index];
|
a_element = reference[a_index];
|
||||||
b_element = sample[b_index];
|
b_element = sample[b_index];
|
||||||
if (a_element == b_element) {
|
if (a_element == b_element) {
|
||||||
result = resize(result);
|
|
||||||
result[++acc_index] = a_element;
|
result[++acc_index] = a_element;
|
||||||
a_index++;
|
a_index++;
|
||||||
b_index++;
|
b_index++;
|
||||||
@ -136,22 +125,7 @@ public class Intersections {
|
|||||||
a_index++;
|
a_index++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Arrays.copyOfRange(result,0,acc_index+1);
|
return Arrays.copyOfRange(result, 0, acc_index + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int[] resize(int[] arr) {
|
|
||||||
int len = arr.length;
|
|
||||||
int[] copy = new int[len + 1];
|
|
||||||
System.arraycopy(arr, 0, copy, 0, len);
|
|
||||||
return copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static long[] resize(long[] arr) {
|
|
||||||
int len = arr.length;
|
|
||||||
long[] copy = new long[len + 1];
|
|
||||||
System.arraycopy(arr, 0, copy, 0, len);
|
|
||||||
return copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ import java.util.stream.IntStream;
|
|||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
class ComputeFunctionsTest {
|
class ComputeFunctionsIntTest {
|
||||||
|
|
||||||
private final static Offset<Double> offset=Offset.offset(0.001d);
|
private final static Offset<Double> offset=Offset.offset(0.001d);
|
||||||
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
|
private final static int[] allInts =new int[]{0,1,2,3,4,5,6,7,8,9};
|
||||||
@ -33,7 +33,10 @@ class ComputeFunctionsTest {
|
|||||||
private final static int[] highInts56789 = new int[]{5,6,7,8,9};
|
private final static int[] highInts56789 = new int[]{5,6,7,8,9};
|
||||||
|
|
||||||
private final static int[] intsBy3_369 = new int[]{3,6,9};
|
private final static int[] intsBy3_369 = new int[]{3,6,9};
|
||||||
private final static int[] intsBy3_693 = new int[]{3,6,9};
|
private final static int[] intsBy3_693 = new int[]{6,9,3};
|
||||||
|
private final static int[] midInts45678 = new int[]{4,5,6,7,8};
|
||||||
|
private final static int[] ints12390 = new int[]{1,2,3,9,0};
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testRecallIntArrays() {
|
void testRecallIntArrays() {
|
||||||
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
|
assertThat(ComputeFunctions.recall(evenInts86204,oddInts37195))
|
||||||
@ -81,12 +84,49 @@ class ComputeFunctionsTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReciprocalRank() {
|
public void testReciprocalRank() {
|
||||||
assertThat(ComputeFunctions.RR(intsBy3_369,highInts56789))
|
assertThat(ComputeFunctions.reciprocal_rank(intsBy3_369,highInts56789))
|
||||||
.as("relevant results in rank 2 should yield RR=0.5")
|
.as("relevant results in rank 2 should yield RR=0.5")
|
||||||
.isCloseTo(0.5d,offset);
|
.isCloseTo(0.5d,offset);
|
||||||
|
|
||||||
assertThat(ComputeFunctions.RR(highInts56789,lowInts01234))
|
assertThat(ComputeFunctions.reciprocal_rank(highInts56789,lowInts01234))
|
||||||
.as("no relevant results should yield RR=0.0")
|
.as("no relevant results should yield RR=0.0")
|
||||||
.isCloseTo(0.0d,offset);
|
.isCloseTo(0.0d,offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIntegerIntersection() {
|
||||||
|
int[] result = Intersections.find(lowInts01234,midInts45678);
|
||||||
|
assertThat(result).isEqualTo(new int[]{4});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCountIntIntersection() {
|
||||||
|
int result = Intersections.count(oddInts37195, ints12390);
|
||||||
|
assertThat(result).isEqualTo(2L);
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testMasking() {
|
||||||
|
assertThat(Intersections.mask(ints12390,highInts56789))
|
||||||
|
.as("the last actual is relevant and should have a 1 in the mask")
|
||||||
|
.isEqualTo(new int[]{0,0,0,0,1});
|
||||||
|
|
||||||
|
assertThat(Intersections.mask(allInts,allInts))
|
||||||
|
.as("the last actual is relevant and should have a 1 in the mask")
|
||||||
|
.isEqualTo(new int[]{1,1,1,1,1,1,1,1,1,1});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAP() {
|
||||||
|
double ap1 = ComputeFunctions.average_precision(new int[]{1, 2, 3, 4, 5, 6}, new int[]{3, 11, 5, 12, 1});
|
||||||
|
assertThat(ap1)
|
||||||
|
.as("")
|
||||||
|
.isCloseTo(0.755d,offset);
|
||||||
|
|
||||||
|
double ap2 = ComputeFunctions.average_precision(ints12390, intsBy3_369);
|
||||||
|
assertThat(ap2)
|
||||||
|
.as("")
|
||||||
|
.isCloseTo(0.833,offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2023 nosqlbench
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.nosqlbench.engine.extensions.computefunctions;
|
||||||
|
|
||||||
|
import org.assertj.core.data.Offset;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.stream.LongStream;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
class ComputeFunctionsLongTest {
|
||||||
|
|
||||||
|
private final static Offset<Double> offset=Offset.offset(0.001d);
|
||||||
|
private final static long[] longs_0to9 =new long[]{0,1,2,3,4,5,6,7,8,9};
|
||||||
|
private final static long[] longs_37195 = new long[]{3,7,1,9,5};
|
||||||
|
private final static long[] longs_86204 = new long[]{8,6,2,0,4};
|
||||||
|
private final static long[] longs_01234 = new long[]{0,1,2,3,4};
|
||||||
|
private final static long[] longs_56789 = new long[]{5,6,7,8,9};
|
||||||
|
|
||||||
|
private final static long[] longs_369 = new long[]{3,6,9};
|
||||||
|
private final static long[] longs_693 = new long[]{6,9,3};
|
||||||
|
private final static long[] longs_45678 = new long[]{4,5,6,7,8};
|
||||||
|
private final static long[] longs_12390 = new long[]{1,2,3,9,0};
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testRecallLongArrays() {
|
||||||
|
assertThat(ComputeFunctions.recall(longs_86204, longs_37195))
|
||||||
|
.as("finding 0 actual of any should yield recall=0.0")
|
||||||
|
.isCloseTo(0.0d, offset);
|
||||||
|
|
||||||
|
assertThat(ComputeFunctions.recall(longs_86204, longs_369))
|
||||||
|
.as("finding 1 actual of 5 relevant should yield recall=0.2")
|
||||||
|
.isCloseTo(0.2d, offset);
|
||||||
|
|
||||||
|
assertThat(ComputeFunctions.recall(longs_86204, longs_369,1))
|
||||||
|
.as("finding 0 (limited) actual of 5 relevant should yield recall=0.0")
|
||||||
|
.isCloseTo(0.0d, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testPrecisionLongArrays() {
|
||||||
|
assertThat(ComputeFunctions.precision(longs_86204, longs_693))
|
||||||
|
.as("one of three results being relevant should yield precision=0.333")
|
||||||
|
.isCloseTo(0.333,offset);
|
||||||
|
assertThat(ComputeFunctions.precision(longs_86204, longs_01234))
|
||||||
|
.as("three of five results being relevant should yield precision=0.6")
|
||||||
|
.isCloseTo(0.6,offset);
|
||||||
|
assertThat(ComputeFunctions.precision(longs_86204, longs_37195))
|
||||||
|
.as("none of the results being relevant should yield precision=0.0")
|
||||||
|
.isCloseTo(0.0,offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void sanityCheckRecallAndLimitRatioLongs() {
|
||||||
|
long[] hundo = LongStream.range(0,100).toArray();
|
||||||
|
|
||||||
|
for (int i = 0; i < hundo.length; i++) {
|
||||||
|
long[] partial=LongStream.range(0,i).toArray();
|
||||||
|
int finalI = i;
|
||||||
|
assertThat(ComputeFunctions.recall(hundo, partial))
|
||||||
|
.as(() -> "for subset size " + finalI +", recall should be fractional/100")
|
||||||
|
.isCloseTo((double)partial.length/(double)hundo.length,offset);
|
||||||
|
assertThat(ComputeFunctions.recall(hundo, hundo, i))
|
||||||
|
.as(() -> "for full intersection, limit " + finalI +" (K) recall should be fractional/100")
|
||||||
|
.isCloseTo((double)partial.length/(double)hundo.length,offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testReciprocalRankLongs() {
|
||||||
|
assertThat(ComputeFunctions.reciprocal_rank(longs_369, longs_56789))
|
||||||
|
.as("relevant results in rank 2 should yield RR=0.5")
|
||||||
|
.isCloseTo(0.5d,offset);
|
||||||
|
|
||||||
|
assertThat(ComputeFunctions.reciprocal_rank(longs_56789, longs_01234))
|
||||||
|
.as("no relevant results should yield RR=0.0")
|
||||||
|
.isCloseTo(0.0d,offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCountLongIntersection() {
|
||||||
|
long result = Intersections.count(longs_37195, longs_12390);
|
||||||
|
assertThat(result).isEqualTo(3L);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,50 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2023 nosqlbench
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package io.nosqlbench.engine.extensions.computefunctions;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
|
||||||
|
|
||||||
class IntersectionsTest {
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIntegerIntersection() {
|
|
||||||
int[] result = Intersections.find(new int[]{1,2,3,4,5},new int[]{4,5,6,7,8});
|
|
||||||
assertThat(result).isEqualTo(new int[]{4,5});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLongIntersection() {
|
|
||||||
long[] result = Intersections.find(new long[]{1,2,3,4,5},new long[]{4,5,6,7,8});
|
|
||||||
assertThat(result).isEqualTo(new long[]{4,5});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCountIntIntersection() {
|
|
||||||
long result = Intersections.count(new int[]{1,3,5,7,9}, new int[]{1,2,3,9,10});
|
|
||||||
assertThat(result).isEqualTo(3L);
|
|
||||||
}
|
|
||||||
@Test
|
|
||||||
public void testCountLongIntersection() {
|
|
||||||
long result = Intersections.count(new long[]{1,3,5,7,9}, new long[]{1,2,3,9,10});
|
|
||||||
assertThat(result).isEqualTo(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
@ -20,12 +20,13 @@ import io.nosqlbench.api.config.NBLabels;
|
|||||||
import io.nosqlbench.api.engine.metrics.instruments.NBMetricGauge;
|
import io.nosqlbench.api.engine.metrics.instruments.NBMetricGauge;
|
||||||
|
|
||||||
import java.util.DoubleSummaryStatistics;
|
import java.util.DoubleSummaryStatistics;
|
||||||
|
import java.util.function.DoubleConsumer;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a discrete stat reservoir as a gauge.
|
* Create a discrete stat reservoir as a gauge.
|
||||||
*/
|
*/
|
||||||
public class DoubleSummaryGauge implements NBMetricGauge<Double> {
|
public class DoubleSummaryGauge implements NBMetricGauge<Double>, DoubleConsumer {
|
||||||
private final NBLabels labels;
|
private final NBLabels labels;
|
||||||
private final Stat stat;
|
private final Stat stat;
|
||||||
private final DoubleSummaryStatistics stats;
|
private final DoubleSummaryStatistics stats;
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2023 nosqlbench
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package io.nosqlbench.api.engine.metrics.instruments;
|
||||||
|
|
||||||
|
import io.nosqlbench.api.config.NBLabels;
|
||||||
|
|
||||||
|
import java.util.function.DoubleConsumer;
|
||||||
|
|
||||||
|
public class CompoundGaugeFunction implements NBMetricGauge<Double>, DoubleConsumer {
|
||||||
|
|
||||||
|
private final NBLabels labels;
|
||||||
|
private final String name;
|
||||||
|
|
||||||
|
public CompoundGaugeFunction(NBLabels labels, String name) {
|
||||||
|
this.labels = labels;
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public Double getValue() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NBLabels getLabels() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void accept(double value) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user