diff --git a/engine-extensions/src/main/java/io/nosqlbench/engine/extensions/computefunctions/ComputeFunctions.java b/engine-extensions/src/main/java/io/nosqlbench/engine/extensions/computefunctions/ComputeFunctions.java index 647940ef9..062ca8ec0 100644 --- a/engine-extensions/src/main/java/io/nosqlbench/engine/extensions/computefunctions/ComputeFunctions.java +++ b/engine-extensions/src/main/java/io/nosqlbench/engine/extensions/computefunctions/ComputeFunctions.java @@ -18,6 +18,8 @@ package io.nosqlbench.engine.extensions.computefunctions; import io.nosqlbench.nb.api.components.core.NBBaseComponent; import io.nosqlbench.nb.api.components.core.NBComponent; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.util.Arrays; import java.util.DoubleSummaryStatistics; @@ -43,6 +45,7 @@ import java.util.HashSet; * elide duplicates internally. */ public class ComputeFunctions extends NBBaseComponent { + private final static Logger logger = LogManager.getLogger("RUNTIME"); public ComputeFunctions(NBComponent parentComponent) { super(parentComponent); @@ -66,14 +69,16 @@ public class ComputeFunctions extends NBBaseComponent { public static double recall(long[] relevant, long[] actual, int k) { if (actual.length < k) { - throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); + logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k); } - relevant = Arrays.copyOfRange(relevant,0,k); - actual = Arrays.copyOfRange(actual, 0, k); + long divisor = Math.min(relevant.length, k); + int arrayLength = Math.max(relevant.length, actual.length); + relevant = Arrays.copyOfRange(relevant,0,arrayLength); + actual = Arrays.copyOfRange(actual, 0, arrayLength); Arrays.sort(relevant); Arrays.sort(actual); long[] intersection = Intersections.find(relevant, actual); - return (double) intersection.length / (double) relevant.length; + return (double) intersection.length / (double) divisor; } public static double precision(long[] relevant, long[] actual) { @@ -85,10 +90,11 @@ public class ComputeFunctions extends NBBaseComponent { public static double precision(long[] relevant, long[] actual, int k) { if (actual.length < k) { - throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); + logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k); } - relevant = Arrays.copyOfRange(relevant,0,k); - actual = Arrays.copyOfRange(actual, 0, k); + int arrayLength = Math.max(relevant.length, actual.length); + relevant = Arrays.copyOfRange(relevant,0,arrayLength); + actual = Arrays.copyOfRange(actual, 0, arrayLength); Arrays.sort(relevant); Arrays.sort(actual); long[] intersection = Intersections.find(relevant, actual); @@ -113,14 +119,16 @@ public class ComputeFunctions extends NBBaseComponent { public static double recall(int[] relevant, int[] actual, int k) { if (actual.length < k) { - throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); + logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k); } - relevant = Arrays.copyOfRange(relevant,0,k); - actual = Arrays.copyOfRange(actual, 0, k); + long divisor = Math.min(relevant.length, k); + int arrayLength = Math.max(relevant.length, actual.length); + relevant = Arrays.copyOfRange(relevant,0,arrayLength); + actual = Arrays.copyOfRange(actual, 0, arrayLength); Arrays.sort(relevant); Arrays.sort(actual); int intersection = Intersections.count(relevant, actual); - return (double) intersection / (double) relevant.length; + return (double) intersection / (double) divisor; } public static double precision(int[] relevant, int[] actual) { @@ -132,10 +140,11 @@ public class ComputeFunctions extends NBBaseComponent { public static double precision(int[] relevant, int[] actual, int k) { if (actual.length < k) { - throw new RuntimeException("indices fewer than limit, invalid precision computation: index count=" + actual.length + ", limit=" + k); + logger.warn("Returned indices fewer than limit in recall calculation: index count=" + actual.length + ", limit=" + k); } - relevant = Arrays.copyOfRange(relevant,0,k); - actual = Arrays.copyOfRange(actual, 0, k); + int arrayLength = Math.max(relevant.length, actual.length); + relevant = Arrays.copyOfRange(relevant,0,arrayLength); + actual = Arrays.copyOfRange(actual, 0, arrayLength); Arrays.sort(relevant); Arrays.sort(actual); int intersection = Intersections.count(relevant, actual);