mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2024-12-28 17:51:06 -06:00
Refactor SSLKsFactory
Use javax.net.ssl.SSLContext instead of Netty's SslContext
This commit is contained in:
parent
2ee74355bd
commit
c99b609a59
@ -19,8 +19,6 @@ import org.slf4j.LoggerFactory;
|
||||
import com.datastax.driver.core.Cluster;
|
||||
import com.datastax.driver.core.ProtocolOptions;
|
||||
import com.datastax.driver.core.RemoteEndpointAwareJdkSSLOptions;
|
||||
import com.datastax.driver.core.RemoteEndpointAwareNettySSLOptions;
|
||||
import com.datastax.driver.core.SSLOptions;
|
||||
import com.datastax.driver.core.Session;
|
||||
import com.datastax.driver.core.policies.DefaultRetryPolicy;
|
||||
import com.datastax.driver.core.policies.LoadBalancingPolicy;
|
||||
@ -30,7 +28,6 @@ import com.datastax.driver.core.policies.RoundRobinPolicy;
|
||||
import com.datastax.driver.core.policies.SpeculativeExecutionPolicy;
|
||||
import com.datastax.driver.core.policies.WhiteListPolicy;
|
||||
import com.datastax.driver.dse.DseCluster;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import io.nosqlbench.activitytype.cql.core.CQLOptions;
|
||||
import io.nosqlbench.activitytype.cql.core.ProxyTranslator;
|
||||
import io.nosqlbench.engine.api.activityapi.core.Shutdownable;
|
||||
@ -43,7 +40,7 @@ public class CQLSessionCache implements Shutdownable {
|
||||
|
||||
private final static Logger logger = LoggerFactory.getLogger(CQLSessionCache.class);
|
||||
private final static String DEFAULT_SESSION_ID = "default";
|
||||
private static CQLSessionCache instance = new CQLSessionCache();
|
||||
private static final CQLSessionCache instance = new CQLSessionCache();
|
||||
private Map<String, Session> sessionCache = new HashMap<>();
|
||||
|
||||
private CQLSessionCache() {
|
||||
@ -220,10 +217,9 @@ public class CQLSessionCache implements Shutdownable {
|
||||
.map(CQLOptions::withCompression)
|
||||
.ifPresent(builder::withCompression);
|
||||
|
||||
SslContext context = SSLKsFactory.get().getContext(activityDef);
|
||||
SSLContext context = SSLKsFactory.get().getContext(activityDef);
|
||||
if (context != null) {
|
||||
SSLOptions sslOptions = new RemoteEndpointAwareNettySSLOptions(context);
|
||||
builder.withSSL(sslOptions);
|
||||
builder.withSSL(RemoteEndpointAwareJdkSSLOptions.builder().withSSLContext(context).build());
|
||||
}
|
||||
|
||||
RetryPolicy retryPolicy = activityDef.getParams()
|
||||
|
@ -17,20 +17,31 @@
|
||||
|
||||
package io.nosqlbench.engine.api.util;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.KeyFactory;
|
||||
import java.security.KeyStore;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.SecureRandom;
|
||||
import java.security.cert.Certificate;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.spec.PKCS8EncodedKeySpec;
|
||||
import java.util.Base64;
|
||||
import java.util.Optional;
|
||||
import java.util.regex.Pattern;
|
||||
import javax.net.ServerSocketFactory;
|
||||
import javax.net.SocketFactory;
|
||||
import javax.net.ssl.KeyManagerFactory;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import io.netty.handler.ssl.JdkSslContext;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import io.netty.handler.ssl.SslContextBuilder;
|
||||
import io.nosqlbench.engine.api.activityimpl.ActivityDef;
|
||||
|
||||
public class SSLKsFactory {
|
||||
@ -38,6 +49,9 @@ public class SSLKsFactory {
|
||||
|
||||
private static final SSLKsFactory instance = new SSLKsFactory();
|
||||
|
||||
private static final Pattern CERT_PATTERN = Pattern.compile("-+BEGIN\\s+.*CERTIFICATE[^-]*-+(?:\\s|\\r|\\n)+([a-z0-9+/=\\r\\n]+)-+END\\s+.*CERTIFICATE[^-]*-+", 2);
|
||||
private static final Pattern KEY_PATTERN = Pattern.compile("-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+([a-z0-9+/=\\r\\n]+)-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", 2);
|
||||
|
||||
/**
|
||||
* Consider: https://gist.github.com/artem-smotrakov/bd14e4bde4d7238f7e5ab12c697a86a3
|
||||
*/
|
||||
@ -49,108 +63,165 @@ public class SSLKsFactory {
|
||||
}
|
||||
|
||||
public ServerSocketFactory createSSLServerSocketFactory(ActivityDef def) {
|
||||
SslContext context = getContext(def);
|
||||
SSLContext context = getContext(def);
|
||||
if (context == null) {
|
||||
throw new IllegalArgumentException("SSL is not enabled.");
|
||||
}
|
||||
// FIXME: potential incompatibility issue
|
||||
return ((JdkSslContext) context).context().getServerSocketFactory();
|
||||
return context.getServerSocketFactory();
|
||||
}
|
||||
|
||||
public SocketFactory createSocketFactory(ActivityDef def) {
|
||||
SslContext context = getContext(def);
|
||||
SSLContext context = getContext(def);
|
||||
if (context == null) {
|
||||
throw new IllegalArgumentException("SSL is not enabled.");
|
||||
}
|
||||
// FIXME: potential incompatibility issue
|
||||
return ((JdkSslContext) context).context().getSocketFactory();
|
||||
return context.getSocketFactory();
|
||||
}
|
||||
|
||||
public SslContext getContext(ActivityDef def) {
|
||||
public SSLContext getContext(ActivityDef def) {
|
||||
Optional<String> sslParam = def.getParams().getOptionalString("ssl");
|
||||
if (sslParam.isPresent()) {
|
||||
String tlsVersion = def.getParams().getOptionalString("tlsversion").orElse("TLSv1.2");
|
||||
|
||||
KeyStore keyStore;
|
||||
char[] keyPassword = null;
|
||||
KeyStore trustStore;
|
||||
|
||||
if (sslParam.get().equals("jdk") || sslParam.get().equals("true")) {
|
||||
if (sslParam.get().equals("true")) {
|
||||
logger.warn("Please update your 'ssl=true' parameter to 'ssl=jdk'");
|
||||
}
|
||||
|
||||
Optional<String> keystorePath = def.getParams().getOptionalString("keystore");
|
||||
Optional<String> keystorePass = def.getParams().getOptionalString("kspass");
|
||||
char[] keyPassword = def.getParams().getOptionalString("keyPassword")
|
||||
.map(String::toCharArray)
|
||||
.orElse(null);
|
||||
Optional<String> truststorePath = def.getParams().getOptionalString("truststore");
|
||||
Optional<String> truststorePass = def.getParams().getOptionalString("tspass");
|
||||
keyPassword = def.getParams().getOptionalString("keyPassword")
|
||||
.map(String::toCharArray)
|
||||
.orElse(null);
|
||||
|
||||
KeyStore ks = keystorePath.map(ksPath -> {
|
||||
keyStore = def.getParams().getOptionalString("keystore").map(ksPath -> {
|
||||
try {
|
||||
return KeyStore.getInstance(new File(ksPath),
|
||||
keystorePass.map(String::toCharArray).orElse(null));
|
||||
def.getParams().getOptionalString("kspass")
|
||||
.map(String::toCharArray)
|
||||
.orElse(null));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to load the keystore. Please check.", e);
|
||||
}
|
||||
}).orElse(null);
|
||||
|
||||
KeyManagerFactory kmf;
|
||||
try {
|
||||
kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
|
||||
kmf.init(ks, keyPassword);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to init KeyManagerFactory. Please check.", e);
|
||||
}
|
||||
|
||||
KeyStore ts = truststorePath.map(tsPath -> {
|
||||
trustStore = def.getParams().getOptionalString("truststore").map(tsPath -> {
|
||||
try {
|
||||
return KeyStore.getInstance(new File(tsPath),
|
||||
truststorePass.map(String::toCharArray).orElse(null));
|
||||
def.getParams().getOptionalString("tspass")
|
||||
.map(String::toCharArray)
|
||||
.orElse(null));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to load the truststore. Please check.", e);
|
||||
}
|
||||
}).orElse(null);
|
||||
|
||||
TrustManagerFactory tmf;
|
||||
try {
|
||||
tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||
tmf.init(ts != null ? ts : ks);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to init TrustManagerFactory. Please check.", e);
|
||||
}
|
||||
|
||||
try {
|
||||
return SslContextBuilder.forClient()
|
||||
.protocols(tlsVersion)
|
||||
.trustManager(tmf)
|
||||
.keyManager(kmf)
|
||||
.build();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else if (sslParam.get().equals("openssl")) {
|
||||
File caCertFileLocation = def.getParams().getOptionalString("caCertFilePath").map(File::new).orElse(null);
|
||||
File certFileLocation = def.getParams().getOptionalString("certFilePath").map(File::new).orElse(null);
|
||||
File keyFileLocation = def.getParams().getOptionalString("keyFilePath").map(File::new).orElse(null);
|
||||
|
||||
try {
|
||||
return SslContextBuilder.forClient()
|
||||
.protocols(tlsVersion)
|
||||
/* configured with the TrustManagerFactory that has the cert from the ca.cert
|
||||
* This tells the driver to trust the server during the SSL handshake */
|
||||
.trustManager(caCertFileLocation)
|
||||
/* These are needed if the server is configured with require_client_auth
|
||||
* In this case the client's public key must be in the truststore on each DSE
|
||||
* server node and the CA configured */
|
||||
.keyManager(certFileLocation, keyFileLocation)
|
||||
.build();
|
||||
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
||||
|
||||
keyStore = KeyStore.getInstance("JKS");
|
||||
keyStore.load(null, null);
|
||||
|
||||
Certificate cert = def.getParams().getOptionalString("certFilePath").map(certFilePath -> {
|
||||
try (InputStream is = new ByteArrayInputStream(loadCertFromPem(new File(certFilePath)))) {
|
||||
return cf.generateCertificate(is);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("Unable to load cert from %s. Please check.",
|
||||
certFilePath),
|
||||
e);
|
||||
}
|
||||
}).orElse(null);
|
||||
|
||||
if (cert != null)
|
||||
keyStore.setCertificateEntry("certFile", cert);
|
||||
|
||||
File keyFile = def.getParams().getOptionalString("keyFilePath").map(File::new)
|
||||
.orElse(null);
|
||||
if (keyFile != null) {
|
||||
try {
|
||||
keyPassword = def.getParams().getOptionalString("keyPassword")
|
||||
.map(String::toCharArray)
|
||||
.orElse("temp_key_password".toCharArray());
|
||||
|
||||
KeyFactory kf = KeyFactory.getInstance("RSA");
|
||||
PrivateKey key = kf.generatePrivate(new PKCS8EncodedKeySpec(loadKeyFromPem(keyFile)));
|
||||
keyStore.setKeyEntry("key", key, keyPassword,
|
||||
cert != null ? new Certificate[]{ cert } : null);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("Unable to load key from %s. Please check.",
|
||||
keyFile),
|
||||
e);
|
||||
}
|
||||
}
|
||||
|
||||
trustStore = def.getParams().getOptionalString("caCertFilePath").map(caCertFilePath -> {
|
||||
try (InputStream is = new FileInputStream(new File(caCertFilePath))) {
|
||||
KeyStore ts = KeyStore.getInstance("JKS");
|
||||
ts.load(null, null);
|
||||
|
||||
Certificate caCert = cf.generateCertificate(is);
|
||||
ts.setCertificateEntry("caCertFile", caCert);
|
||||
return ts;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(String.format("Unable to load caCert from %s. Please check.",
|
||||
caCertFilePath),
|
||||
e);
|
||||
}
|
||||
}).orElse(null);
|
||||
|
||||
} catch (RuntimeException re) {
|
||||
throw re;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
throw new RuntimeException("The 'ssl' parameter must have one of jdk, or openssl");
|
||||
}
|
||||
|
||||
KeyManagerFactory kmf;
|
||||
try {
|
||||
kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
|
||||
kmf.init(keyStore, keyPassword);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to init KeyManagerFactory. Please check.", e);
|
||||
}
|
||||
|
||||
TrustManagerFactory tmf;
|
||||
try {
|
||||
tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||
tmf.init(trustStore != null ? trustStore : keyStore);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Unable to init TrustManagerFactory. Please check.", e);
|
||||
}
|
||||
|
||||
try {
|
||||
SSLContext sslContext = SSLContext.getInstance(tlsVersion);
|
||||
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
|
||||
return sslContext;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] loadPem(Pattern pattern, File pemFile) throws IOException {
|
||||
try (InputStream in = new FileInputStream(pemFile)) {
|
||||
String pem = new String(in.readAllBytes(), StandardCharsets.ISO_8859_1);
|
||||
String encoded = pattern.matcher(pem).replaceFirst("$1");
|
||||
return Base64.getMimeDecoder().decode(encoded);
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] loadKeyFromPem(File keyPemFile) throws IOException {
|
||||
return loadPem(KEY_PATTERN, keyPemFile);
|
||||
}
|
||||
|
||||
private static byte[] loadCertFromPem(File certPemFile) throws IOException {
|
||||
return loadPem(CERT_PATTERN, certPemFile);
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
package io.nosqlbench.engine.api.util;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import io.nosqlbench.engine.api.activityimpl.ActivityDef;
|
||||
@ -24,8 +26,17 @@ import io.nosqlbench.engine.api.activityimpl.ActivityDef;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
|
||||
public class SSLKsFactoryTest
|
||||
{
|
||||
public class SSLKsFactoryTest {
|
||||
@Test
|
||||
public void testJdkGetContext() {
|
||||
String[] params = {
|
||||
"ssl=jdk",
|
||||
"tlsversion=TLSv1.2",
|
||||
};
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJdkGetContextWithTruststoreAndKeystore() {
|
||||
String[] params = {
|
||||
@ -64,7 +75,17 @@ public class SSLKsFactoryTest
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOpenSSLGetContextWithCaCertAndClientCert() {
|
||||
public void testOpenSSLGetContext() {
|
||||
String[] params = {
|
||||
"ssl=openssl",
|
||||
"tlsversion=TLSv1.2",
|
||||
};
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOpenSSLGetContextWithCaCertAndCertAndKey() {
|
||||
String[] params = {
|
||||
"ssl=openssl",
|
||||
"caCertFilePath=src/test/resources/ssl/cacert.crt",
|
||||
@ -86,20 +107,11 @@ public class SSLKsFactoryTest
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJdkGetContext() {
|
||||
String[] params = {
|
||||
"ssl=jdk",
|
||||
"tlsversion=TLSv1.2",
|
||||
};
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOpenSSLGetContext() {
|
||||
public void testOpenSSLGetContextWithCertAndKey() {
|
||||
String[] params = {
|
||||
"ssl=openssl",
|
||||
"tlsversion=TLSv1.2",
|
||||
"certFilePath=src/test/resources/ssl/client_cert.pem",
|
||||
"keyFilePath=src/test/resources/ssl/client.key"
|
||||
};
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
|
||||
@ -155,8 +167,8 @@ public class SSLKsFactoryTest
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThatExceptionOfType(RuntimeException.class)
|
||||
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
|
||||
.withMessageContaining("File does not contain valid certificates")
|
||||
.withCauseInstanceOf(IllegalArgumentException.class);
|
||||
.withMessageContaining("Unable to load caCert from")
|
||||
.withCauseInstanceOf(FileNotFoundException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -168,8 +180,8 @@ public class SSLKsFactoryTest
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThatExceptionOfType(RuntimeException.class)
|
||||
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
|
||||
.withMessageContaining("File does not contain valid certificates")
|
||||
.withCauseInstanceOf(IllegalArgumentException.class);
|
||||
.withMessageContaining("Unable to load cert from")
|
||||
.withCauseInstanceOf(FileNotFoundException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -181,7 +193,21 @@ public class SSLKsFactoryTest
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThatExceptionOfType(RuntimeException.class)
|
||||
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
|
||||
.withMessageContaining("File does not contain valid private key")
|
||||
.withMessageContaining("Unable to load key from")
|
||||
.withCauseInstanceOf(FileNotFoundException.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOpenSSLGetContextWithMissingCertError() {
|
||||
String[] params = {
|
||||
"ssl=openssl",
|
||||
"caCertFilePath=src/test/resources/ssl/cacert.crt",
|
||||
"keyFilePath=src/test/resources/ssl/client.key"
|
||||
};
|
||||
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
|
||||
assertThatExceptionOfType(RuntimeException.class)
|
||||
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
|
||||
.withMessageContaining("Unable to load key from")
|
||||
.withCauseInstanceOf(IllegalArgumentException.class);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user