add F1 functions

This commit is contained in:
Jonathan Shook 2023-09-07 10:35:10 -05:00
parent efd882a553
commit bc3ebfaa15

View File

@ -40,8 +40,11 @@ public class ComputeFunctions {
/**
* Compute the recall as the proportion of matching indices divided by the expected indices
* @param referenceIndexes long array of indices
* @param sampleIndexes long array of indices
*
* @param referenceIndexes
* long array of indices
* @param sampleIndexes
* long array of indices
* @return a fractional measure of matching vs expected indices
*/
public static double recall(long[] referenceIndexes, long[] sampleIndexes) {
@ -50,11 +53,12 @@ public class ComputeFunctions {
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
return (double) intersection.length / (double) referenceIndexes.length;
}
public static double recall(long[] referenceIndexes, long[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
@ -69,10 +73,10 @@ public class ComputeFunctions {
}
public static double precision(long[] referenceIndexes, long[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
long[] intersection = Intersections.find(referenceIndexes, sampleIndexes);
@ -81,8 +85,11 @@ public class ComputeFunctions {
/**
* Compute the recall as the proportion of matching indices divided by the expected indices
* @param referenceIndexes int array of indices
* @param sampleIndexes int array of indices
*
* @param referenceIndexes
* int array of indices
* @param sampleIndexes
* int array of indices
* @return a fractional measure of matching vs expected indices
*/
public static double recall(int[] referenceIndexes, int[] sampleIndexes) {
@ -91,11 +98,12 @@ public class ComputeFunctions {
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
return (double) intersection / (double) referenceIndexes.length;
}
public static double recall(int[] referenceIndexes, int[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes, referenceIndexes.length);
@ -108,11 +116,12 @@ public class ComputeFunctions {
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
return (double) intersection / (double) sampleIndexes.length;
}
public static double precision(int[] referenceIndexes, int[] sampleIndexes, int limit) {
if (sampleIndexes.length<limit) {
if (sampleIndexes.length < limit) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + sampleIndexes.length + ", limit=" + limit);
}
sampleIndexes=Arrays.copyOfRange(sampleIndexes,0,limit);
sampleIndexes = Arrays.copyOfRange(sampleIndexes, 0, limit);
Arrays.sort(referenceIndexes);
Arrays.sort(sampleIndexes);
int intersection = Intersections.count(referenceIndexes, sampleIndexes);
@ -123,8 +132,9 @@ public class ComputeFunctions {
* Compute the intersection of two long arrays
*/
public static long[] intersection(long[] a, long[] b) {
return Intersections.find(a,b);
return Intersections.find(a, b);
}
public static long[] intersection(long[] a, long[] b, int limit) {
return Intersections.find(a, b, limit);
}
@ -137,7 +147,7 @@ public class ComputeFunctions {
}
public static int[] intersection(int[] reference, int[] sample, int limit) {
return Intersections.find(reference,sample,limit);
return Intersections.find(reference, sample, limit);
}
/**
@ -146,14 +156,37 @@ public class ComputeFunctions {
public static int intersectionSize(int[] reference, int[] sample) {
return Intersections.count(reference, sample);
}
public static int intersectionSize(int[] reference, int[] sample, int limit) {
return Intersections.count(reference, sample, limit);
}
public static int intersectionSize(long[] reference, long[] sample) {
return Intersections.count(reference, sample);
}
public static int intersectionSize(long[] reference, long[] sample, int limit) {
return Intersections.count(reference, sample, limit);
}
public static double F1(int[] reference, int[] sample) {
return F1(reference, sample, reference.length);
}
public static double F1(int[] reference, int[] sample, int limit) {
double recallAtK = recall(reference, sample, limit);
double precisionAtK = precision(reference, sample, limit);
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
}
public static double F1(long[] reference, long[] sample) {
return F1(reference, sample, reference.length);
}
public static double F1(long[] reference, long[] sample, int limit) {
double recallAtK = recall(reference, sample, limit);
double precisionAtK = precision(reference, sample, limit);
return 2.0d * ((recallAtK * precisionAtK) / (recallAtK + precisionAtK));
}
}