fixes for DNN angular1 tests

This commit is contained in:
Jonathan Shook 2023-12-21 14:25:37 -06:00
parent 3f6abf12f8
commit d2f302d02c

View File

@ -16,54 +16,54 @@
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import org.jetbrains.annotations.TestOnly;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
class DNN_angular1_vTest {
class DNNAngular1VTest {
@Test
public void testCosineSimilarity() {
assertThat(cosine_similarity(new float[]{1,2,3,4,5,6,7},new float[]{7,6,5,4,3,2,1})).isEqualTo(0.6);
assertThat(cosine_similarity(new float[]{1,2,3,4,5,6,7},new float[]{1,2,3,4,5,6,7})).isEqualTo(1.0);
assertThat(cosine_similarity(new float[]{1, 2, 3, 4, 5, 6, 7}, new float[]{7, 6, 5, 4, 3, 2, 1})).isEqualTo(0.6);
assertThat(cosine_similarity(new float[]{1, 2, 3, 4, 5, 6, 7}, new float[]{1, 2, 3, 4, 5, 6, 7})).isEqualTo(1.0);
}
@Test
public void testSimpleGeneration() {
DNN_angular1_v vs = new DNN_angular1_v(2,100,3);
assertThat(vs.apply(0)).isEqualTo(new float[]{1,0});
assertThat(vs.apply(1)).isEqualTo(new float[]{2,2});
assertThat(vs.apply(2)).isEqualTo(new float[]{3,6});
assertThat(vs.apply(3)).isEqualTo(new float[]{4,0});
assertThat(vs.apply(4)).isEqualTo(new float[]{5,5});
assertThat(vs.apply(5)).isEqualTo(new float[]{6,12});
assertThat(vs.apply(6)).isEqualTo(new float[]{7,0});
DNN_angular1_v vs = new DNN_angular1_v(2, 100, 3);
assertThat(vs.apply(0)).isEqualTo(new float[]{1, 0});
assertThat(vs.apply(1)).isEqualTo(new float[]{2, 2});
assertThat(vs.apply(2)).isEqualTo(new float[]{3, 6});
assertThat(vs.apply(3)).isEqualTo(new float[]{4, 0});
assertThat(vs.apply(4)).isEqualTo(new float[]{5, 5});
assertThat(vs.apply(5)).isEqualTo(new float[]{6, 12});
assertThat(vs.apply(6)).isEqualTo(new float[]{7, 0});
}
@Test
public void testBasicAngularVectors() {
DNN_angular1_v vf = new DNN_angular1_v(10, 100, 7);
int M = 7;
DNN_angular1_v vf = new DNN_angular1_v(10, 100, M);
float[][] vectors = new float[100][];
for (int i = 0; i < 100; i++) {
vectors[i] = vf.apply(i);
}
int[] same = new int[100];
Arrays.fill(same,-1);
Arrays.fill(same, -1);
for (int vidx = 0; vidx < same.length; vidx++) {
for (int compare_to = 0; compare_to < vidx; compare_to++) {
if (cosine_similarity(vectors[vidx],vectors[compare_to])==1.0) {
same[vidx]=compare_to;
for (int compare_to = 0; compare_to <= vidx; compare_to++) {
double similarity = cosine_similarity(vectors[vidx], vectors[compare_to]);
if (Math.abs(similarity - 1.0d) < 0.00000001d) {
same[vidx] = compare_to;
break;
}
}
}
for (int sameas = 0; sameas < same.length; sameas++) {
assertThat(same[sameas]==sameas%7);
for (int sameas = M; sameas < same.length; sameas++) {
// System.out.println("idx:" + sameas + ", same[sameas] -> " + same[sameas] + " sameas%7=" + sameas % M);
assertThat(same[sameas] % M).isEqualTo(sameas % M);
}
}
@ -76,8 +76,7 @@ class DNN_angular1_vTest {
as += (a[i] * a[i]);
bs += (b[i] * b[i]);
}
double similarity = dp / (Math.sqrt(as) * Math.sqrt(bs));
return similarity;
return dp / (Math.sqrt(as) * Math.sqrt(bs));
}
}