Fixes handling of List vs. Float

This commit is contained in:
jeffbanks
2023-06-02 14:47:51 -05:00
parent b9d5008c5a
commit 08989143ed
2 changed files with 71 additions and 76 deletions

View File

@@ -22,6 +22,7 @@ import com.datastax.oss.driver.api.core.cql.ColumnDefinition;
import com.datastax.oss.driver.api.core.cql.ColumnDefinitions; import com.datastax.oss.driver.api.core.cql.ColumnDefinitions;
import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.data.CqlDuration; import com.datastax.oss.driver.api.core.data.CqlDuration;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.data.TupleValue; import com.datastax.oss.driver.api.core.data.TupleValue;
import com.datastax.oss.driver.api.core.data.UdtValue; import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.type.DataType; import com.datastax.oss.driver.api.core.type.DataType;
@@ -50,84 +51,76 @@ import static com.datastax.oss.protocol.internal.ProtocolConstants.DataType.*;
* explaining more specifically what the problem was that caused the original error to be thrown. * explaining more specifically what the problem was that caused the original error to be thrown.
*/ */
public class CQLD4PreparedStmtDiagnostics { public class CQLD4PreparedStmtDiagnostics {
private final static Logger logger = LogManager.getLogger(CQLD4PreparedStmtDiagnostics.class); private static final Logger logger = LogManager.getLogger(CQLD4PreparedStmtDiagnostics.class);
public static BoundStatement bindStatement(BoundStatement bound, CqlIdentifier colname, Object colval, DataType coltype) { public static BoundStatement bindStatement(BoundStatement bound, CqlIdentifier colname,
Object colval, DataType coltype) {
// if (coltype instanceof PrimitiveType pt) { return switch (coltype.getProtocolCode()) {
try { // TODO - need to handle 'custom' more specifically, interim change for vector search.
BoundStatement reproduce_error_with_custom_dataType = switch (coltype.getProtocolCode()) { case CUSTOM -> bound.setCqlVector(colname, (CqlVector<?>) colval);
case CUSTOM -> throw new OpConfigError("Error with Custom DataType"); case ASCII, VARCHAR -> bound.setString(colname, (String) colval);
case ASCII, VARCHAR -> bound.setString(colname, (String) colval); case BIGINT, COUNTER -> bound.setLong(colname, (long) colval);
case BIGINT, COUNTER -> bound.setLong(colname, (long) colval); case BLOB -> bound.setByteBuffer(colname, (ByteBuffer) colval);
case BLOB -> bound.setByteBuffer(colname, (ByteBuffer) colval); case BOOLEAN -> bound.setBoolean(colname, (boolean) colval);
case BOOLEAN -> bound.setBoolean(colname, (boolean) colval); case DECIMAL -> bound.setBigDecimal(colname, (BigDecimal) colval);
case DECIMAL -> bound.setBigDecimal(colname, (BigDecimal) colval); case DOUBLE -> bound.setDouble(colname, (double) colval);
case DOUBLE -> bound.setDouble(colname, (double) colval); case FLOAT -> bound.setFloat(colname, (float) colval);
case FLOAT -> bound.setFloat(colname, (float) colval); case INT -> bound.setInt(colname, (int) colval);
case INT -> bound.setInt(colname, (int) colval); case SMALLINT -> bound.setShort(colname, (short) colval);
case SMALLINT -> bound.setShort(colname, (short) colval); case TINYINT -> bound.setByte(colname, (byte) colval);
case TINYINT -> bound.setByte(colname, (byte) colval); case TIMESTAMP -> bound.setInstant(colname, (Instant) colval);
case TIMESTAMP -> bound.setInstant(colname, (Instant) colval); case TIMEUUID, UUID -> bound.setUuid(colname, (UUID) colval);
case TIMEUUID, UUID -> bound.setUuid(colname, (UUID) colval); case VARINT -> bound.setBigInteger(colname, (BigInteger) colval);
case VARINT -> bound.setBigInteger(colname, (BigInteger) colval); case INET -> bound.setInetAddress(colname, (InetAddress) colval);
case INET -> bound.setInetAddress(colname, (InetAddress) colval); case DATE -> bound.setLocalDate(colname, (LocalDate) colval);
case DATE -> bound.setLocalDate(colname, (LocalDate) colval); case TIME -> bound.setLocalTime(colname, (LocalTime) colval);
case TIME -> bound.setLocalTime(colname, (LocalTime) colval); case DURATION -> bound.setCqlDuration(colname, (CqlDuration) colval);
case DURATION -> bound.setCqlDuration(colname, (CqlDuration) colval); case LIST -> bound.setList(colname, (List) colval, ((List) colval).get(0).getClass());
case LIST -> bound.setList(colname, (List) colval, ((List) colval).get(0).getClass()); case MAP -> {
case MAP -> { Map map = (Map) colval;
Map map = (Map) colval; Set<Map.Entry> entries = map.entrySet();
Set<Map.Entry> entries = map.entrySet(); Optional<Map.Entry> first = entries.stream().findFirst();
Optional<Map.Entry> first = entries.stream().findFirst(); if (first.isPresent()) {
if (first.isPresent()) { yield bound.setMap(colname, map, first.get().getKey().getClass(), first.get().getValue().getClass());
yield bound.setMap(colname, map, first.get().getKey().getClass(), first.get().getValue().getClass()); } else {
} else { yield bound.setMap(colname, map, Object.class, Object.class);
yield bound.setMap(colname, map, Object.class, Object.class);
}
} }
case SET -> { }
Set set = (Set) colval; case SET -> {
Optional first = set.stream().findFirst(); Set set = (Set) colval;
if (first.isPresent()) { Optional first = set.stream().findFirst();
yield bound.setSet(colname, set, first.get().getClass()); if (first.isPresent()) {
} else { yield bound.setSet(colname, set, first.get().getClass());
yield bound.setSet(colname, Set.of(), Object.class); } else {
} yield bound.setSet(colname, Set.of(), Object.class);
} }
case UDT -> { }
UdtValue udt = (UdtValue) colval; case UDT -> {
yield bound.setUdtValue(colname, udt); UdtValue udt = (UdtValue) colval;
} yield bound.setUdtValue(colname, udt);
case TUPLE -> { }
TupleValue tuple = (TupleValue) colval; case TUPLE -> {
yield bound.setTupleValue(colname, tuple); TupleValue tuple = (TupleValue) colval;
} yield bound.setTupleValue(colname, tuple);
default -> throw new RuntimeException("Unknown CQL type for diagnostic (type:'" + coltype + "',code:'" + coltype.getProtocolCode() + "'"); }
}; default -> throw new RuntimeException("Unknown CQL type for diagnostic " +
return reproduce_error_with_custom_dataType; "(type:'" + coltype + "',code:'" + coltype.getProtocolCode() + "'");
};
} catch (Exception e) {
throw e;
}
// }
// throw new IllegalStateException("Unexpected value: " + coltype);
} }
public static Cqld4CqlOp rebindWithDiagnostics( public static Cqld4CqlOp rebindWithDiagnostics(
PreparedStatement preparedStmt, PreparedStatement preparedStmt,
LongFunction<Object[]> fieldsF, LongFunction<Object[]> fieldsF,
long cycle, long cycle,
Exception exception Exception exception
) { ) {
logger.error(exception); logger.error(exception);
ColumnDefinitions defs = preparedStmt.getVariableDefinitions(); ColumnDefinitions defs = preparedStmt.getVariableDefinitions();
Object[] values = fieldsF.apply(cycle); Object[] values = fieldsF.apply(cycle);
if (defs.size() != values.length) { if (defs.size() != values.length) {
throw new OpConfigError("There are " + defs.size() + " anchors in statement '" + preparedStmt.getQuery() + "'" + throw new OpConfigError("There are " + defs.size() + " anchors in statement '" + preparedStmt.getQuery() + "'" +
"but " + values.length + " values were provided. These must match."); "but " + values.length + " values were provided. These must match.");
} }
BoundStatement bound = preparedStmt.bind(); BoundStatement bound = preparedStmt.bind();
@@ -143,11 +136,11 @@ public class CQLD4PreparedStmtDiagnostics {
String fullValue = value.toString(); String fullValue = value.toString();
String valueToPrint = fullValue.length() > 100 ? fullValue.substring(0, 100) + " ... (abbreviated for console, since the size is " + fullValue.length() + ")" : fullValue; String valueToPrint = fullValue.length() > 100 ? fullValue.substring(0, 100) + " ... (abbreviated for console, since the size is " + fullValue.length() + ")" : fullValue;
String errormsg = String.format( String errormsg = String.format(
"Unable to bind column '%s' to cql type '%s' with value '%s' (class '%s')", "Unable to bind column '%s' to cql type '%s' with value '%s' (class '%s')",
defname, defname,
type.asCql(false, false), type.asCql(false, false),
valueToPrint, valueToPrint,
value.getClass().getCanonicalName() value.getClass().getCanonicalName()
); );
logger.error(errormsg); logger.error(errormsg);
throw new OpConfigError(errormsg, e); throw new OpConfigError(errormsg, e);

View File

@@ -22,8 +22,6 @@ import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper; import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector; import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector; import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@@ -39,8 +37,6 @@ public class NormalizeCqlVector implements Function<com.datastax.oss.driver.api.
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector(); private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector(); private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
private final static Logger logger = LogManager.getLogger(NormalizeCqlVector.class);
@Override @Override
public com.datastax.oss.driver.api.core.data.CqlVector apply(CqlVector cqlVector) { public com.datastax.oss.driver.api.core.data.CqlVector apply(CqlVector cqlVector) {
@@ -50,15 +46,21 @@ public class NormalizeCqlVector implements Function<com.datastax.oss.driver.api.
values.forEach(list::add); values.forEach(list::add);
if (list.isEmpty()) { if (list.isEmpty()) {
builder.add(List.of()); builder.add(List.of());
} else if (list.get(0) instanceof Float) { } else if (list.get(0) instanceof Float) {
List<Float> floats = new ArrayList<>(); List<Float> floats = new ArrayList<>();
list.forEach(o -> floats.add((Float) o)); list.forEach(o -> floats.add((Float) o));
builder.add(nfv.apply(floats)); for (Float fv : floats) {
builder.add(fv);
}
} else if (list.get(0) instanceof Double) { } else if (list.get(0) instanceof Double) {
List<Double> doubles = new ArrayList<>(); List<Double> doubles = new ArrayList<>();
list.forEach(o -> doubles.add((Double) o)); list.forEach(o -> doubles.add((Double) o));
builder.add(ndv.apply(doubles)); for (Double dv : doubles) {
builder.add(dv);
}
} else { } else {
throw new RuntimeException("Only Doubles and Floats are recognized."); throw new RuntimeException("Only Doubles and Floats are recognized.");
} }