diff --git a/src/main/java/io/weaviate/client/base/BaseClient.java b/src/main/java/io/weaviate/client/base/BaseClient.java index c3c0a528..2d26793d 100644 --- a/src/main/java/io/weaviate/client/base/BaseClient.java +++ b/src/main/java/io/weaviate/client/base/BaseClient.java @@ -1,10 +1,12 @@ package io.weaviate.client.base; +import io.weaviate.client.Config; import io.weaviate.client.base.http.HttpClient; import io.weaviate.client.base.http.HttpResponse; +import io.weaviate.client.v1.graphql.GraphQL; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; import java.util.Collections; - -import io.weaviate.client.Config; +import java.util.List; public abstract class BaseClient { private final HttpClient client; @@ -51,7 +53,12 @@ private Response sendRequest(String endpoint, Object payload, String method, if (statusCode < 399) { T body = toResponse(responseBody, classOfT); - return new Response<>(statusCode, body, null); + WeaviateErrorResponse errors = null; + + if (body != null && classOfT.equals(GraphQL.class)) { + errors = getWeaviateGraphQLErrorResponse((GraphQLResponse) body, statusCode); + } + return new Response<>(statusCode, body, errors); } WeaviateErrorResponse error = toResponse(responseBody, WeaviateErrorResponse.class); @@ -90,10 +97,22 @@ private String toJsonString(Object object) { } private WeaviateErrorResponse getWeaviateErrorResponse(Exception e) { - WeaviateErrorMessage error = WeaviateErrorMessage.builder() - .message(e.getMessage()) - .throwable(e) - .build(); + WeaviateErrorMessage error = WeaviateErrorMessage.builder().message(e.getMessage()).throwable(e).build(); return WeaviateErrorResponse.builder().error(Collections.singletonList(error)).build(); } + + /** + * Extract errors from {@link WeaviateErrorResponse} from a GraphQL response body. + * + * @param gql GraphQL response body. + * @param code HTTP status code to pass in the {@link WeaviateErrorResponse}. + * @return Error response to be returned to the caller. + */ + private WeaviateErrorResponse getWeaviateGraphQLErrorResponse(GraphQLResponse gql, int code) { + List messages = gql.errorMessages(); + if (messages == null || messages.isEmpty()) { + return null; + } + return WeaviateErrorResponse.builder().code(code).error(gql.errorMessages()).build(); + } } diff --git a/src/main/java/io/weaviate/client/base/WeaviateErrorMessage.java b/src/main/java/io/weaviate/client/base/WeaviateErrorMessage.java index ecad1bf0..bcf92bfb 100644 --- a/src/main/java/io/weaviate/client/base/WeaviateErrorMessage.java +++ b/src/main/java/io/weaviate/client/base/WeaviateErrorMessage.java @@ -1,6 +1,7 @@ package io.weaviate.client.base; import lombok.AccessLevel; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; import lombok.ToString; @@ -9,6 +10,7 @@ @Getter @Builder @ToString +@AllArgsConstructor(access = AccessLevel.PUBLIC) @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class WeaviateErrorMessage { String message; diff --git a/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java index 5899fc59..68b173ef 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java +++ b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java @@ -1,5 +1,9 @@ package io.weaviate.client.v1.graphql.model; +import io.weaviate.client.base.WeaviateErrorMessage; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import lombok.AccessLevel; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -15,4 +19,19 @@ public class GraphQLResponse { Object data; GraphQLError[] errors; + + + /** + * Extract the 'message' portion of every error in the response, omitting 'path' and 'location'. + * + * @return Non-throwable WeaviateErrorMessages + */ + public List errorMessages() { + if (errors == null || errors.length == 0) { + return null; + } + return Arrays.stream(errors) + .map(err -> new WeaviateErrorMessage(err.getMessage(), null)) + .collect(Collectors.toList()); + } } diff --git a/src/main/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgument.java b/src/main/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgument.java index 6fed0215..3714c370 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgument.java +++ b/src/main/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgument.java @@ -1,7 +1,12 @@ package io.weaviate.client.v1.graphql.query.argument; import io.weaviate.client.v1.graphql.query.util.Serializer; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import lombok.AccessLevel; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -9,10 +14,6 @@ import lombok.ToString; import lombok.experimental.FieldDefaults; import org.apache.commons.lang3.ArrayUtils; -import org.apache.commons.lang3.StringUtils; - -import java.util.LinkedHashSet; -import java.util.Set; @Getter @Builder @@ -24,7 +25,7 @@ public class NearVectorArgument implements Argument { Float certainty; Float distance; String[] targetVectors; - Map vectorPerTarget; + Map vectorsPerTarget; Targets targets; @Override @@ -41,19 +42,73 @@ public String build() { arg.add(String.format("distance:%s", distance)); } if (ArrayUtils.isNotEmpty(targetVectors)) { - arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors))); + arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors))); } - if (vectorPerTarget != null && !vectorPerTarget.isEmpty()) { + if (vectorsPerTarget != null && !vectorsPerTarget.isEmpty()) { Set vectorPerTargetArg = new LinkedHashSet<>(); - for (Map.Entry entry : vectorPerTarget.entrySet()) { - vectorPerTargetArg.add(String.format("%s:%s", entry.getKey(), Serializer.array(entry.getValue()))); + for (Map.Entry e : vectorsPerTarget.entrySet()) { + Float[][] vectors = e.getValue(); + vectorPerTargetArg.add(String.format("%s:%s", e.getKey(), vectors.length == 1 ? Serializer.array(vectors[0]) : Serializer.array(vectors))); } arg.add(String.format("vectorPerTarget:{%s}", String.join(" ", vectorPerTargetArg))); } if (targets != null) { - arg.add(String.format("%s", targets.build())); + arg.add(String.format("%s", withValidTargetVectors(this.targets).build())); } return String.format("nearVector:{%s}", String.join(" ", arg)); } + + /** + * withValidTargetVectors makes sure the target names are repeated for each target vector, + * which is required by server, but may be easily overlooked by the user. + * + *

+ * Note, too, that in case the user fails to pass a value in targetVectors altogether, it will not be added to the query. + * + * @return A copy of the Targets with validated target vectors. + */ + private Targets withValidTargetVectors(Targets targets) { + return Targets.builder(). + combinationMethod(targets.getCombinationMethod()). + weightsMulti(targets.getWeights()). + targetVectors(prepareTargetVectors(targets.getTargetVectors())). + build(); + } + + /** + * prepareTargetVectors adds appends the target name for each target vector associated with it. + */ + private String[] prepareTargetVectors(String[] targets) { + List out = new ArrayList<>(); + for (String target : targets) { + if (this.vectorsPerTarget.containsKey(target)) { + int l = this.vectorsPerTarget.get(target).length; + for (int i = 0; i < l; i++) { + out.add(target); + } + } else { + out.add(target); + } + } + return out.toArray(new String[0]); + } + + // Extend Lombok's builder to overload some methods. + public static class NearVectorArgumentBuilder { + Map vectorsPerTarget = new LinkedHashMap<>(); + + public NearVectorArgumentBuilder vectorPerTarget(Map vectors) { + this.vectorsPerTarget.clear(); // Overwrite the existing entries each time this is called. + for (Map.Entry e : vectors.entrySet()) { + this.vectorsPerTarget.put(e.getKey(), new Float[][]{e.getValue()}); + } + return this; + } + + public NearVectorArgumentBuilder vectorsPerTarget(Map vectors) { + this.vectorsPerTarget = vectors; + return this; + } + } } diff --git a/src/main/java/io/weaviate/client/v1/graphql/query/argument/Targets.java b/src/main/java/io/weaviate/client/v1/graphql/query/argument/Targets.java index 20cb1255..ffce7b58 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/query/argument/Targets.java +++ b/src/main/java/io/weaviate/client/v1/graphql/query/argument/Targets.java @@ -1,6 +1,7 @@ package io.weaviate.client.v1.graphql.query.argument; import io.weaviate.client.v1.graphql.query.util.Serializer; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -20,7 +21,7 @@ public class Targets { CombinationMethod combinationMethod; String[] targetVectors; - Map weights; + Map weights; public enum CombinationMethod { minimum("minimum"), @@ -52,13 +53,38 @@ String build() { } if (weights != null && !weights.isEmpty()) { Set weightsArg = new LinkedHashSet<>(); - for (Map.Entry entry : weights.entrySet()) { - weightsArg.add(String.format("%s:%s", entry.getKey(), entry.getValue())); + for (Map.Entry e : weights.entrySet()) { + Float[] weightsPerTarget = e.getValue(); + String target = e.getKey(); + + String weight = Serializer.array(weightsPerTarget); + if (weightsPerTarget.length == 1) { + weight = weightsPerTarget[0].toString(); + } + weightsArg.add(String.format("%s:%s", target, weight)); } arg.add(String.format("weights:{%s}", String.join(" ", weightsArg))); } return String.format("targets:{%s}", String.join(" ", arg)); } + + // Extend lombok's builder to overload some methods. + public static class TargetsBuilder { + Map weights = new LinkedHashMap<>(); + + public TargetsBuilder weights(Map weights) { + this.weights.clear(); // Overwrite the existing entries each time this is called. + for (Map.Entry e : weights.entrySet()) { + this.weights.put(e.getKey(), new Float[]{e.getValue()}); + } + return this; + } + + public TargetsBuilder weightsMulti(Map weights) { + this.weights = weights; + return this; + } + } } diff --git a/src/main/java/io/weaviate/client/v1/graphql/query/util/Serializer.java b/src/main/java/io/weaviate/client/v1/graphql/query/util/Serializer.java index 8f930b92..b15a6f44 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/query/util/Serializer.java +++ b/src/main/java/io/weaviate/client/v1/graphql/query/util/Serializer.java @@ -1,16 +1,15 @@ package io.weaviate.client.v1.graphql.query.util; -import org.apache.commons.lang3.StringEscapeUtils; -import org.apache.commons.lang3.StringUtils; - import java.util.Arrays; -import java.util.Objects; import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.commons.lang3.StringEscapeUtils; +import org.apache.commons.lang3.StringUtils; public class Serializer { - private Serializer() {} + private Serializer() { + } /** * Creates graphql safe string @@ -69,7 +68,7 @@ public static String array(T[] input) { * Creates array string * It is up to user to make elements json safe * - * @param input array of arbitrary elements + * @param input array of arbitrary elements * @param mapper maps single element before building array * @return array string */ @@ -78,7 +77,12 @@ public static String array(T[] input, Function mapper) { if (input != null) { inner = Arrays.stream(input) .map(mapper) - .map(Objects::toString) + .map(obj -> { + if (obj.getClass().isArray()) { + return array((Object[]) obj); + } + return obj.toString(); + }) .collect(Collectors.joining(",")); } diff --git a/src/main/java/io/weaviate/client/v1/misc/model/ReplicationConfig.java b/src/main/java/io/weaviate/client/v1/misc/model/ReplicationConfig.java index 89fb933c..8e645ea1 100644 --- a/src/main/java/io/weaviate/client/v1/misc/model/ReplicationConfig.java +++ b/src/main/java/io/weaviate/client/v1/misc/model/ReplicationConfig.java @@ -1,5 +1,6 @@ package io.weaviate.client.v1.misc.model; +import com.google.gson.annotations.SerializedName; import lombok.AccessLevel; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -15,4 +16,12 @@ public class ReplicationConfig { Boolean asyncEnabled; Integer factor; + DeletionStrategy deletionStrategy; + + public enum DeletionStrategy { + @SerializedName("DeleteOnConflict") + DELETE_ON_CONFLICT, + @SerializedName("NoAutomatedResolution") + NO_AUTOMATED_RESOLUTION; + } } diff --git a/src/main/java/io/weaviate/client/v1/misc/model/VectorIndexConfig.java b/src/main/java/io/weaviate/client/v1/misc/model/VectorIndexConfig.java index 9a74a204..4fb05b9e 100644 --- a/src/main/java/io/weaviate/client/v1/misc/model/VectorIndexConfig.java +++ b/src/main/java/io/weaviate/client/v1/misc/model/VectorIndexConfig.java @@ -1,5 +1,6 @@ package io.weaviate.client.v1.misc.model; +import com.google.gson.annotations.SerializedName; import lombok.AccessLevel; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -20,6 +21,7 @@ public class VectorIndexConfig { Integer dynamicEfMin; Integer dynamicEfMax; Integer dynamicEfFactor; + FilterStrategy filterStrategy; Long vectorCacheMaxObjects; Integer flatSearchCutoff; Integer cleanupIntervalSeconds; @@ -27,4 +29,11 @@ public class VectorIndexConfig { PQConfig pq; BQConfig bq; SQConfig sq; + + public enum FilterStrategy { + @SerializedName("sweeping") + SWEEPING, + @SerializedName("acorn") + ACORN; + } } diff --git a/src/test/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgumentTest.java b/src/test/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgumentTest.java index 78074ad3..fa186dc6 100644 --- a/src/test/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgumentTest.java +++ b/src/test/java/io/weaviate/client/v1/graphql/query/argument/NearVectorArgumentTest.java @@ -1,10 +1,10 @@ package io.weaviate.client.v1.graphql.query.argument; import java.util.LinkedHashMap; -import org.junit.Test; - +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; +import org.junit.Test; public class NearVectorArgumentTest { @@ -60,7 +60,7 @@ public void testBuildWithTargets() { weights.put("t1", 0.8f); weights.put("t2", 0.2f); Targets targets = Targets.builder() - .targetVectors(new String[]{ "t1", "t2" }) + .targetVectors(new String[]{"t1", "t2"}) .combinationMethod(Targets.CombinationMethod.sum) .weights(weights) .build(); @@ -71,6 +71,24 @@ public void testBuildWithTargets() { // when String arg = nearVector.build(); // then - assertEquals("nearVector:{vectorPerTarget:{t1:[1.0,2.0,3.0] t2:[0.1,0.2,0.3]} targets:{combinationMethod:sum targetVectors:[\"t1\",\"t2\"] weights:{t1:0.8 t2:0.2}}}", arg); + assertEquals( + "nearVector:{vectorPerTarget:{t1:[1.0,2.0,3.0] t2:[0.1,0.2,0.3]} targets:{combinationMethod:sum targetVectors:[\"t1\",\"t2\"] weights:{t1:0.8 t2:0.2}}}", + arg); + } + + @Test + public void testBuildWithMultipleVectorsPerTarget() { + Map vectorsPerTarget = new LinkedHashMap() { + { + this.put("t1", new Float[][]{new Float[]{1f, 2f, 3f}, new Float[]{4f, 5f, 6f}}); + this.put("t2", new Float[][]{new Float[]{.1f, .2f, .3f}}); + } + }; + NearVectorArgument nearVector = + NearVectorArgument.builder().targets(Targets.builder().targetVectors(new String[]{"t1", "t2"}).build()).vectorsPerTarget(vectorsPerTarget).build(); + + String got = nearVector.build(); + + assertEquals("nearVector:{vectorPerTarget:{t1:[[1.0,2.0,3.0],[4.0,5.0,6.0]] t2:[0.1,0.2,0.3]} targets:{targetVectors:[\"t1\",\"t1\",\"t2\"]}}", got); } } diff --git a/src/test/java/io/weaviate/client/v1/graphql/query/argument/TargetsTest.java b/src/test/java/io/weaviate/client/v1/graphql/query/argument/TargetsTest.java index ea207b81..14edff4f 100644 --- a/src/test/java/io/weaviate/client/v1/graphql/query/argument/TargetsTest.java +++ b/src/test/java/io/weaviate/client/v1/graphql/query/argument/TargetsTest.java @@ -1,10 +1,10 @@ package io.weaviate.client.v1.graphql.query.argument; import java.util.LinkedHashMap; -import org.junit.Test; - +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import org.junit.Test; public class TargetsTest { @@ -12,7 +12,7 @@ public class TargetsTest { public void testBuild() { // given Targets targets = Targets.builder() - .targetVectors(new String[]{ "t1", "t2" }) + .targetVectors(new String[]{"t1", "t2"}) .combinationMethod(Targets.CombinationMethod.average) .build(); // when @@ -29,7 +29,7 @@ public void testBuildWithWeights() { weights.put("t1", 0.8f); weights.put("t2", 0.2f); Targets targets = Targets.builder() - .targetVectors(new String[]{ "t1", "t2" }) + .targetVectors(new String[]{"t1", "t2"}) .combinationMethod(Targets.CombinationMethod.manualWeights) .weights(weights) .build(); @@ -39,4 +39,21 @@ public void testBuildWithWeights() { assertNotNull(targetsStr); assertEquals("targets:{combinationMethod:manualWeights targetVectors:[\"t1\",\"t2\"] weights:{t1:0.8 t2:0.2}}", targetsStr); } + + @Test + public void testMultipleWeightsPerTargetVector() { + Map weights = new LinkedHashMap() { + { + this.put("t1", new Float[]{.8f, .34f}); + this.put("t2", new Float[]{.2f}); + } + }; + Targets targets = + Targets.builder().targetVectors(new String[]{"t1", "t2"}).combinationMethod(Targets.CombinationMethod.relativeScore).weightsMulti(weights).build(); + + String got = targets.build(); + + assertNotNull(got); + assertEquals("targets:{combinationMethod:relativeScore targetVectors:[\"t1\",\"t2\"] weights:{t1:[0.8,0.34] t2:0.2}}", got); + } } diff --git a/src/test/java/io/weaviate/client/v1/schema/model/WeaviateClassTest.java b/src/test/java/io/weaviate/client/v1/schema/model/WeaviateClassTest.java index e0b7cdd7..5ea55a56 100644 --- a/src/test/java/io/weaviate/client/v1/schema/model/WeaviateClassTest.java +++ b/src/test/java/io/weaviate/client/v1/schema/model/WeaviateClassTest.java @@ -3,13 +3,11 @@ import com.google.gson.GsonBuilder; import io.weaviate.client.v1.misc.model.BQConfig; import io.weaviate.client.v1.misc.model.VectorIndexConfig; -import org.junit.Test; - import java.util.HashMap; import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.InstanceOfAssertFactories.MAP; +import org.junit.Test; public class WeaviateClassTest { @@ -110,13 +108,15 @@ public void shouldSerializeClassWithVectorConfig() { "\"moduleConfig\":{\"text2vec-contextionary\":{\"vectorizeClassName\":false}}," + "\"vectorConfig\":{" + "\"hnswVector\":{\"vectorIndexType\":\"hnsw\",\"vectorizer\":{\"text2vec-contextionary\":\"some-setting\"}}," + - "\"flatVector\":{\"vectorIndexConfig\":{\"bq\":{\"enabled\":true,\"rescoreLimit\":100}},\"vectorIndexType\":\"flat\",\"vectorizer\":{\"text2vec-contextionary\":\"some-setting\"}}" + + "\"flatVector\":{\"vectorIndexConfig\":{\"bq\":{\"enabled\":true,\"rescoreLimit\":100}},\"vectorIndexType\":\"flat\"," + + "\"vectorizer\":{\"text2vec-contextionary\":\"some-setting\"}}" + "}}"), serialized -> assertThat(serialized).isEqualTo("{\"class\":\"Band\"," + "\"description\":\"Band that plays and produces music\"," + "\"moduleConfig\":{\"text2vec-contextionary\":{\"vectorizeClassName\":false}}," + "\"vectorConfig\":{" + - "\"flatVector\":{\"vectorIndexConfig\":{\"bq\":{\"enabled\":true,\"rescoreLimit\":100}},\"vectorIndexType\":\"flat\",\"vectorizer\":{\"text2vec-contextionary\":\"some-setting\"}}" + + "\"flatVector\":{\"vectorIndexConfig\":{\"bq\":{\"enabled\":true,\"rescoreLimit\":100}},\"vectorIndexType\":\"flat\"," + + "\"vectorizer\":{\"text2vec-contextionary\":\"some-setting\"}}" + "\"hnswVector\":{\"vectorIndexType\":\"hnsw\",\"vectorizer\":{\"text2vec-contextionary\":\"some-setting\"}}," + "}}") ); diff --git a/src/test/java/io/weaviate/integration/client/WeaviateVersion.java b/src/test/java/io/weaviate/integration/client/WeaviateVersion.java index d27c0a1f..72946175 100644 --- a/src/test/java/io/weaviate/integration/client/WeaviateVersion.java +++ b/src/test/java/io/weaviate/integration/client/WeaviateVersion.java @@ -3,12 +3,12 @@ public class WeaviateVersion { // docker image version - public static final String WEAVIATE_IMAGE = "stable-v1.26-6a411a4"; + public static final String WEAVIATE_IMAGE = "1.27.0"; // to be set according to weaviate docker image - public static final String EXPECTED_WEAVIATE_VERSION = "1.26.6"; + public static final String EXPECTED_WEAVIATE_VERSION = "1.27.0"; // to be set according to weaviate docker image - public static final String EXPECTED_WEAVIATE_GIT_HASH = "6a411a4"; + public static final String EXPECTED_WEAVIATE_GIT_HASH = "6c571ff"; private WeaviateVersion() { } diff --git a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLMultiTargetSearchTest.java b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLMultiTargetSearchTest.java index c9928aef..6ea4701d 100644 --- a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLMultiTargetSearchTest.java +++ b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLMultiTargetSearchTest.java @@ -27,6 +27,7 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.InstanceOfAssertFactories.ARRAY; +import static org.junit.Assert.assertNull; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; @@ -64,7 +65,7 @@ public void shouldPerformMultiTargetSearch() throws InterruptedException { setupDB(className); Field _additional = Field.builder() .name("_additional") - .fields(new Field[]{ Field.builder().name("id").build(), Field.builder().name("distance").build() }) + .fields(new Field[]{Field.builder().name("id").build(), Field.builder().name("distance").build()}) .build(); // nearText Map weights = new HashMap<>(); @@ -72,12 +73,12 @@ public void shouldPerformMultiTargetSearch() throws InterruptedException { weights.put(title1, 0.6f); weights.put(title2, 0.3f); Targets targets = Targets.builder() - .targetVectors(new String[]{ titleAndContent, title1, title2 }) + .targetVectors(new String[]{titleAndContent, title1, title2}) .combinationMethod(Targets.CombinationMethod.manualWeights) .weights(weights) .build(); NearTextArgument nearText = client.graphQL().arguments().nearTextArgBuilder() - .concepts(new String[]{ "Water black" }) + .concepts(new String[]{"Water black"}) .targets(targets) .build(); Result response = client.graphQL().get() @@ -86,15 +87,16 @@ public void shouldPerformMultiTargetSearch() throws InterruptedException { .withFields(_additional) .run(); assertGetContainsIds(response, className, id1, id2, id3); - // nearVector + // nearVector with single vector-per-target Map vectorPerTarget = new HashMap<>(); - vectorPerTarget.put(bringYourOwnVector, new Float[]{ .99f, .88f, .77f }); - vectorPerTarget.put(bringYourOwnVector2, new Float[]{ .11f, .22f, .33f }); - weights = new HashMap<>(); - weights.put(bringYourOwnVector, 0.1f); - weights.put(bringYourOwnVector2, 0.6f); + vectorPerTarget.put(bringYourOwnVector, new Float[]{.99f, .88f, .77f}); + vectorPerTarget.put(bringYourOwnVector2, new Float[]{.11f, .22f, .33f}); + weights = new HashMap() {{ + this.put(bringYourOwnVector, 0.1f); + this.put(bringYourOwnVector2, 0.6f); + }}; targets = Targets.builder() - .targetVectors(new String[]{ bringYourOwnVector, bringYourOwnVector2 }) + .targetVectors(new String[]{bringYourOwnVector, bringYourOwnVector2}) .combinationMethod(Targets.CombinationMethod.manualWeights) .weights(weights) .build(); @@ -106,10 +108,33 @@ public void shouldPerformMultiTargetSearch() throws InterruptedException { .withNearVector(nearVector) .withFields(_additional) .run(); + assertNull("check error in response:", response.getError()); + assertGetContainsIds(response, className, id2, id3); + // nearVector with multiple vector-per-target + Map vectorsPerTarget = new HashMap<>(); + vectorsPerTarget.put(bringYourOwnVector, new Float[][]{new Float[]{.99f, .88f, .77f}, new Float[]{.99f, .88f, .77f}}); + vectorsPerTarget.put(bringYourOwnVector2, new Float[][]{new Float[]{.11f, .22f, .33f}}); + Map weightsMulti = new HashMap<>(); + weightsMulti.put(bringYourOwnVector, new Float[]{0.5f, 0.5f}); + weightsMulti.put(bringYourOwnVector2, new Float[]{0.6f}); + targets = Targets.builder() + .targetVectors(new String[]{bringYourOwnVector, bringYourOwnVector2}) + .combinationMethod(Targets.CombinationMethod.manualWeights) + .weightsMulti(weightsMulti) + .build(); + nearVector = client.graphQL().arguments().nearVectorArgBuilder() + .vectorsPerTarget(vectorsPerTarget) + .targets(targets).build(); + response = client.graphQL().get() + .withClassName(className) + .withNearVector(nearVector) + .withFields(_additional) + .run(); + assertNull("check error in response:", response.getError()); assertGetContainsIds(response, className, id2, id3); // nearObject targets = Targets.builder() - .targetVectors(new String[]{ bringYourOwnVector, bringYourOwnVector2, titleAndContent, title1, title2 }) + .targetVectors(new String[]{bringYourOwnVector, bringYourOwnVector2, titleAndContent, title1, title2}) .combinationMethod(Targets.CombinationMethod.average) .build(); NearObjectArgument nearObject = client.graphQL().arguments().nearObjectArgBuilder() @@ -170,7 +195,7 @@ private void setupDB(String className) { props1.put("content", "A great fantasy novel"); props1.put("title1", "J.R.R. Tolkien The Lord of the Rings"); props1.put("title2", "Rings"); - Float[] vector1a = new Float[]{ 0.77f, 0.88f, 0.77f }; + Float[] vector1a = new Float[]{0.77f, 0.88f, 0.77f}; Map vectors1 = new HashMap<>(); vectors1.put("bringYourOwnVector", vector1a); // don't add vector for bringYourOwnVector2 @@ -180,8 +205,8 @@ private void setupDB(String className) { props2.put("content", "A great science fiction book"); props2.put("title1", "Jacek Dukaj Black Oceans"); props2.put("title2", "Water"); - Float[] vector2a = new Float[]{ 0.11f, 0.22f, 0.33f }; - Float[] vector2b = new Float[]{ 0.11f, 0.11f, 0.11f }; + Float[] vector2a = new Float[]{0.11f, 0.22f, 0.33f}; + Float[] vector2b = new Float[]{0.11f, 0.11f, 0.11f}; Map vectors2 = new HashMap<>(); vectors2.put("bringYourOwnVector", vector2a); vectors2.put("bringYourOwnVector2", vector2b); @@ -191,8 +216,8 @@ private void setupDB(String className) { props3.put("content", "New York Times bestseller and global phenomenon The Girl on the Train returns with Into the Water"); props3.put("title1", "Paula Hawkins Into the Water"); props3.put("title2", "Water go into it"); - Float[] vector3a = new Float[]{ 0.99f, 0.88f, 0.77f }; - Float[] vector3b = new Float[]{ 0.99f, 0.88f, 0.77f }; + Float[] vector3a = new Float[]{0.99f, 0.88f, 0.77f}; + Float[] vector3b = new Float[]{0.99f, 0.88f, 0.77f}; Map vectors3 = new HashMap<>(); vectors3.put("bringYourOwnVector", vector3a); vectors3.put("bringYourOwnVector2", vector3b); @@ -214,7 +239,7 @@ private WeaviateClass.VectorConfig getTitleAndContentVectorConfig() { Map titleAndContent = new HashMap<>(); Map text2vecContextionarySettings = new HashMap<>(); text2vecContextionarySettings.put("vectorizeClassName", false); - text2vecContextionarySettings.put("properties", new String[]{ "title", "content" }); + text2vecContextionarySettings.put("properties", new String[]{"title", "content"}); titleAndContent.put("text2vec-contextionary", text2vecContextionarySettings); return getHNSWSQVectorConfig(titleAndContent); } @@ -223,7 +248,7 @@ private WeaviateClass.VectorConfig getTitle1VectorConfig() { Map titleAndContent = new HashMap<>(); Map text2vecContextionarySettings = new HashMap<>(); text2vecContextionarySettings.put("vectorizeClassName", false); - text2vecContextionarySettings.put("properties", new String[]{ "title1" }); + text2vecContextionarySettings.put("properties", new String[]{"title1"}); titleAndContent.put("text2vec-contextionary", text2vecContextionarySettings); return getHNSWPQVectorConfig(titleAndContent); } @@ -232,7 +257,7 @@ private WeaviateClass.VectorConfig getTitle2VectorConfig() { Map titleAndContent = new HashMap<>(); Map text2vecContextionarySettings = new HashMap<>(); text2vecContextionarySettings.put("vectorizeClassName", false); - text2vecContextionarySettings.put("properties", new String[]{ "title2" }); + text2vecContextionarySettings.put("properties", new String[]{"title2"}); titleAndContent.put("text2vec-contextionary", text2vecContextionarySettings); return getHNSWVectorConfig(titleAndContent); } diff --git a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java index 68704bb7..3325c896 100644 --- a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java +++ b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java @@ -1602,6 +1602,8 @@ public void testGraphQLGetWithGroupBy() { }).build(); // _additional Field _additional = Field.builder().name("_additional").fields(new Field[]{group}).build(); + // Property that we group by + Field ofDocument = Field.builder().name("ofDocument{__typename}").build(); // filter arguments GroupByArgument groupBy = client.graphQL().arguments().groupByArgBuilder() .path(new String[]{"ofDocument"}).groups(3).objectsPerGroup(10).build(); @@ -1612,7 +1614,7 @@ public void testGraphQLGetWithGroupBy() { .withClassName(testData.PASSAGE) .withNearObject(nearObject) .withGroupBy(groupBy) - .withFields(_additional).run(); + .withFields(ofDocument, _additional).run(); testData.cleanupWeaviate(client); // then assertThat(groupByResult).isNotNull(); @@ -1658,6 +1660,8 @@ public void testGraphQLGetWithGroupByWithHybrid() { }).build(); // _additional Field _additional = Field.builder().name("_additional").fields(new Field[]{group}).build(); + // Property that we group by + Field content = Field.builder().name("content").build(); // filter arguments GroupByArgument groupBy = client.graphQL().arguments().groupByArgBuilder() .path(new String[]{"content"}).groups(3).objectsPerGroup(10).build(); @@ -1673,7 +1677,7 @@ public void testGraphQLGetWithGroupByWithHybrid() { .withClassName(testData.PASSAGE) .withHybrid(hybrid) .withGroupBy(groupBy) - .withFields(_additional).run(); + .withFields(content, _additional).run(); testData.cleanupWeaviate(client); // then assertThat(groupByResult).isNotNull(); @@ -1892,9 +1896,9 @@ public void shouldSupportSearchByUUID() { .withFields(fieldId) .run(); - assertIds(className, resultUuid, new String[]{ id }); - assertIds(className, resultUuidArray1, new String[]{ id }); - assertIds(className, resultUuidArray2, new String[]{ id }); + assertIds(className, resultUuid, new String[]{id}); + assertIds(className, resultUuidArray1, new String[]{id}); + assertIds(className, resultUuidArray2, new String[]{id}); Result deleteStatus = client.schema().allDeleter().run(); @@ -2097,9 +2101,9 @@ public void shouldSupportSearchWithContains() { .withClassName(className) .withWhere(WhereArgument.builder().filter(filter).build()) .withFields(Field.builder() - .name("_additional") - .fields(Field.builder().name("id").build()) - .build(), + .name("_additional") + .fields(Field.builder().name("id").build()) + .build(), Field.builder().name("bool").build(), Field.builder().name("bools").build()) .run(); @@ -2108,29 +2112,29 @@ public void shouldSupportSearchWithContains() { }; // FIXME: 0 returned -// runAndAssertExpectedIds.accept( -// WhereFilter.builder().path("bools").operator(Operator.ContainsAll).valueBoolean(boolsArray[0]).build(), -// new String[]{id1, id2}); + // runAndAssertExpectedIds.accept( + // WhereFilter.builder().path("bools").operator(Operator.ContainsAll).valueBoolean(boolsArray[0]).build(), + // new String[]{id1, id2}); // FIXME: 0 returned -// runAndAssertExpectedIds.accept( -// WhereFilter.builder().path("bools").operator(Operator.ContainsAll).valueBoolean(boolsArray[1]).build(), -// new String[]{id1, id2}); + // runAndAssertExpectedIds.accept( + // WhereFilter.builder().path("bools").operator(Operator.ContainsAll).valueBoolean(boolsArray[1]).build(), + // new String[]{id1, id2}); // FIXME: 1 returned -// runAndAssertExpectedIds.accept( -// WhereFilter.builder().path("bools").operator(Operator.ContainsAll).valueBoolean(boolsArray[2]).build(), -// new String[]{id1, id2, id3}); + // runAndAssertExpectedIds.accept( + // WhereFilter.builder().path("bools").operator(Operator.ContainsAll).valueBoolean(boolsArray[2]).build(), + // new String[]{id1, id2, id3}); // FIXME: 1 returned -// runAndAssertExpectedIds.accept( -// WhereFilter.builder().path("bools").operator(Operator.ContainsAny).valueBoolean(boolsArray[0]).build(), -// new String[]{id1, id2, id3}); + // runAndAssertExpectedIds.accept( + // WhereFilter.builder().path("bools").operator(Operator.ContainsAny).valueBoolean(boolsArray[0]).build(), + // new String[]{id1, id2, id3}); // FIXME: 1 returned -// runAndAssertExpectedIds.accept( -// WhereFilter.builder().path("bools").operator(Operator.ContainsAny).valueBoolean(boolsArray[1]).build(), -// new String[]{id1, id2, id3}); + // runAndAssertExpectedIds.accept( + // WhereFilter.builder().path("bools").operator(Operator.ContainsAny).valueBoolean(boolsArray[1]).build(), + // new String[]{id1, id2, id3}); // FIXME: 1 returned -// runAndAssertExpectedIds.accept( -// WhereFilter.builder().path("bools").operator(Operator.ContainsAny).valueBoolean(boolsArray[2]).build(), -// new String[]{id1, id2, id3}); + // runAndAssertExpectedIds.accept( + // WhereFilter.builder().path("bools").operator(Operator.ContainsAny).valueBoolean(boolsArray[2]).build(), + // new String[]{id1, id2, id3}); runAndAssertExpectedIds.accept( WhereFilter.builder().path("ints").operator(Operator.ContainsAll).valueInt(intsArray[0]).build(), @@ -2256,7 +2260,8 @@ private void assertIds(String className, Result gqlResult, Stri .extracting(get -> ((Map) get).get(className)).isInstanceOf(List.class).asList() .hasSize(expectedIds.length); - List> results = (List>) ((Map) (((Map) (gqlResult.getResult().getData())).get("Get"))).get(className); + List> results = (List>) ((Map) (((Map) (gqlResult.getResult().getData())).get( + "Get"))).get(className); String[] resultIds = results.stream() .map(m -> m.get("_additional")) .map(a -> ((Map) a).get("id")) diff --git a/src/test/java/io/weaviate/integration/client/schema/ClientSchemaTest.java b/src/test/java/io/weaviate/integration/client/schema/ClientSchemaTest.java index 6a1f63c5..ba3a0660 100644 --- a/src/test/java/io/weaviate/integration/client/schema/ClientSchemaTest.java +++ b/src/test/java/io/weaviate/integration/client/schema/ClientSchemaTest.java @@ -11,6 +11,7 @@ import io.weaviate.client.v1.misc.model.DistanceType; import io.weaviate.client.v1.misc.model.InvertedIndexConfig; import io.weaviate.client.v1.misc.model.PQConfig; +import io.weaviate.client.v1.misc.model.ReplicationConfig; import io.weaviate.client.v1.misc.model.ShardingConfig; import io.weaviate.client.v1.misc.model.StopwordConfig; import io.weaviate.client.v1.misc.model.VectorIndexConfig; @@ -1466,12 +1467,16 @@ public void shouldUpdateClass() { .vectorizer("text2vec-contextionary") .properties(properties) .vectorIndexConfig(VectorIndexConfig.builder() + .filterStrategy(VectorIndexConfig.FilterStrategy.ACORN) .pq(PQConfig.builder() .enabled(true) .trainingLimit(99_999) .segments(96) .build()) .build()) + .replicationConfig(ReplicationConfig.builder() + .deletionStrategy(ReplicationConfig.DeletionStrategy.DELETE_ON_CONFLICT) + .build()) .build(); Result updateResult = client.schema().classUpdater() @@ -1494,14 +1499,19 @@ public void shouldUpdateClass() { .withFailMessage(null) .extracting(Result::getResult).isNotNull() .extracting(WeaviateClass::getVectorIndexConfig).isNotNull() + .returns(VectorIndexConfig.FilterStrategy.ACORN, VectorIndexConfig::getFilterStrategy) .extracting(VectorIndexConfig::getPq).isNotNull() .returns(true, PQConfig::getEnabled) .returns(96, PQConfig::getSegments) .returns(99_999, PQConfig::getTrainingLimit); + + assertThat(updatedClassResult.getResult()) + .extracting(WeaviateClass::getReplicationConfig).isNotNull() + .returns(ReplicationConfig.DeletionStrategy.DELETE_ON_CONFLICT, ReplicationConfig::getDeletionStrategy); } @Test - public void shouldCreateClassWithVectorConfig() { + public void shouldCreateClassWithVectorAndReplicationConfig() { Integer cleanupIntervalSeconds = 300; // vector index config Integer efConstruction = 128; @@ -1521,6 +1531,9 @@ public void shouldCreateClassWithVectorConfig() { Integer centroids = 8; String encoderType = "tile"; String encoderDistribution = "normal"; + // replication config + Boolean asyncEnabled = true; + Integer replicationFactor = 1; VectorIndexConfig vectorIndexConfig = VectorIndexConfig.builder() .cleanupIntervalSeconds(cleanupIntervalSeconds) @@ -1529,6 +1542,7 @@ public void shouldCreateClassWithVectorConfig() { .vectorCacheMaxObjects(vectorCacheMaxObjects) .ef(ef) .skip(skip) + .filterStrategy(VectorIndexConfig.FilterStrategy.SWEEPING) .dynamicEfFactor(dynamicEfFactor) .dynamicEfMax(dynamicEfMax) .dynamicEfMin(dynamicEfMin) @@ -1546,6 +1560,12 @@ public void shouldCreateClassWithVectorConfig() { .build()) .build(); + ReplicationConfig replicationConfig = ReplicationConfig.builder() + .factor(replicationFactor) + .asyncEnabled(asyncEnabled) + .deletionStrategy(ReplicationConfig.DeletionStrategy.NO_AUTOMATED_RESOLUTION) + .build(); + Map contextionaryVectorizerSettings = new HashMap<>(); contextionaryVectorizerSettings.put("vectorizeClassName", true); Map contextionaryVectorizer = new HashMap<>(); @@ -1562,6 +1582,7 @@ public void shouldCreateClassWithVectorConfig() { .className("Band") .description("Band that plays and produces music") .vectorConfig(vectorConfig) + .replicationConfig(replicationConfig) .build(); Result createStatus = client.schema().classCreator() @@ -1602,6 +1623,7 @@ public void shouldCreateClassWithVectorConfig() { .returns(cleanupIntervalSeconds, VectorIndexConfig::getCleanupIntervalSeconds) .returns(efConstruction, VectorIndexConfig::getEfConstruction) .returns(maxConnections, VectorIndexConfig::getMaxConnections) + .returns(VectorIndexConfig.FilterStrategy.SWEEPING, VectorIndexConfig::getFilterStrategy) .returns(vectorCacheMaxObjects, VectorIndexConfig::getVectorCacheMaxObjects) .returns(ef, VectorIndexConfig::getEf) .returns(skip, VectorIndexConfig::getSkip) @@ -1622,5 +1644,11 @@ public void shouldCreateClassWithVectorConfig() { .returns(encoderDistribution, PQConfig.Encoder::getDistribution); }) ); + + assertThat(bandClass.getResult()) + .extracting(WeaviateClass::getReplicationConfig).isNotNull() + .returns(replicationFactor, ReplicationConfig::getFactor) + .returns(asyncEnabled, ReplicationConfig::getAsyncEnabled) + .returns(ReplicationConfig.DeletionStrategy.NO_AUTOMATED_RESOLUTION, ReplicationConfig::getDeletionStrategy); } } diff --git a/src/test/java/io/weaviate/integration/client/schema/ClusterSchemaTest.java b/src/test/java/io/weaviate/integration/client/schema/ClusterSchemaTest.java index 246963db..a96a7d87 100644 --- a/src/test/java/io/weaviate/integration/client/schema/ClusterSchemaTest.java +++ b/src/test/java/io/weaviate/integration/client/schema/ClusterSchemaTest.java @@ -86,8 +86,10 @@ public void shouldNotCreateClassWithTooHighFactor() { .extracting(WeaviateError::getMessages).asList() .first() .extracting(m -> ((WeaviateErrorMessage) m).getMessage()).asInstanceOf(STRING) - .contains("not enough storage replicas"); - }@Test + .contains("could not find enough weaviate nodes for replication"); + } + + @Test public void shouldAddObjectsWithNestedProperties_EntireSchema() { WeaviateClass schemaClass; String className = "ClassWithObjectProperty";