mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
adding pytorch version
This commit is contained in:
parent
08f867d230
commit
6ee029c9eb
@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import h5py
|
||||
|
||||
def generate_knn_dataset(n, p, x, output_file):
|
||||
# Step 1: Generate 'train' data (n vectors of size p) with associated ids
|
||||
train_data = np.random.rand(n, p).astype(np.float32)
|
||||
train_ids = np.repeat(np.arange(1, n // 100 + 1), 100) # Assign 100 contiguous vectors per id
|
||||
|
||||
# Step 2: Generate 'test' data (x vectors of size p) with associated ids
|
||||
test_data = np.random.rand(x, p).astype(np.float32)
|
||||
test_ids = []
|
||||
for _ in range(x):
|
||||
num_ids = np.random.randint(1, 6) # Each test query is associated with 1 to 5 training ids
|
||||
associated_ids = np.random.choice(np.arange(1, n // 100 + 1), size=num_ids, replace=False)
|
||||
test_ids.append(associated_ids)
|
||||
|
||||
# Step 3: Compute KNN for 'test' data using 'train' data filtered by associated ids
|
||||
neighbors_list = []
|
||||
train_tensor = torch.tensor(train_data)
|
||||
for i in range(x):
|
||||
query_vector = torch.tensor(test_data[i])
|
||||
query_ids = test_ids[i]
|
||||
|
||||
# Filter train data by matching ids
|
||||
mask = torch.tensor(np.isin(train_ids, query_ids))
|
||||
filtered_train_data = train_tensor[mask]
|
||||
global_indices = torch.where(mask)[0] # Get global indices of the filtered train data
|
||||
|
||||
if filtered_train_data.shape[0] > 0:
|
||||
# Normalize vectors to compute cosine similarity
|
||||
query_vector = query_vector / query_vector.norm(dim=0, keepdim=True)
|
||||
filtered_train_data = filtered_train_data / filtered_train_data.norm(dim=1, keepdim=True)
|
||||
|
||||
# Compute cosine similarity
|
||||
similarities = torch.mm(filtered_train_data, query_vector.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get top 100 neighbors based on similarity
|
||||
knn_indices = torch.topk(similarities, k=100, largest=True).indices
|
||||
global_knn_indices = global_indices[knn_indices] # Map local indices to global indices
|
||||
neighbors_list.append(global_knn_indices.numpy())
|
||||
else:
|
||||
neighbors_list.append(np.full(100, -1, dtype=np.int32)) # Placeholder for no neighbors
|
||||
|
||||
# Step 4: Write data to HDF5 file
|
||||
with h5py.File(output_file, 'w') as h5f:
|
||||
h5f.create_dataset('train', data=train_data)
|
||||
h5f.create_dataset('train_ids', data=train_ids)
|
||||
h5f.create_dataset('test', data=test_data)
|
||||
h5f.create_dataset('test_ids', data=np.array(test_ids, dtype=object), dtype=h5py.special_dtype(vlen=np.int32))
|
||||
h5f.create_dataset('neighbors', data=np.array(neighbors_list, dtype=np.int32))
|
||||
|
||||
print(f"Dataset saved to {output_file}")
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
n = int(input("Enter the number of train vectors (n): "))
|
||||
p = int(input("Enter the dimensionality of each vector (p): "))
|
||||
x = int(input("Enter the number of test vectors (x): "))
|
||||
output_file = input("Enter the output HDF5 file name: ")
|
||||
|
||||
generate_knn_dataset(n, p, x, output_file)
|
@ -0,0 +1,80 @@
|
||||
min_version: "5.21.1"
|
||||
|
||||
params:
|
||||
driver: cqld4
|
||||
instrument: true
|
||||
|
||||
bindings:
|
||||
account_sid: ToString()
|
||||
knowledge_sid_w: HdfFileToInt("TEMPLATE(dataset)", "/train_ids");
|
||||
knowledge_sid_r: HdfFileToVarLengthIntArray("TEMPLATE(dataset)", "/test_ids"); IntArrayToString();
|
||||
content: CharBufImage('A-Za-z0-9_,:{}[]',1000000,WeightedLongs("12288:100")); ToString()
|
||||
id: ToHashedUUID()
|
||||
# write embedding
|
||||
embed_base: HdfFileToFloatList("TEMPLATE(dataset)", "/train"); ToCqlVector();
|
||||
# read embedding
|
||||
embed_query: HdfFileToFloatList("TEMPLATE(dataset)", "/test"); ToCqlVector();
|
||||
# KNN relevance score
|
||||
relevant_indices: HdfFileToIntArray("TEMPLATE(dataset)", "/neighbors");
|
||||
|
||||
|
||||
blocks:
|
||||
reset_schema:
|
||||
ops:
|
||||
drop_tbl: |
|
||||
DROP TABLE IF EXISTS TEMPLATE(keyspace, default_keyspace).TEMPLATE(table, embeddings);
|
||||
|
||||
schema:
|
||||
ops:
|
||||
op_tbl: |
|
||||
CREATE TABLE IF NOT EXISTS TEMPLATE(keyspace, default_keyspace).TEMPLATE(table, embeddings) (
|
||||
account_sid text,
|
||||
knowledge_sid int,
|
||||
content text,
|
||||
id uuid,
|
||||
embedding vector<float, TEMPLATE(dimension, 128)>,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
op_idx_1: |
|
||||
CREATE CUSTOM INDEX IF NOT EXISTS TEMPLATE(table,embeddings)_embeddings_index
|
||||
ON TEMPLATE(keyspace,default_keyspace).TEMPLATE(table, embeddings)(embedding)
|
||||
USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'
|
||||
WITH OPTIONS = {'source_model': 'TEMPLATE(source_model, openai_v3_large)'};
|
||||
op_idx_2: |
|
||||
CREATE CUSTOM INDEX IF NOT EXISTS TEMPLATE(table,embeddings)_ksid_index
|
||||
ON TEMPLATE(keyspace,default_keyspace).TEMPLATE(table, embeddings)(knowledge_sid)
|
||||
USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';
|
||||
|
||||
read_query_1:
|
||||
ops:
|
||||
op_r1_1:
|
||||
raw: |
|
||||
SELECT * FROM TEMPLATE(keyspace,default_keyspace).TEMPLATE(table, embeddings)
|
||||
WHERE knowledge_sid IN({knowledge_sid_r})
|
||||
ORDER BY embedding ANN OF {embed_query} LIMIT TEMPLATE(select_limit, 10);
|
||||
verifier-init: |
|
||||
topK=TEMPLATE(select_limit,10)
|
||||
relevancy=new io.nosqlbench.nb.api.engine.metrics.wrappers.RelevancyMeasures(_parsed_op)
|
||||
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.recall("recall",topK));
|
||||
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.precision("precision",topK));
|
||||
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.F1("F1",topK));
|
||||
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.reciprocal_rank("RR",topK));
|
||||
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.average_precision("AP",topK));
|
||||
windowed_relevancy = new io.nosqlbench.nb.api.engine.metrics.wrappers.WindowedRelevancyMeasures(_parsed_op,10)
|
||||
windowed_relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.recall("recall",topK));
|
||||
verifier: |
|
||||
actual_indices=cql_utils.cqlStringColumnToIntArray("account_sid",result);
|
||||
relevant_indices={relevant_indices};
|
||||
relevancy.accept(relevant_indices,actual_indices);
|
||||
windowed_relevancy.accept(relevant_indices, actual_indices);
|
||||
return true;
|
||||
|
||||
|
||||
main_write:
|
||||
params:
|
||||
cl: TEMPLATE(cl,LOCAL_QUORUM)
|
||||
ops:
|
||||
op_w_1: |
|
||||
INSERT INTO TEMPLATE(keyspace,default_keyspace).TEMPLATE(table, embeddings)
|
||||
(account_sid, knowledge_sid, content, id, embedding)
|
||||
VALUES ({account_sid}, {knowledge_sid_w}, {content}, {id}, {embed_base});
|
Loading…
Reference in New Issue
Block a user