adding vector find ops

This commit is contained in:
Mark Wolters
2024-05-10 18:22:01 -04:00
parent e6d2c35f5a
commit 101e9d93fa
6 changed files with 250 additions and 0 deletions

View File

@@ -0,0 +1,82 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.adapter.dataapi;
import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.Filter;
import com.datastax.astra.client.model.FindOptions;
import com.datastax.astra.client.model.Projection;
import com.datastax.astra.client.model.Sort;
import io.nosqlbench.adapter.dataapi.opdispensers.DataApiOpDispenser;
import io.nosqlbench.adapter.dataapi.ops.DataApiBaseOp;
import io.nosqlbench.adapter.dataapi.ops.DataApiFindVectorFilterOp;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.function.LongFunction;
public class DataApiFindVectorFilterOpDispenser extends DataApiOpDispenser {
private static final Logger logger = LogManager.getLogger(DataApiFindVectorFilterOpDispenser.class);
private final LongFunction<DataApiFindVectorFilterOp> opFunction;
public DataApiFindVectorFilterOpDispenser(DataApiDriverAdapter adapter, ParsedOp op, LongFunction<String> targetFunction) {
super(adapter, op, targetFunction);
this.opFunction = createOpFunction(op);
}
private LongFunction<DataApiFindVectorFilterOp> createOpFunction(ParsedOp op) {
return (l) -> {
Database db = spaceFunction.apply(l).getDatabase();
float[] vector = getVectorValues(op, l);
Filter filter = getFilterFromOp(op, l);
int limit = getLimit(op, l);
return new DataApiFindVectorFilterOp(
db,
db.getCollection(targetFunction.apply(l)),
vector,
limit,
filter
);
};
}
private int getLimit(ParsedOp op, long l) {
return op.getConfigOr("limit", 100, l);
}
private FindOptions getFindOptions(ParsedOp op, long l) {
FindOptions options = new FindOptions();
Sort sort = getSortFromOp(op, l);
float[] vector = getVectorValues(op, l);
if (sort != null) {
options = vector != null ? options.sort(vector, sort) : options.sort(sort);
} else if (vector != null) {
options = options.sort(vector);
}
Projection[] projection = getProjectionFromOp(op, l);
if (projection != null) {
options = options.projection(projection);
}
options.setIncludeSimilarity(true);
return options;
}
@Override
public DataApiBaseOp getOp(long value) {
return opFunction.apply(value);
}
}

View File

@@ -47,10 +47,13 @@ public class DataApiOpMapper implements OpMapper<DataApiBaseOp> {
case create_collection -> new DataApiCreateCollectionOpDispenser(adapter, op, typeAndTarget.targetFunction);
case insert_many -> new DataApiInsertManyOpDispenser(adapter, op, typeAndTarget.targetFunction);
case insert_one -> new DataApiInsertOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
case insert_one_vector -> new DataApiInsertOneVectorOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find -> new DataApiFindOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_one -> new DataApiFindOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_one_and_delete -> new DataApiFindOneAndDeleteOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_one_and_update -> new DataApiFindOneAndUpdateOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_vector -> new DataApiFindVectorOpDispenser(adapter, op, typeAndTarget.targetFunction);
case find_vector_filter -> new DataApiFindVectorFilterOpDispenser(adapter, op, typeAndTarget.targetFunction);
case update_one -> new DataApiUpdateOneOpDispenser(adapter, op, typeAndTarget.targetFunction);
case update_many -> new DataApiUpdateManyOpDispenser(adapter, op, typeAndTarget.targetFunction);
case delete_one -> new DataApiDeleteOneOpDispenser(adapter, op, typeAndTarget.targetFunction);

View File

@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.adapter.dataapi.opdispensers;
import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.FindOptions;
import com.datastax.astra.client.model.Projection;
import com.datastax.astra.client.model.Sort;
import io.nosqlbench.adapter.dataapi.DataApiDriverAdapter;
import io.nosqlbench.adapter.dataapi.ops.DataApiBaseOp;
import io.nosqlbench.adapter.dataapi.ops.DataApiFindVectorOp;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.function.LongFunction;
public class DataApiFindVectorOpDispenser extends DataApiOpDispenser {
private static final Logger logger = LogManager.getLogger(DataApiFindVectorOpDispenser.class);
private final LongFunction<DataApiFindVectorOp> opFunction;
public DataApiFindVectorOpDispenser(DataApiDriverAdapter adapter, ParsedOp op, LongFunction<String> targetFunction) {
super(adapter, op, targetFunction);
this.opFunction = createOpFunction(op);
}
private LongFunction<DataApiFindVectorOp> createOpFunction(ParsedOp op) {
return (l) -> {
Database db = spaceFunction.apply(l).getDatabase();
float[] vector = getVectorValues(op, l);
int limit = getLimit(op, l);
return new DataApiFindVectorOp(
db,
db.getCollection(targetFunction.apply(l)),
vector,
limit
);
};
}
private int getLimit(ParsedOp op, long l) {
return op.getConfigOr("limit", 100, l);
}
private FindOptions getFindOptions(ParsedOp op, long l) {
FindOptions options = new FindOptions();
Sort sort = getSortFromOp(op, l);
float[] vector = getVectorValues(op, l);
if (sort != null) {
options = vector != null ? options.sort(vector, sort) : options.sort(sort);
} else if (vector != null) {
options = options.sort(vector);
}
Projection[] projection = getProjectionFromOp(op, l);
if (projection != null) {
options = options.projection(projection);
}
options.setIncludeSimilarity(true);
return options;
}
@Override
public DataApiBaseOp getOp(long value) {
return opFunction.apply(value);
}
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.adapter.dataapi.ops;
import com.datastax.astra.client.Collection;
import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.Document;
import com.datastax.astra.client.model.Filter;
public class DataApiFindVectorFilterOp extends DataApiBaseOp {
private final Collection<Document> collection;
private final float[] vector;
private final int limit;
private final Filter filter;
public DataApiFindVectorFilterOp(Database db, Collection<Document> collection, float[] vector, int limit, Filter filter) {
super(db);
this.collection = collection;
this.vector = vector;
this.limit = limit;
this.filter = filter;
}
@Override
public Object apply(long value) {
return collection.find(filter, vector, limit);
}
}

View File

@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.adapter.dataapi.ops;
import com.datastax.astra.client.Collection;
import com.datastax.astra.client.Database;
import com.datastax.astra.client.model.Document;
import com.datastax.astra.client.model.Filter;
import com.datastax.astra.client.model.FindOptions;
public class DataApiFindVectorOp extends DataApiBaseOp {
private final Collection<Document> collection;
private final float[] vector;
private final int limit;
public DataApiFindVectorOp(Database db, Collection<Document> collection, float[] vector, int limit) {
super(db);
this.collection = collection;
this.vector = vector;
this.limit = limit;
}
@Override
public Object apply(long value) {
return collection.find(vector, limit);
}
}

View File

@@ -20,10 +20,13 @@ public enum DataApiOpType {
create_collection,
insert_many,
insert_one,
insert_one_vector,
find,
find_one,
find_one_and_delete,
find_one_and_update,
find_vector,
find_vector_filter,
update_one,
update_many,
delete_one,