updating compute functions

This commit is contained in:
Mark Wolters 2023-12-18 17:36:50 -04:00
parent 84bc7bd4fb
commit 4b3c294443

View File

@ -18,6 +18,8 @@ package io.nosqlbench.engine.extensions.computefunctions;
import io.nosqlbench.nb.api.components.core.NBBaseComponent; import io.nosqlbench.nb.api.components.core.NBBaseComponent;
import io.nosqlbench.nb.api.components.core.NBComponent; import io.nosqlbench.nb.api.components.core.NBComponent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.Arrays; import java.util.Arrays;
import java.util.DoubleSummaryStatistics; import java.util.DoubleSummaryStatistics;
@ -43,6 +45,7 @@ import java.util.HashSet;
* elide duplicates internally. * elide duplicates internally.
*/ */
public class ComputeFunctions extends NBBaseComponent { public class ComputeFunctions extends NBBaseComponent {
private final static Logger logger = LogManager.getLogger("RUNTIME");
public ComputeFunctions(NBComponent parentComponent) { public ComputeFunctions(NBComponent parentComponent) {
super(parentComponent); super(parentComponent);
@ -66,14 +69,16 @@ public class ComputeFunctions extends NBBaseComponent {
public static double recall(long[] relevant, long[] actual, int k) { public static double recall(long[] relevant, long[] actual, int k) {
if (actual.length < k) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k);
} }
relevant = Arrays.copyOfRange(relevant,0,k); long divisor = Math.min(relevant.length, k);
actual = Arrays.copyOfRange(actual, 0, k); int arrayLength = Math.max(relevant.length, actual.length);
relevant = Arrays.copyOfRange(relevant,0,arrayLength);
actual = Arrays.copyOfRange(actual, 0, arrayLength);
Arrays.sort(relevant); Arrays.sort(relevant);
Arrays.sort(actual); Arrays.sort(actual);
long[] intersection = Intersections.find(relevant, actual); long[] intersection = Intersections.find(relevant, actual);
return (double) intersection.length / (double) relevant.length; return (double) intersection.length / (double) divisor;
} }
public static double precision(long[] relevant, long[] actual) { public static double precision(long[] relevant, long[] actual) {
@ -85,10 +90,11 @@ public class ComputeFunctions extends NBBaseComponent {
public static double precision(long[] relevant, long[] actual, int k) { public static double precision(long[] relevant, long[] actual, int k) {
if (actual.length < k) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k);
} }
relevant = Arrays.copyOfRange(relevant,0,k); int arrayLength = Math.max(relevant.length, actual.length);
actual = Arrays.copyOfRange(actual, 0, k); relevant = Arrays.copyOfRange(relevant,0,arrayLength);
actual = Arrays.copyOfRange(actual, 0, arrayLength);
Arrays.sort(relevant); Arrays.sort(relevant);
Arrays.sort(actual); Arrays.sort(actual);
long[] intersection = Intersections.find(relevant, actual); long[] intersection = Intersections.find(relevant, actual);
@ -113,14 +119,16 @@ public class ComputeFunctions extends NBBaseComponent {
public static double recall(int[] relevant, int[] actual, int k) { public static double recall(int[] relevant, int[] actual, int k) {
if (actual.length < k) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k);
} }
relevant = Arrays.copyOfRange(relevant,0,k); long divisor = Math.min(relevant.length, k);
actual = Arrays.copyOfRange(actual, 0, k); int arrayLength = Math.max(relevant.length, actual.length);
relevant = Arrays.copyOfRange(relevant,0,arrayLength);
actual = Arrays.copyOfRange(actual, 0, arrayLength);
Arrays.sort(relevant); Arrays.sort(relevant);
Arrays.sort(actual); Arrays.sort(actual);
int intersection = Intersections.count(relevant, actual); int intersection = Intersections.count(relevant, actual);
return (double) intersection / (double) relevant.length; return (double) intersection / (double) divisor;
} }
public static double precision(int[] relevant, int[] actual) { public static double precision(int[] relevant, int[] actual) {
@ -132,10 +140,11 @@ public class ComputeFunctions extends NBBaseComponent {
public static double precision(int[] relevant, int[] actual, int k) { public static double precision(int[] relevant, int[] actual, int k) {
if (actual.length < k) { if (actual.length < k) {
throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k);
} }
relevant = Arrays.copyOfRange(relevant,0,k); int arrayLength = Math.max(relevant.length, actual.length);
actual = Arrays.copyOfRange(actual, 0, k); relevant = Arrays.copyOfRange(relevant,0,arrayLength);
actual = Arrays.copyOfRange(actual, 0, arrayLength);
Arrays.sort(relevant); Arrays.sort(relevant);
Arrays.sort(actual); Arrays.sort(actual);
int intersection = Intersections.count(relevant, actual); int intersection = Intersections.count(relevant, actual);