Refactor SSLKsFactory

Use javax.net.ssl.SSLContext instead of Netty's SslContext
This commit is contained in:
Justin Chu 2020-05-16 17:16:48 -04:00
parent 2ee74355bd
commit c99b609a59
3 changed files with 181 additions and 88 deletions

View File

@ -19,8 +19,6 @@ import org.slf4j.LoggerFactory;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ProtocolOptions;
import com.datastax.driver.core.RemoteEndpointAwareJdkSSLOptions;
import com.datastax.driver.core.RemoteEndpointAwareNettySSLOptions;
import com.datastax.driver.core.SSLOptions;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.policies.DefaultRetryPolicy;
import com.datastax.driver.core.policies.LoadBalancingPolicy;
@ -30,7 +28,6 @@ import com.datastax.driver.core.policies.RoundRobinPolicy;
import com.datastax.driver.core.policies.SpeculativeExecutionPolicy;
import com.datastax.driver.core.policies.WhiteListPolicy;
import com.datastax.driver.dse.DseCluster;
import io.netty.handler.ssl.SslContext;
import io.nosqlbench.activitytype.cql.core.CQLOptions;
import io.nosqlbench.activitytype.cql.core.ProxyTranslator;
import io.nosqlbench.engine.api.activityapi.core.Shutdownable;
@ -43,7 +40,7 @@ public class CQLSessionCache implements Shutdownable {
private final static Logger logger = LoggerFactory.getLogger(CQLSessionCache.class);
private final static String DEFAULT_SESSION_ID = "default";
private static CQLSessionCache instance = new CQLSessionCache();
private static final CQLSessionCache instance = new CQLSessionCache();
private Map<String, Session> sessionCache = new HashMap<>();
private CQLSessionCache() {
@ -220,10 +217,9 @@ public class CQLSessionCache implements Shutdownable {
.map(CQLOptions::withCompression)
.ifPresent(builder::withCompression);
SslContext context = SSLKsFactory.get().getContext(activityDef);
SSLContext context = SSLKsFactory.get().getContext(activityDef);
if (context != null) {
SSLOptions sslOptions = new RemoteEndpointAwareNettySSLOptions(context);
builder.withSSL(sslOptions);
builder.withSSL(RemoteEndpointAwareJdkSSLOptions.builder().withSSLContext(context).build());
}
RetryPolicy retryPolicy = activityDef.getParams()

View File

@ -17,20 +17,31 @@
package io.nosqlbench.engine.api.util;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.Optional;
import java.util.regex.Pattern;
import javax.net.ServerSocketFactory;
import javax.net.SocketFactory;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.handler.ssl.JdkSslContext;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.nosqlbench.engine.api.activityimpl.ActivityDef;
public class SSLKsFactory {
@ -38,6 +49,9 @@ public class SSLKsFactory {
private static final SSLKsFactory instance = new SSLKsFactory();
private static final Pattern CERT_PATTERN = Pattern.compile("-+BEGIN\\s+.*CERTIFICATE[^-]*-+(?:\\s|\\r|\\n)+([a-z0-9+/=\\r\\n]+)-+END\\s+.*CERTIFICATE[^-]*-+", 2);
private static final Pattern KEY_PATTERN = Pattern.compile("-+BEGIN\\s+.*PRIVATE\\s+KEY[^-]*-+(?:\\s|\\r|\\n)+([a-z0-9+/=\\r\\n]+)-+END\\s+.*PRIVATE\\s+KEY[^-]*-+", 2);
/**
* Consider: https://gist.github.com/artem-smotrakov/bd14e4bde4d7238f7e5ab12c697a86a3
*/
@ -49,108 +63,165 @@ public class SSLKsFactory {
}
public ServerSocketFactory createSSLServerSocketFactory(ActivityDef def) {
SslContext context = getContext(def);
SSLContext context = getContext(def);
if (context == null) {
throw new IllegalArgumentException("SSL is not enabled.");
}
// FIXME: potential incompatibility issue
return ((JdkSslContext) context).context().getServerSocketFactory();
return context.getServerSocketFactory();
}
public SocketFactory createSocketFactory(ActivityDef def) {
SslContext context = getContext(def);
SSLContext context = getContext(def);
if (context == null) {
throw new IllegalArgumentException("SSL is not enabled.");
}
// FIXME: potential incompatibility issue
return ((JdkSslContext) context).context().getSocketFactory();
return context.getSocketFactory();
}
public SslContext getContext(ActivityDef def) {
public SSLContext getContext(ActivityDef def) {
Optional<String> sslParam = def.getParams().getOptionalString("ssl");
if (sslParam.isPresent()) {
String tlsVersion = def.getParams().getOptionalString("tlsversion").orElse("TLSv1.2");
KeyStore keyStore;
char[] keyPassword = null;
KeyStore trustStore;
if (sslParam.get().equals("jdk") || sslParam.get().equals("true")) {
if (sslParam.get().equals("true")) {
logger.warn("Please update your 'ssl=true' parameter to 'ssl=jdk'");
}
Optional<String> keystorePath = def.getParams().getOptionalString("keystore");
Optional<String> keystorePass = def.getParams().getOptionalString("kspass");
char[] keyPassword = def.getParams().getOptionalString("keyPassword")
.map(String::toCharArray)
.orElse(null);
Optional<String> truststorePath = def.getParams().getOptionalString("truststore");
Optional<String> truststorePass = def.getParams().getOptionalString("tspass");
keyPassword = def.getParams().getOptionalString("keyPassword")
.map(String::toCharArray)
.orElse(null);
KeyStore ks = keystorePath.map(ksPath -> {
keyStore = def.getParams().getOptionalString("keystore").map(ksPath -> {
try {
return KeyStore.getInstance(new File(ksPath),
keystorePass.map(String::toCharArray).orElse(null));
def.getParams().getOptionalString("kspass")
.map(String::toCharArray)
.orElse(null));
} catch (Exception e) {
throw new RuntimeException("Unable to load the keystore. Please check.", e);
}
}).orElse(null);
KeyManagerFactory kmf;
try {
kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(ks, keyPassword);
} catch (Exception e) {
throw new RuntimeException("Unable to init KeyManagerFactory. Please check.", e);
}
KeyStore ts = truststorePath.map(tsPath -> {
trustStore = def.getParams().getOptionalString("truststore").map(tsPath -> {
try {
return KeyStore.getInstance(new File(tsPath),
truststorePass.map(String::toCharArray).orElse(null));
def.getParams().getOptionalString("tspass")
.map(String::toCharArray)
.orElse(null));
} catch (Exception e) {
throw new RuntimeException("Unable to load the truststore. Please check.", e);
}
}).orElse(null);
TrustManagerFactory tmf;
try {
tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ts != null ? ts : ks);
} catch (Exception e) {
throw new RuntimeException("Unable to init TrustManagerFactory. Please check.", e);
}
try {
return SslContextBuilder.forClient()
.protocols(tlsVersion)
.trustManager(tmf)
.keyManager(kmf)
.build();
} catch (Exception e) {
throw new RuntimeException(e);
}
} else if (sslParam.get().equals("openssl")) {
File caCertFileLocation = def.getParams().getOptionalString("caCertFilePath").map(File::new).orElse(null);
File certFileLocation = def.getParams().getOptionalString("certFilePath").map(File::new).orElse(null);
File keyFileLocation = def.getParams().getOptionalString("keyFilePath").map(File::new).orElse(null);
try {
return SslContextBuilder.forClient()
.protocols(tlsVersion)
/* configured with the TrustManagerFactory that has the cert from the ca.cert
* This tells the driver to trust the server during the SSL handshake */
.trustManager(caCertFileLocation)
/* These are needed if the server is configured with require_client_auth
* In this case the client's public key must be in the truststore on each DSE
* server node and the CA configured */
.keyManager(certFileLocation, keyFileLocation)
.build();
CertificateFactory cf = CertificateFactory.getInstance("X.509");
keyStore = KeyStore.getInstance("JKS");
keyStore.load(null, null);
Certificate cert = def.getParams().getOptionalString("certFilePath").map(certFilePath -> {
try (InputStream is = new ByteArrayInputStream(loadCertFromPem(new File(certFilePath)))) {
return cf.generateCertificate(is);
} catch (Exception e) {
throw new RuntimeException(String.format("Unable to load cert from %s. Please check.",
certFilePath),
e);
}
}).orElse(null);
if (cert != null)
keyStore.setCertificateEntry("certFile", cert);
File keyFile = def.getParams().getOptionalString("keyFilePath").map(File::new)
.orElse(null);
if (keyFile != null) {
try {
keyPassword = def.getParams().getOptionalString("keyPassword")
.map(String::toCharArray)
.orElse("temp_key_password".toCharArray());
KeyFactory kf = KeyFactory.getInstance("RSA");
PrivateKey key = kf.generatePrivate(new PKCS8EncodedKeySpec(loadKeyFromPem(keyFile)));
keyStore.setKeyEntry("key", key, keyPassword,
cert != null ? new Certificate[]{ cert } : null);
} catch (Exception e) {
throw new RuntimeException(String.format("Unable to load key from %s. Please check.",
keyFile),
e);
}
}
trustStore = def.getParams().getOptionalString("caCertFilePath").map(caCertFilePath -> {
try (InputStream is = new FileInputStream(new File(caCertFilePath))) {
KeyStore ts = KeyStore.getInstance("JKS");
ts.load(null, null);
Certificate caCert = cf.generateCertificate(is);
ts.setCertificateEntry("caCertFile", caCert);
return ts;
} catch (Exception e) {
throw new RuntimeException(String.format("Unable to load caCert from %s. Please check.",
caCertFilePath),
e);
}
}).orElse(null);
} catch (RuntimeException re) {
throw re;
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
throw new RuntimeException("The 'ssl' parameter must have one of jdk, or openssl");
}
KeyManagerFactory kmf;
try {
kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(keyStore, keyPassword);
} catch (Exception e) {
throw new RuntimeException("Unable to init KeyManagerFactory. Please check.", e);
}
TrustManagerFactory tmf;
try {
tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(trustStore != null ? trustStore : keyStore);
} catch (Exception e) {
throw new RuntimeException("Unable to init TrustManagerFactory. Please check.", e);
}
try {
SSLContext sslContext = SSLContext.getInstance(tlsVersion);
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());
return sslContext;
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
return null;
}
}
private static byte[] loadPem(Pattern pattern, File pemFile) throws IOException {
try (InputStream in = new FileInputStream(pemFile)) {
String pem = new String(in.readAllBytes(), StandardCharsets.ISO_8859_1);
String encoded = pattern.matcher(pem).replaceFirst("$1");
return Base64.getMimeDecoder().decode(encoded);
}
}
private static byte[] loadKeyFromPem(File keyPemFile) throws IOException {
return loadPem(KEY_PATTERN, keyPemFile);
}
private static byte[] loadCertFromPem(File certPemFile) throws IOException {
return loadPem(CERT_PATTERN, certPemFile);
}
}

View File

@ -17,6 +17,8 @@
package io.nosqlbench.engine.api.util;
import java.io.FileNotFoundException;
import org.junit.Test;
import io.nosqlbench.engine.api.activityimpl.ActivityDef;
@ -24,8 +26,17 @@ import io.nosqlbench.engine.api.activityimpl.ActivityDef;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
public class SSLKsFactoryTest
{
public class SSLKsFactoryTest {
@Test
public void testJdkGetContext() {
String[] params = {
"ssl=jdk",
"tlsversion=TLSv1.2",
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
}
@Test
public void testJdkGetContextWithTruststoreAndKeystore() {
String[] params = {
@ -64,7 +75,17 @@ public class SSLKsFactoryTest
}
@Test
public void testOpenSSLGetContextWithCaCertAndClientCert() {
public void testOpenSSLGetContext() {
String[] params = {
"ssl=openssl",
"tlsversion=TLSv1.2",
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
}
@Test
public void testOpenSSLGetContextWithCaCertAndCertAndKey() {
String[] params = {
"ssl=openssl",
"caCertFilePath=src/test/resources/ssl/cacert.crt",
@ -86,20 +107,11 @@ public class SSLKsFactoryTest
}
@Test
public void testJdkGetContext() {
String[] params = {
"ssl=jdk",
"tlsversion=TLSv1.2",
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
}
@Test
public void testOpenSSLGetContext() {
public void testOpenSSLGetContextWithCertAndKey() {
String[] params = {
"ssl=openssl",
"tlsversion=TLSv1.2",
"certFilePath=src/test/resources/ssl/client_cert.pem",
"keyFilePath=src/test/resources/ssl/client.key"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThat(SSLKsFactory.get().getContext(activityDef)).isNotNull();
@ -155,8 +167,8 @@ public class SSLKsFactoryTest
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.withMessageContaining("File does not contain valid certificates")
.withCauseInstanceOf(IllegalArgumentException.class);
.withMessageContaining("Unable to load caCert from")
.withCauseInstanceOf(FileNotFoundException.class);
}
@Test
@ -168,8 +180,8 @@ public class SSLKsFactoryTest
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.withMessageContaining("File does not contain valid certificates")
.withCauseInstanceOf(IllegalArgumentException.class);
.withMessageContaining("Unable to load cert from")
.withCauseInstanceOf(FileNotFoundException.class);
}
@Test
@ -181,7 +193,21 @@ public class SSLKsFactoryTest
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.withMessageContaining("File does not contain valid private key")
.withMessageContaining("Unable to load key from")
.withCauseInstanceOf(FileNotFoundException.class);
}
@Test
public void testOpenSSLGetContextWithMissingCertError() {
String[] params = {
"ssl=openssl",
"caCertFilePath=src/test/resources/ssl/cacert.crt",
"keyFilePath=src/test/resources/ssl/client.key"
};
ActivityDef activityDef = ActivityDef.parseActivityDef(String.join(";", params));
assertThatExceptionOfType(RuntimeException.class)
.isThrownBy(() -> SSLKsFactory.get().getContext(activityDef))
.withMessageContaining("Unable to load key from")
.withCauseInstanceOf(IllegalArgumentException.class);
}
}
}