mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
support optimo with polyglot
This commit is contained in:
@@ -19,6 +19,7 @@ package io.nosqlbench.engine.extensions.optimizers;
|
||||
|
||||
import com.codahale.metrics.MetricRegistry;
|
||||
import jdk.nashorn.api.scripting.ScriptObjectMirror;
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.apache.commons.math3.optim.*;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
|
||||
@@ -28,6 +29,7 @@ import org.slf4j.Logger;
|
||||
import javax.script.ScriptContext;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
public class BobyqaOptimizerInstance {
|
||||
|
||||
@@ -41,7 +43,7 @@ public class BobyqaOptimizerInstance {
|
||||
|
||||
private MVParams params = new MVParams();
|
||||
|
||||
private MultivariateObjectScript objectiveScriptObject;
|
||||
private MultivariateFunction objectiveFunctionFromScript;
|
||||
private SimpleBounds bounds;
|
||||
private InitialGuess initialGuess;
|
||||
private PointValuePair result;
|
||||
@@ -99,9 +101,16 @@ public class BobyqaOptimizerInstance {
|
||||
if (!scriptObject.isFunction()) {
|
||||
throw new RuntimeException("Unable to setFunction with a non-function object");
|
||||
}
|
||||
this.objectiveScriptObject =
|
||||
new MultivariateObjectScript(logger, params, scriptObject);
|
||||
this.objectiveFunctionFromScript =
|
||||
new NashornMultivariateObjectScript(logger, params, scriptObject);
|
||||
}
|
||||
|
||||
if (f instanceof Function) {
|
||||
// Function<Object[],Object> function = (Function<Object[],Object>)f;
|
||||
this.objectiveFunctionFromScript =
|
||||
new PolyglotMultivariateObjectScript(logger, params, f);
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -121,7 +130,7 @@ public class BobyqaOptimizerInstance {
|
||||
this.stoppingTrustRegionRadius
|
||||
);
|
||||
|
||||
this.mvLogger = new MVLogger(this.objectiveScriptObject);
|
||||
this.mvLogger = new MVLogger(this.objectiveFunctionFromScript);
|
||||
ObjectiveFunction objective = new ObjectiveFunction(this.mvLogger);
|
||||
|
||||
List<OptimizationData> od = List.of(
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package io.nosqlbench.engine.extensions.optimizers;
|
||||
|
||||
import org.apache.commons.math3.analysis.MultivariateFunction;
|
||||
import org.graalvm.polyglot.proxy.ProxyObject;
|
||||
import org.slf4j.Logger;
|
||||
|
||||
import java.security.InvalidParameterException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
public class PolyglotMultivariateObjectScript implements MultivariateFunction {
|
||||
|
||||
private final MVParams params;
|
||||
private final Object function;
|
||||
private Logger logger;
|
||||
|
||||
public PolyglotMultivariateObjectScript(Logger logger, MVParams params, Object function) {
|
||||
this.logger = logger;
|
||||
this.function = function;
|
||||
this.params = params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double value(double[] doubles) {
|
||||
if (doubles.length != params.size()) {
|
||||
throw new InvalidParameterException("Expected " + params.size() + " doubles, not " + doubles.length);
|
||||
}
|
||||
|
||||
Map<String,Object> fparams = new HashMap<>();
|
||||
for (int i = 0; i < params.size(); i++) {
|
||||
fparams.put(params.get(i).name, doubles[i]);
|
||||
}
|
||||
Object[] args = new Object[]{ProxyObject.fromMap(fparams)};
|
||||
|
||||
Object result = ((Function<Object[], Object>) function).apply(args);
|
||||
|
||||
if (result instanceof Double) {
|
||||
return (Double) result;
|
||||
} else {
|
||||
throw new RuntimeException(
|
||||
"Unable to case result of polyglot function return value as a double:" +
|
||||
result.getClass().getCanonicalName()+", toString=" + result.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user