Fixes handling of List vs. Float

This commit is contained in:
jeffbanks
2023-06-02 14:47:51 -05:00
parent b9d5008c5a
commit 08989143ed
2 changed files with 71 additions and 76 deletions

View File

@@ -22,6 +22,7 @@ import com.datastax.oss.driver.api.core.cql.ColumnDefinition;
import com.datastax.oss.driver.api.core.cql.ColumnDefinitions;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.data.CqlDuration;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.data.TupleValue;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.type.DataType;
@@ -50,84 +51,76 @@ import static com.datastax.oss.protocol.internal.ProtocolConstants.DataType.*;
* explaining more specifically what the problem was that caused the original error to be thrown.
*/
public class CQLD4PreparedStmtDiagnostics {
private final static Logger logger = LogManager.getLogger(CQLD4PreparedStmtDiagnostics.class);
private static final Logger logger = LogManager.getLogger(CQLD4PreparedStmtDiagnostics.class);
public static BoundStatement bindStatement(BoundStatement bound, CqlIdentifier colname, Object colval, DataType coltype) {
public static BoundStatement bindStatement(BoundStatement bound, CqlIdentifier colname,
Object colval, DataType coltype) {
// if (coltype instanceof PrimitiveType pt) {
try {
BoundStatement reproduce_error_with_custom_dataType = switch (coltype.getProtocolCode()) {
case CUSTOM -> throw new OpConfigError("Error with Custom DataType");
case ASCII, VARCHAR -> bound.setString(colname, (String) colval);
case BIGINT, COUNTER -> bound.setLong(colname, (long) colval);
case BLOB -> bound.setByteBuffer(colname, (ByteBuffer) colval);
case BOOLEAN -> bound.setBoolean(colname, (boolean) colval);
case DECIMAL -> bound.setBigDecimal(colname, (BigDecimal) colval);
case DOUBLE -> bound.setDouble(colname, (double) colval);
case FLOAT -> bound.setFloat(colname, (float) colval);
case INT -> bound.setInt(colname, (int) colval);
case SMALLINT -> bound.setShort(colname, (short) colval);
case TINYINT -> bound.setByte(colname, (byte) colval);
case TIMESTAMP -> bound.setInstant(colname, (Instant) colval);
case TIMEUUID, UUID -> bound.setUuid(colname, (UUID) colval);
case VARINT -> bound.setBigInteger(colname, (BigInteger) colval);
case INET -> bound.setInetAddress(colname, (InetAddress) colval);
case DATE -> bound.setLocalDate(colname, (LocalDate) colval);
case TIME -> bound.setLocalTime(colname, (LocalTime) colval);
case DURATION -> bound.setCqlDuration(colname, (CqlDuration) colval);
case LIST -> bound.setList(colname, (List) colval, ((List) colval).get(0).getClass());
case MAP -> {
Map map = (Map) colval;
Set<Map.Entry> entries = map.entrySet();
Optional<Map.Entry> first = entries.stream().findFirst();
if (first.isPresent()) {
yield bound.setMap(colname, map, first.get().getKey().getClass(), first.get().getValue().getClass());
} else {
yield bound.setMap(colname, map, Object.class, Object.class);
}
return switch (coltype.getProtocolCode()) {
// TODO - need to handle 'custom' more specifically, interim change for vector search.
case CUSTOM -> bound.setCqlVector(colname, (CqlVector<?>) colval);
case ASCII, VARCHAR -> bound.setString(colname, (String) colval);
case BIGINT, COUNTER -> bound.setLong(colname, (long) colval);
case BLOB -> bound.setByteBuffer(colname, (ByteBuffer) colval);
case BOOLEAN -> bound.setBoolean(colname, (boolean) colval);
case DECIMAL -> bound.setBigDecimal(colname, (BigDecimal) colval);
case DOUBLE -> bound.setDouble(colname, (double) colval);
case FLOAT -> bound.setFloat(colname, (float) colval);
case INT -> bound.setInt(colname, (int) colval);
case SMALLINT -> bound.setShort(colname, (short) colval);
case TINYINT -> bound.setByte(colname, (byte) colval);
case TIMESTAMP -> bound.setInstant(colname, (Instant) colval);
case TIMEUUID, UUID -> bound.setUuid(colname, (UUID) colval);
case VARINT -> bound.setBigInteger(colname, (BigInteger) colval);
case INET -> bound.setInetAddress(colname, (InetAddress) colval);
case DATE -> bound.setLocalDate(colname, (LocalDate) colval);
case TIME -> bound.setLocalTime(colname, (LocalTime) colval);
case DURATION -> bound.setCqlDuration(colname, (CqlDuration) colval);
case LIST -> bound.setList(colname, (List) colval, ((List) colval).get(0).getClass());
case MAP -> {
Map map = (Map) colval;
Set<Map.Entry> entries = map.entrySet();
Optional<Map.Entry> first = entries.stream().findFirst();
if (first.isPresent()) {
yield bound.setMap(colname, map, first.get().getKey().getClass(), first.get().getValue().getClass());
} else {
yield bound.setMap(colname, map, Object.class, Object.class);
}
case SET -> {
Set set = (Set) colval;
Optional first = set.stream().findFirst();
if (first.isPresent()) {
yield bound.setSet(colname, set, first.get().getClass());
} else {
yield bound.setSet(colname, Set.of(), Object.class);
}
}
case SET -> {
Set set = (Set) colval;
Optional first = set.stream().findFirst();
if (first.isPresent()) {
yield bound.setSet(colname, set, first.get().getClass());
} else {
yield bound.setSet(colname, Set.of(), Object.class);
}
case UDT -> {
UdtValue udt = (UdtValue) colval;
yield bound.setUdtValue(colname, udt);
}
case TUPLE -> {
TupleValue tuple = (TupleValue) colval;
yield bound.setTupleValue(colname, tuple);
}
default -> throw new RuntimeException("Unknown CQL type for diagnostic (type:'" + coltype + "',code:'" + coltype.getProtocolCode() + "'");
};
return reproduce_error_with_custom_dataType;
} catch (Exception e) {
throw e;
}
// }
// throw new IllegalStateException("Unexpected value: " + coltype);
}
case UDT -> {
UdtValue udt = (UdtValue) colval;
yield bound.setUdtValue(colname, udt);
}
case TUPLE -> {
TupleValue tuple = (TupleValue) colval;
yield bound.setTupleValue(colname, tuple);
}
default -> throw new RuntimeException("Unknown CQL type for diagnostic " +
"(type:'" + coltype + "',code:'" + coltype.getProtocolCode() + "'");
};
}
public static Cqld4CqlOp rebindWithDiagnostics(
PreparedStatement preparedStmt,
LongFunction<Object[]> fieldsF,
long cycle,
Exception exception
PreparedStatement preparedStmt,
LongFunction<Object[]> fieldsF,
long cycle,
Exception exception
) {
logger.error(exception);
ColumnDefinitions defs = preparedStmt.getVariableDefinitions();
Object[] values = fieldsF.apply(cycle);
if (defs.size() != values.length) {
throw new OpConfigError("There are " + defs.size() + " anchors in statement '" + preparedStmt.getQuery() + "'" +
"but " + values.length + " values were provided. These must match.");
"but " + values.length + " values were provided. These must match.");
}
BoundStatement bound = preparedStmt.bind();
@@ -143,11 +136,11 @@ public class CQLD4PreparedStmtDiagnostics {
String fullValue = value.toString();
String valueToPrint = fullValue.length() > 100 ? fullValue.substring(0, 100) + " ... (abbreviated for console, since the size is " + fullValue.length() + ")" : fullValue;
String errormsg = String.format(
"Unable to bind column '%s' to cql type '%s' with value '%s' (class '%s')",
defname,
type.asCql(false, false),
valueToPrint,
value.getClass().getCanonicalName()
"Unable to bind column '%s' to cql type '%s' with value '%s' (class '%s')",
defname,
type.asCql(false, false),
valueToPrint,
value.getClass().getCanonicalName()
);
logger.error(errormsg);
throw new OpConfigError(errormsg, e);

View File

@@ -22,8 +22,6 @@ import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeDoubleListVector;
import io.nosqlbench.virtdata.library.basics.shared.from_long.to_vector.NormalizeFloatListVector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.List;
@@ -39,8 +37,6 @@ public class NormalizeCqlVector implements Function<com.datastax.oss.driver.api.
private final NormalizeDoubleListVector ndv = new NormalizeDoubleListVector();
private final NormalizeFloatListVector nfv = new NormalizeFloatListVector();
private final static Logger logger = LogManager.getLogger(NormalizeCqlVector.class);
@Override
public com.datastax.oss.driver.api.core.data.CqlVector apply(CqlVector cqlVector) {
@@ -50,15 +46,21 @@ public class NormalizeCqlVector implements Function<com.datastax.oss.driver.api.
values.forEach(list::add);
if (list.isEmpty()) {
builder.add(List.of());
} else if (list.get(0) instanceof Float) {
List<Float> floats = new ArrayList<>();
list.forEach(o -> floats.add((Float) o));
builder.add(nfv.apply(floats));
for (Float fv : floats) {
builder.add(fv);
}
} else if (list.get(0) instanceof Double) {
List<Double> doubles = new ArrayList<>();
list.forEach(o -> doubles.add((Double) o));
builder.add(ndv.apply(doubles));
for (Double dv : doubles) {
builder.add(dv);
}
} else {
throw new RuntimeException("Only Doubles and Floats are recognized.");
}