use NBConfiguration in some types

This commit is contained in:
Jonathan Shook 2021-07-20 18:26:20 -05:00
parent a12efd0db6
commit 69fc74c409
10 changed files with 184 additions and 172 deletions

View File

@ -4,10 +4,10 @@ import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.DataType;
import com.datastax.driver.core.TupleType;
import com.datastax.driver.core.TupleValue;
import io.nosqlbench.nb.api.config.standard.ConfigModel;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.nb.api.config.ConfigAware;
import io.nosqlbench.nb.api.config.ConfigModel;
import io.nosqlbench.nb.api.config.MutableConfigModel;
import io.nosqlbench.nb.api.config.standard.NBMapConfigurable;
import io.nosqlbench.nb.api.config.standard.NBConfigModel;
import java.util.*;
import java.util.function.LongFunction;
@ -28,7 +28,7 @@ import java.util.function.LongUnaryOperator;
* </LI>
*/
@ThreadSafeMapper
public class CustomFunc955 implements LongFunction<Map<?,?>>, ConfigAware {
public class CustomFunc955 implements LongFunction<Map<?,?>>, NBMapConfigurable {
private final LongToIntFunction sizefunc;
private final LongFunction<Object> keyfunc;
@ -79,8 +79,8 @@ public class CustomFunc955 implements LongFunction<Map<?,?>>, ConfigAware {
}
@Override
public ConfigModel getConfigModel() {
return new MutableConfigModel(this)
public NBConfigModel getConfigModel() {
return ConfigModel.of(this.getClass())
.optional("<cluster>", Cluster.class)
.asReadOnly();
}

View File

@ -17,28 +17,18 @@ import io.nosqlbench.engine.api.scripting.ExprEvaluator;
import io.nosqlbench.engine.api.scripting.GraalJsEvaluator;
import io.nosqlbench.engine.api.util.SSLKsFactory;
import io.nosqlbench.nb.api.errors.BasicError;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import javax.net.ssl.SSLContext;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOError;
import java.io.IOException;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.UnknownHostException;
import java.io.*;
import java.net.*;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import org.apache.commons.codec.digest.DigestUtils;
public class CQLSessionCache implements Shutdownable {
private final static Logger logger = LogManager.getLogger(CQLSessionCache.class);
@ -246,7 +236,7 @@ public class CQLSessionCache implements Shutdownable {
.ifPresent(builder::withCompression);
SSLContext context = SSLKsFactory.get().getContext(activityDef);
SSLContext context = SSLKsFactory.get().getContext(activityDef.getParams());
if (context != null) {
builder.withSSL(RemoteEndpointAwareJdkSSLOptions.builder().withSSLContext(context).build());
}

View File

@ -238,7 +238,7 @@ public class CQLSessionCache implements Shutdownable {
.ifPresent(builder::withCompression);
SSLContext context = SSLKsFactory.get().getContext(activityDef);
SSLContext context = SSLKsFactory.get().getContext(activityDef.getParams());
if (context != null) {
builder.withSSL(RemoteEndpointAwareJdkSSLOptions.builder().withSSLContext(context).build());
}

View File

@ -24,7 +24,7 @@ public class JMXActivity extends SimpleActivity implements Activity {
super.initActivity();
this.sequence = createOpSequenceFromCommands(ReadyJmxOp::new);
setDefaultsFromOpSequence(sequence);
this.sslContext= SSLKsFactory.get().getContext(activityDef);
this.sslContext= SSLKsFactory.get().getContext(activityDef.getParams());
// TODO: Require qualified default with an op sequence as the input
}

View File

@ -46,7 +46,7 @@ public class TCPClientActivity extends StdoutActivity {
SocketFactory socketFactory = SocketFactory.getDefault();
boolean sslEnabled = activityDef.getParams().getOptionalBoolean("ssl").orElse(false);
if (sslEnabled) {
socketFactory = SSLKsFactory.get().createSocketFactory(activityDef);
socketFactory = SSLKsFactory.get().createSocketFactory(activityDef.getParams());
}
String host = getActivityDef().getParams().getOptionalString("host").orElse("localhost");

View File

@ -56,7 +56,7 @@ public class TCPServerActivity extends StdoutActivity {
queue = new LinkedBlockingQueue<>(capacity);
if (sslEnabled) {
socketFactory = SSLKsFactory.get().createSSLServerSocketFactory(activityDef);
socketFactory = SSLKsFactory.get().createSSLServerSocketFactory(activityDef.getParams());
} else {
socketFactory = ServerSocketFactory.getDefault();
}

View File

@ -33,7 +33,7 @@ public class SSLKsFactoryTest {
"tlsversion=TLSv1.2",
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -46,7 +46,7 @@ public class SSLKsFactoryTest {
"kspass=nosqlbench_client"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -60,7 +60,7 @@ public class SSLKsFactoryTest {
"keyPassword=nosqlbench"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -71,7 +71,7 @@ public class SSLKsFactoryTest {
"tspass=nosqlbench_server"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -82,7 +82,7 @@ public class SSLKsFactoryTest {
"kspass=nosqlbench_client"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -94,7 +94,7 @@ public class SSLKsFactoryTest {
"keyPassword=nosqlbench"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -104,7 +104,7 @@ public class SSLKsFactoryTest {
"tlsversion=TLSv1.2",
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -116,7 +116,7 @@ public class SSLKsFactoryTest {
"keyFilePath=src/test/resources/ssl/client.key"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -126,7 +126,7 @@ public class SSLKsFactoryTest {
"caCertFilePath=src/test/resources/ssl/cacert.crt"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -137,7 +137,7 @@ public class SSLKsFactoryTest {
"keyFilePath=src/test/resources/ssl/client.key"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
assertThat(SSLKsFactory.get().getContext(activityDef.getParams())).isNotNull();
}
@Test
@ -150,7 +150,7 @@ public class SSLKsFactoryTest {
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageMatching("Unable to load the keystore. Please check.");
}
@ -164,7 +164,7 @@ public class SSLKsFactoryTest {
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageMatching("Unable to init KeyManagerFactory. Please check.*");
}
@ -177,7 +177,7 @@ public class SSLKsFactoryTest {
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageMatching("Unable to load the truststore. Please check.");
}
@ -189,7 +189,7 @@ public class SSLKsFactoryTest {
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageContaining("Unable to load caCert from")
.withCauseInstanceOf(FileNotFoundException.class);
}
@ -202,7 +202,7 @@ public class SSLKsFactoryTest {
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageContaining("Unable to load cert from")
.withCauseInstanceOf(FileNotFoundException.class);
}
@ -215,7 +215,7 @@ public class SSLKsFactoryTest {
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageContaining("Unable to load key from")
.withCauseInstanceOf(FileNotFoundException.class);
}
@ -223,14 +223,15 @@ public class SSLKsFactoryTest {
@Test
public void testOpenSSLGetContextWithMissingCertError() {
String[] params = {
"ssl=openssl",
"caCertFilePath=src/test/resources/ssl/cacert.crt",
"keyFilePath=src/test/resources/ssl/client.key"
"ssl=openssl",
"caCertFilePath=src/test/resources/ssl/cacert.crt",
"keyFilePath=src/test/resources/ssl/client.key"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.withMessageContaining("Unable to load key from")
.withCauseInstanceOf(IllegalArgumentException.class);
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef.getParams()))
.withMessageContaining("Unable to load key from")
.withCauseInstanceOf(IllegalArgumentException.class);
}
}

View File

@ -4,12 +4,16 @@ import io.nosqlbench.engine.clients.grafana.GrafanaClient;
import io.nosqlbench.engine.clients.grafana.GrafanaClientConfig;
import io.nosqlbench.engine.clients.grafana.transfer.GAnnotation;
import io.nosqlbench.nb.annotations.Service;
import io.nosqlbench.nb.api.NBEnvironment;
import io.nosqlbench.nb.api.OnError;
import io.nosqlbench.nb.api.SystemId;
import io.nosqlbench.nb.api.annotations.Annotation;
import io.nosqlbench.nb.api.annotations.Annotator;
import io.nosqlbench.nb.api.config.*;
import io.nosqlbench.nb.api.config.params.ParamsParser;
import io.nosqlbench.nb.api.config.standard.ConfigModel;
import io.nosqlbench.nb.api.config.standard.NBConfigModel;
import io.nosqlbench.nb.api.config.standard.NBConfigurable;
import io.nosqlbench.nb.api.config.standard.NBConfiguration;
import io.nosqlbench.nb.api.errors.BasicError;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -21,7 +25,7 @@ import java.util.function.Function;
import java.util.function.Supplier;
@Service(value = Annotator.class, selector = "grafana")
public class GrafanaMetricsAnnotator implements Annotator, ConfigAware {
public class GrafanaMetricsAnnotator implements Annotator, NBConfigurable {
private final static Logger logger = LogManager.getLogger("ANNOTATORS");
//private final static Logger annotationsLog = LogManager.getLogger("ANNOTATIONS" );
@ -55,7 +59,7 @@ public class GrafanaMetricsAnnotator implements Annotator, ConfigAware {
Map<String, String> labels = annotation.getLabels();
Optional.ofNullable(labels.get("alertId"))
.map(Integer::parseInt).ifPresent(ga::setAlertId);
.map(Integer::parseInt).ifPresent(ga::setAlertId);
ga.setText(annotation.toString());
@ -64,28 +68,28 @@ public class GrafanaMetricsAnnotator implements Annotator, ConfigAware {
// Target
Optional.ofNullable(labels.get("type"))
.ifPresent(ga::setType);
.ifPresent(ga::setType);
Optional.ofNullable(labels.get("id")).map(Integer::valueOf)
.ifPresent(ga::setId);
.ifPresent(ga::setId);
Optional.ofNullable(labels.get("alertId")).map(Integer::valueOf)
.ifPresent(ga::setAlertId);
.ifPresent(ga::setAlertId);
Optional.ofNullable(labels.get("dashboardId")).map(Integer::valueOf)
.ifPresent(ga::setDashboardId);
.ifPresent(ga::setDashboardId);
Optional.ofNullable(labels.get("panelId")).map(Integer::valueOf)
.ifPresent(ga::setPanelId);
.ifPresent(ga::setPanelId);
Optional.ofNullable(labels.get("userId")).map(Integer::valueOf)
.ifPresent(ga::setUserId);
.ifPresent(ga::setUserId);
Optional.ofNullable(labels.get("userName"))
.ifPresent(ga::setUserName);
.ifPresent(ga::setUserName);
Optional.ofNullable(labels.get("metric"))
.ifPresent(ga::setMetric);
.ifPresent(ga::setMetric);
// Details
@ -104,66 +108,39 @@ public class GrafanaMetricsAnnotator implements Annotator, ConfigAware {
}
@Override
public void applyConfig(Map<String, ?> providedConfig) {
ConfigModel configModel = getConfigModel();
ConfigReader cfg = configModel.apply(providedConfig);
public void applyConfig(NBConfiguration cfg) {
GrafanaClientConfig gc = new GrafanaClientConfig();
gc.setBaseUri(cfg.param("baseurl", String.class));
if (cfg.containsKey("tags")) {
this.tags = ParamsParser.parse(cfg.param("tags", String.class), false);
}
cfg.getOptional("tags")
.map(t -> ParamsParser.parse(t, false))
.ifPresent(this::setTags);
cfg.getOptional("username")
.ifPresent(
username ->
gc.basicAuth(
username,
cfg.getOptional("password").orElse("")
)
);
if (cfg.containsKey("username")) {
if (cfg.containsKey("password")) {
gc.basicAuth(
cfg.param("username", String.class),
cfg.param("password", String.class)
);
} else {
gc.basicAuth(cfg.param("username", String.class), "");
}
}
Path keyfilePath = null;
if (cfg.containsKey("apikeyfile")) {
String apikeyfile = cfg.paramEnv("apikeyfile", String.class);
keyfilePath = Path.of(apikeyfile);
} else if (cfg.containsKey("apikey")) {
gc.addHeaderSource(() -> Map.of("Authorization", "Bearer " + cfg.param("apikey", String.class)));
Optional<String> optionalApikeyfile = cfg.getEnvOptional("apikeyfile");
Optional<String> optionalApikey = cfg.getOptional("apikey");
if (optionalApikeyfile.isPresent()) {
keyfilePath=optionalApikeyfile.map(Path::of).orElseThrow();
} else if (optionalApikey.isPresent()) {
gc.addHeaderSource(() -> Map.of("Authorization", "Bearer " + optionalApikey.get()));
} else {
Optional<String> apikeyLocation = NBEnvironment.INSTANCE
.interpolate(cfg.paramEnv("apikeyfile", String.class));
keyfilePath = apikeyLocation.map(Path::of).orElseThrow();
throw new BasicError("Undefined keyfile parameters.");
}
// if (!Files.exists(keyfilePath)) {
// logger.info("Auto-configuring grafana apikey.");
// GrafanaClientConfig apiClientConf = gc.copy().basicAuth("admin", "admin");
// GrafanaClient apiClient = new GrafanaClient(apiClientConf);
// try {
// String nodeId = SystemId.getNodeId();
//
// String keyName = "nosqlbench-" + nodeId + "-" + System.currentTimeMillis();
// ApiToken apiToken = apiClient.createApiToken(keyName, "Admin", Long.MAX_VALUE);
// Files.createDirectories(keyfilePath.getParent(),
// PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rwxrwx---")));
// Files.writeString(keyfilePath, apiToken.getKey());
// } catch (Exception e) {
// throw new RuntimeException(e);
// }
// }
//
// AuthWrapper authHeaderSupplier = new AuthWrapper(
// "Authorization",
// new GrafanaKeyFileReader(keyfilePath),
// s -> "Bearer " + s
// );
// gc.addHeaderSource(authHeaderSupplier);
this.onError = OnError.valueOfName(cfg.get("onerror").toString());
cfg.getOptional("onerror").map(OnError::valueOfName).ifPresent(this::setOnError);
this.client = new GrafanaClient(gc);
@ -173,27 +150,34 @@ public class GrafanaMetricsAnnotator implements Annotator, ConfigAware {
}
private void setOnError(OnError onError) {
this.onError=onError;
}
private void setTags(Map<String, String> tags) {
this.tags = tags;
}
@Override
public ConfigModel getConfigModel() {
return new MutableConfigModel(this)
.required("baseurl", String.class,
"The base url of the grafana node, like http://localhost:3000/")
.defaultto("apikeyfile", "$NBSTATEDIR/grafana/grafana_apikey",
"The file that contains the api key, supersedes apikey")
.optional("apikey", String.class,
"The api key to use, supersedes basic username and password")
.optional("username", String.class,
"The username to use for basic auth")
.optional("password", String.class,
"The password to use for basic auth")
.defaultto("tags", "source:nosqlbench",
"The tags that identify the annotations, in k:v,... form")
// .defaultto("onerror", OnError.Warn)
.defaultto("onerror", "warn",
"What to do when an error occurs while posting an annotation")
.defaultto("timeoutms", 5000,
"connect and transport timeout for the HTTP client")
.asReadOnly();
public NBConfigModel getConfigModel() {
return ConfigModel.of(this.getClass())
.required("baseurl", String.class,
"The base url of the grafana node, like http://localhost:3000/")
.defaults("apikeyfile", "$NBSTATEDIR/grafana/grafana_apikey",
"The file that contains the api key, supersedes apikey")
.optional("apikey", String.class,
"The api key to use, supersedes basic username and password")
.optional("username", String.class,
"The username to use for basic auth")
.optional("password", String.class,
"The password to use for basic auth")
.defaults("tags", "source:nosqlbench",
"The tags that identify the annotations, in k:v,... form")
.defaults("onerror", "warn",
"What to do when an error occurs while posting an annotation")
.defaults("timeoutms", 5000,
"connect and transport timeout for the HTTP client")
.asReadOnly();
}

View File

@ -5,8 +5,8 @@ import com.google.gson.GsonBuilder;
import io.nosqlbench.nb.annotations.Service;
import io.nosqlbench.nb.api.annotations.Annotation;
import io.nosqlbench.nb.api.annotations.Annotator;
import io.nosqlbench.nb.api.config.ConfigAware;
import io.nosqlbench.nb.api.config.ConfigLoader;
import io.nosqlbench.nb.api.config.standard.NBMapConfigurable;
import io.nosqlbench.nb.api.config.standard.ConfigLoader;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager;
@ -53,9 +53,9 @@ public class Annotators {
}
Annotator annotator = annotatorProvider.get();
if (annotator instanceof ConfigAware) {
ConfigAware configAware = (ConfigAware) annotator;
configAware.applyConfig(cmap);
if (annotator instanceof NBMapConfigurable) {
NBMapConfigurable NBMapConfigurable = (NBMapConfigurable) annotator;
NBMapConfigurable.applyConfig(cmap);
}
annotators.add(annotator);

View File

@ -17,9 +17,11 @@
package io.nosqlbench.engine.api.util;
import io.nosqlbench.engine.api.activityimpl.ActivityDef;
import org.apache.logging.log4j.Logger;
import io.nosqlbench.nb.api.config.standard.*;
import io.nosqlbench.nb.api.config.standard.ConfigModel;
import io.nosqlbench.nb.api.config.standard.NBConfigModel;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import javax.net.ServerSocketFactory;
import javax.net.SocketFactory;
@ -36,16 +38,19 @@ import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;
public class SSLKsFactory {
public class SSLKsFactory implements NBMapConfigurable {
private final static Logger logger = LogManager.getLogger(SSLKsFactory.class);
private static final SSLKsFactory instance = new SSLKsFactory();
private static final Pattern CERT_PATTERN = Pattern.compile("-+BEGIN\\s+.*CERTIFICATE[^-]*-+(?:\\s|\\r|\\n)+([a-z0-9+/=\\r\\n]+)-+END\\s+.*CERTIFICATE[^-]*-+", 2);
private static final Pattern KEY_PATTERN = Pattern.compile("-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+([a-z0-9+/=\\r\\n]+)-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", 2);
public static final String SSL = "ssl";
public static final String DEFAULT_TLSVERSION = "TLSv1.2";
/**
* Consider: https://gist.github.com/artem-smotrakov/bd14e4bde4d7238f7e5ab12c697a86a3
@ -57,44 +62,53 @@ public class SSLKsFactory {
return instance;
}
public ServerSocketFactory createSSLServerSocketFactory(ActivityDef def) {
SSLContext context = getContext(def);
public ServerSocketFactory createSSLServerSocketFactory(Map<String, Object> cfgmap) {
return createSSLServerSocketFactory(getConfigModel().apply(cfgmap));
}
public ServerSocketFactory createSSLServerSocketFactory(NBConfiguration cfg) {
SSLContext context = getContext(cfg);
if (context == null) {
throw new IllegalArgumentException("SSL is not enabled.");
}
return context.getServerSocketFactory();
}
public SocketFactory createSocketFactory(ActivityDef def) {
SSLContext context = getContext(def);
public SocketFactory createSocketFactory(Map<String, Object> cfgmap) {
return createSocketFactory(getConfigModel().apply(cfgmap));
}
public SocketFactory createSocketFactory(NBConfiguration cfg) {
SSLContext context = getContext(cfg);
if (context == null) {
throw new IllegalArgumentException("SSL is not enabled.");
}
return context.getSocketFactory();
}
public SSLContext getContext(ActivityDef def) {
Optional<String> sslParam = def.getParams().getOptionalString("ssl");
public SSLContext getContext(Map<String, Object> cfgmap) {
return getContext(getConfigModel().apply(cfgmap));
}
public SSLContext getContext(NBConfiguration cfg) {
Optional<String> sslParam = cfg.getOptional(SSL);
if (sslParam.isPresent()) {
String tlsVersion = def.getParams().getOptionalString("tlsversion").orElse("TLSv1.2");
String tlsVersion = cfg.getOptional("tlsversion").orElse(DEFAULT_TLSVERSION);
KeyStore keyStore;
char[] keyPassword = null;
KeyStore trustStore;
if (sslParam.get().equals("jdk") || sslParam.get().equals("true")) {
if (sslParam.get().equals("true")) {
logger.warn("Please update your 'ssl=true' parameter to 'ssl=jdk'");
}
if (sslParam.get().equals("jdk")) {
final char[] keyStorePassword = def.getParams().getOptionalString("kspass")
.map(String::toCharArray)
.orElse(null);
keyPassword = def.getParams().getOptionalString("keyPassword")
.map(String::toCharArray)
.orElse(keyStorePassword);
final char[] keyStorePassword = cfg.getOptional("kspass")
.map(String::toCharArray)
.orElse(null);
keyPassword = cfg.getOptional("keyPassword", "keypassword")
.map(String::toCharArray)
.orElse(keyStorePassword);
keyStore = def.getParams().getOptionalString("keystore").map(ksPath -> {
keyStore = cfg.getOptional("keystore").map(ksPath -> {
try {
return KeyStore.getInstance(new File(ksPath), keyStorePassword);
} catch (Exception e) {
@ -102,12 +116,12 @@ public class SSLKsFactory {
}
}).orElse(null);
trustStore = def.getParams().getOptionalString("truststore").map(tsPath -> {
trustStore = cfg.getOptional("truststore").map(tsPath -> {
try {
return KeyStore.getInstance(new File(tsPath),
def.getParams().getOptionalString("tspass")
.map(String::toCharArray)
.orElse(null));
cfg.getOptional("tspass")
.map(String::toCharArray)
.orElse(null));
} catch (Exception e) {
throw new RuntimeException("Unable to load the truststore. Please check.", e);
}
@ -120,39 +134,40 @@ public class SSLKsFactory {
keyStore = KeyStore.getInstance("JKS");
keyStore.load(null, null);
Certificate cert = def.getParams().getOptionalString("certFilePath").map(certFilePath -> {
Certificate cert = cfg.getOptional("certFilePath").map(certFilePath -> {
try (InputStream is = new ByteArrayInputStream(loadCertFromPem(new File(certFilePath)))) {
return cf.generateCertificate(is);
} catch (Exception e) {
throw new RuntimeException(String.format("Unable to load cert from %s. Please check.",
certFilePath),
e);
certFilePath),
e);
}
}).orElse(null);
if (cert != null)
keyStore.setCertificateEntry("certFile", cert);
File keyFile = def.getParams().getOptionalString("keyFilePath").map(File::new)
.orElse(null);
File keyFile = cfg.getOptional("keyFilePath", "keyfilepath").map(File::new)
.orElse(null);
if (keyFile != null) {
try {
keyPassword = def.getParams().getOptionalString("keyPassword")
.map(String::toCharArray)
.orElse("temp_key_password".toCharArray());
keyPassword = cfg.getOptional("keyPassword", "keypassword")
.map(String::toCharArray)
.orElse("temp_key_password".toCharArray());
KeyFactory kf = KeyFactory.getInstance("RSA");
PrivateKey key = kf.generatePrivate(new PKCS8EncodedKeySpec(loadKeyFromPem(keyFile)));
keyStore.setKeyEntry("key", key, keyPassword,
cert != null ? new Certificate[]{ cert } : null);
cert != null ? new Certificate[]{cert} : null);
} catch (Exception e) {
throw new RuntimeException(String.format("Unable to load key from %s. Please check.",
keyFile),
e);
keyFile),
e);
}
}
trustStore = def.getParams().getOptionalString("caCertFilePath").map(caCertFilePath -> {
trustStore = cfg.getOptional("caCertFilePath", "cacertfilepath").map(caCertFilePath -> {
try (InputStream is = new FileInputStream(new File(caCertFilePath))) {
KeyStore ts = KeyStore.getInstance("JKS");
ts.load(null, null);
@ -162,8 +177,8 @@ public class SSLKsFactory {
return ts;
} catch (Exception e) {
throw new RuntimeException(String.format("Unable to load caCert from %s. Please check.",
caCertFilePath),
e);
caCertFilePath),
e);
}
}).orElse(null);
@ -219,4 +234,26 @@ public class SSLKsFactory {
private static byte[] loadCertFromPem(File certPemFile) throws IOException {
return loadPem(CERT_PATTERN, certPemFile);
}
@Override
public void applyConfig(Map<String, ?> providedConfig) {
}
public NBConfigModel getConfigModel() {
return ConfigModel.of(SSLKsFactory.class,
Param.optional("ssl", String.class)
.setDescription("Enable ssl and set the mode")
.setRegex("jdk|openssl"),
Param.defaultTo("tlsversion", DEFAULT_TLSVERSION),
Param.optional("kspass"),
Param.optional("keyPassword"),
Param.optional("keystore"),
Param.optional("truststore"),
Param.optional("tspass"),
Param.optional("keyFilePath"),
Param.optional("caCertFilePath"),
Param.optional("certFilePath")
).asReadOnly();
}
}