mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
Merge pull request #1703 from nosqlbench/nosqlbench-1691-dnn
Nosqlbench 1691 dnn
This commit is contained in:
commit
22c50b02cd
@ -39,7 +39,7 @@ public class ToDouble implements LongToDoubleFunction {
|
||||
|
||||
private final LongToDoubleFunction func;
|
||||
|
||||
ToDouble(Object func) {
|
||||
public ToDouble(Object func) {
|
||||
if (func instanceof Number number) {
|
||||
final double aDouble = number.doubleValue();
|
||||
this.func = l -> aDouble;
|
||||
|
@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
||||
|
||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
/**
|
||||
* Compute the indices of the neighbors of a given v using DNN mapping.
|
||||
* To avoid ambiguity on equidistant neighbors, odd neighborhood sizes are preferred.
|
||||
*/
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class DNN_euclidean_neighbors implements IntFunction<int[]> {
|
||||
|
||||
private final int D;
|
||||
private final int N;
|
||||
private final int k;
|
||||
|
||||
/**
|
||||
* @param k
|
||||
* The size of neighborhood
|
||||
* @param N
|
||||
* The number of total vectors, necessary for boundary conditions of defined vector
|
||||
* @param D
|
||||
* Number of dimensions in each vector
|
||||
*/
|
||||
public DNN_euclidean_neighbors(int k, int N, int D) {
|
||||
this.D = D;
|
||||
this.N = N;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
/**
|
||||
* <P>Compute neighbor indices with a (hopefully) fast implementation. There are surely some simplifications to be
|
||||
* made in the functions below, but even in the current form it avoids a significant number of branches.</P>
|
||||
*
|
||||
* <P>This code is not as simple as it could be. It was built more for speed than simplicity since it will be a hot
|
||||
* spot for testing. The unit tests for this are essential.</P>
|
||||
*
|
||||
* <P>The method is thus:
|
||||
* <OL>
|
||||
* <LI>Determine the sections of the neighborhood which aren't subject to boundary conditions,
|
||||
* starting at the central vector (the index of the query vector).</LI>
|
||||
* <LI>Layer these in rank order using closed-form index functions.</LI>
|
||||
* <LI>Layer in any zero-boundary values which were deferred from above.</LI>
|
||||
* <LI>Layer in an N-boundary values which were deferred from above.</LI>
|
||||
* </OL>
|
||||
* </P>
|
||||
*
|
||||
* <P>The boundary conditions for zero and N are mutually exclusive. Even though there is some amount of
|
||||
* ranging and book keeping in this approach, it should make the general case more stable, especially
|
||||
* when there are many dimensions and many neighbors.
|
||||
* </P>
|
||||
*
|
||||
* @param value
|
||||
* the function argument, or the index of the query vector for the DNN addressing scheme
|
||||
* @return A ranked neighborhood of vector indices, using the DNN addressing scheme
|
||||
*/
|
||||
@Override
|
||||
public int[] apply(int value) {
|
||||
value = Math.min(Math.max(0,value),N-1);
|
||||
int[] indices = new int[k];
|
||||
|
||||
int leftBoundary = (value << 1) + 1;
|
||||
int rightBoundary = ((N - (value + 1)) << 1) + 1;
|
||||
int insideNeighbors = Math.min(k, Math.min(leftBoundary, rightBoundary));
|
||||
for (int i = 0; i < insideNeighbors; i++) {
|
||||
// Leave this here as an explainer, please
|
||||
// int sign = ((((i + 1) & 1) << 1) - 1); // this gives us -1 or +1 depending on odd or even, and is inverted
|
||||
// int offset = ((i + 1)>>1); // half rounded down, shifted by 1
|
||||
// offset *= sign;
|
||||
// int v = value + (((((i + 1) & 1) << 1) - 1) * ((i + 1) >> 1));
|
||||
indices[i] = value + (((((i + 1) & 1) << 1) - 1) * ((i + 1) >> 1));
|
||||
}
|
||||
int leftFill = Math.max(0, k - leftBoundary);
|
||||
// TODO: Evaluate optimization from Dave2Wave for reducing additions
|
||||
for (int i = 0; i < leftFill; i++) {
|
||||
indices[insideNeighbors + i] = insideNeighbors + i;
|
||||
}
|
||||
int rightFill = Math.max(0, k - rightBoundary);
|
||||
for (int i = 0; i < rightFill; i++) {
|
||||
indices[insideNeighbors + i] = (N - 1) - (insideNeighbors + i);
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
||||
|
||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||
|
||||
import java.util.function.LongFunction;
|
||||
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class DNN_euclidean_v implements LongFunction<float[]> {
|
||||
|
||||
private final int D;
|
||||
private final long N;
|
||||
private final double scale;
|
||||
|
||||
public DNN_euclidean_v(int D, long N) {
|
||||
this(D,N,1.0d);
|
||||
}
|
||||
|
||||
public DNN_euclidean_v(int D, long N, double scale) {
|
||||
this.D = D;
|
||||
this.N = N;
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] apply(long value) {
|
||||
if (value>= N) {
|
||||
throw new RuntimeException("You can't generate a vector for ordinal " + value + " when your population is " + this.N);
|
||||
}
|
||||
float[] vector = new float[D];
|
||||
for (int idx = 0; idx < vector.length; idx++) {
|
||||
vector[idx]= (float)(value+(idx*scale));
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
||||
|
||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.function.LongFunction;
|
||||
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class DNN_euclidean_v_series implements LongFunction<float[][]> {
|
||||
|
||||
private final int dimensions;
|
||||
private final long population;
|
||||
private final int k;
|
||||
|
||||
public DNN_euclidean_v_series(int dimensions, long population, int k) {
|
||||
this.dimensions = dimensions;
|
||||
this.population = population;
|
||||
this.k = k;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[][] apply(long value) {
|
||||
long nextInterval = value + k;
|
||||
if (nextInterval > population) {
|
||||
throw new RuntimeException("You can't generate a vector for ordinal " + value + " when your population is " + this.population);
|
||||
}
|
||||
int capacity = dimensions + k;
|
||||
float[] image = new float[capacity];
|
||||
for (int imgidx = 0; imgidx < capacity; imgidx++) {
|
||||
image[imgidx]=imgidx+value;
|
||||
}
|
||||
float[][] vectorSeq = new float[k][dimensions];
|
||||
for (int i = 0; i < vectorSeq.length; i++) {
|
||||
vectorSeq[i]=Arrays.copyOfRange(image,i,i+dimensions);
|
||||
}
|
||||
return vectorSeq;
|
||||
}
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
||||
|
||||
import io.nosqlbench.virtdata.api.annotations.Categories;
|
||||
import io.nosqlbench.virtdata.api.annotations.Category;
|
||||
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
|
||||
|
||||
import java.util.function.LongFunction;
|
||||
|
||||
@ThreadSafeMapper
|
||||
@Categories(Category.experimental)
|
||||
public class DNN_euclidean_v_wrap implements LongFunction<float[]> {
|
||||
|
||||
private final int D;
|
||||
private final long N;
|
||||
private final double scale;
|
||||
|
||||
public DNN_euclidean_v_wrap(int D, long N, double scale) {
|
||||
this.D = D;
|
||||
this.N = N;
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
public DNN_euclidean_v_wrap(int D, long N) {
|
||||
this(D,N,1.0d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] apply(long value) {
|
||||
value = value % N;
|
||||
float[] vector = new float[D];
|
||||
for (int idx = 0; idx < vector.length; idx++) {
|
||||
vector[idx]= (float)(value+(idx*scale));
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* This is an experimental package based on the DNN or "Das/Direct Nearest Neighbor" method.
|
||||
*/
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class DNNEuclideanNeighborsTest {
|
||||
|
||||
@Test
|
||||
public void test_DNN_K3_N7_D5() {
|
||||
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(3, 7, 5);
|
||||
assertThat(idxF.apply(0)).isEqualTo(new int[]{0,1,2});
|
||||
assertThat(idxF.apply(1)).isEqualTo(new int[]{1,0,2});
|
||||
assertThat(idxF.apply(2)).isEqualTo(new int[]{2,1,3});
|
||||
assertThat(idxF.apply(3)).isEqualTo(new int[]{3,2,4});
|
||||
assertThat(idxF.apply(4)).isEqualTo(new int[]{4,3,5});
|
||||
assertThat(idxF.apply(5)).isEqualTo(new int[]{5,4,6});
|
||||
assertThat(idxF.apply(6)).isEqualTo(new int[]{6,5,4});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_DNN_k4_n7_d5() {
|
||||
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(4, 7, 5);
|
||||
assertThat(idxF.apply(0)).isEqualTo(new int[]{0,1,2,3});
|
||||
assertThat(idxF.apply(1)).isEqualTo(new int[]{1,0,2,3});
|
||||
assertThat(idxF.apply(2)).isEqualTo(new int[]{2,1,3,0});
|
||||
assertThat(idxF.apply(3)).isEqualTo(new int[]{3,2,4,1});
|
||||
assertThat(idxF.apply(4)).isEqualTo(new int[]{4,3,5,2});
|
||||
assertThat(idxF.apply(5)).isEqualTo(new int[]{5,4,6,3});
|
||||
assertThat(idxF.apply(6)).isEqualTo(new int[]{6,5,4,3});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_DNN_k6_n100_d10() {
|
||||
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(6, 100, 10);
|
||||
assertThat(idxF.apply(99)).isEqualTo(new int[]{99,98,97,96,95,94});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_DNN_K6_N101_D10() {
|
||||
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(6, 101, 10);
|
||||
assertThat(idxF.apply(101)).isEqualTo(new int[]{100,99,98,97,96,95});
|
||||
assertThat(idxF.apply(100)).isEqualTo(new int[]{100,99,98,97,96,95});
|
||||
assertThat(idxF.apply(99)).isEqualTo(new int[]{99,98,100,97,96,95});
|
||||
assertThat(idxF.apply(98)).isEqualTo(new int[]{98,97,99,96,100,95});
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class DNNEuclideanVTest {
|
||||
|
||||
@Test
|
||||
public void testBasicVectors() {
|
||||
DNN_euclidean_v vf = new DNN_euclidean_v(5, 7);
|
||||
assertThat(vf.apply(3L)).isEqualTo(new float[]{3f,4f,5f,6f,7f});
|
||||
assertThrows(RuntimeException.class, () -> vf.apply(7));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicVectorsScaled() {
|
||||
DNN_euclidean_v vf = new DNN_euclidean_v(5, 7, 3.0);
|
||||
assertThat(vf.apply(3L)).isEqualTo(new float[]{3f,6f,9f,12f,15f});
|
||||
assertThrows(RuntimeException.class, () -> vf.apply(7));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testWrappingVectors() {
|
||||
DNN_euclidean_v_wrap vf = new DNN_euclidean_v_wrap(5, 7);
|
||||
assertThat(vf.apply(3L)).isEqualTo(new float[]{3f,4f,5f,6f,7f});
|
||||
assertThat(vf.apply(0L)).isEqualTo(new float[]{0f,1f,2f,3f,4f});
|
||||
assertThat(vf.apply(7L)).isEqualTo(new float[]{0f,1f,2f,3f,4f});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContiguousVectors() {
|
||||
DNN_euclidean_v_series vf = new DNN_euclidean_v_series(4,10,2);
|
||||
assertThat(vf.apply(7L)).isEqualTo(
|
||||
new float[][] {
|
||||
{7f,8f,9f,10f},
|
||||
{8f,9f,10f,11f}
|
||||
}
|
||||
);
|
||||
|
||||
assertThrows(RuntimeException.class, () -> vf.apply(10));
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user