Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gql: Search with multiple target vectors and weights #318

Merged
merged 21 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8ab612a
gh-304: support multiple weights per target vector
bevzzz Oct 24, 2024
9bcfd1e
gh-304: support multiple vectors per target in near vector search
bevzzz Oct 24, 2024
c375772
ci: target v1.27 server
bevzzz Oct 25, 2024
4dfddc4
fix: report GraphQL errors returned with HTTP 200
bevzzz Oct 24, 2024
214b3a6
fix: update expected error message
bevzzz Oct 25, 2024
566844f
test: update groupBy query integration test
bevzzz Oct 25, 2024
650a2e9
test: fix nearObject with groupBy test
bevzzz Oct 25, 2024
69460ae
feat: ensure query has a target name for each target vector
bevzzz Oct 29, 2024
21ef147
refactor: use 'else' branch instead of 'continue'
bevzzz Oct 30, 2024
1715f9e
feat: add filter strategy to class's vector index config
bevzzz Oct 30, 2024
6197168
feat: add deletion strategy to replication config
bevzzz Oct 30, 2024
7465c91
test: fix expected result
bevzzz Oct 30, 2024
d9459f3
refactor: extract "sneaky errors" from gql responses
bevzzz Oct 30, 2024
1564b7f
chore: minimize formatting changes
bevzzz Oct 30, 2024
3d1b60e
chore: format code
bevzzz Oct 30, 2024
9a4c3e4
chore: remove formatting differences
bevzzz Oct 30, 2024
4b84a16
test: add test fur multi-target search with single vector-per-target
bevzzz Oct 30, 2024
6d4d973
chore: fix formatting differences
bevzzz Oct 30, 2024
4166cf9
refactor: drop SneakyErrors interface
bevzzz Oct 30, 2024
328beaa
chore: implement suggestions from review
bevzzz Oct 30, 2024
ca12123
test: remove defaults from expected values
bevzzz Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/main/java/io/weaviate/client/base/BaseClient.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package io.weaviate.client.base;

import io.weaviate.client.Config;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.http.HttpResponse;
import io.weaviate.client.v1.graphql.GraphQL;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import java.util.Collections;

import io.weaviate.client.Config;
import java.util.List;

public abstract class BaseClient<T> {
private final HttpClient client;
Expand Down Expand Up @@ -51,7 +53,12 @@ private Response<T> sendRequest(String endpoint, Object payload, String method,

if (statusCode < 399) {
T body = toResponse(responseBody, classOfT);
return new Response<>(statusCode, body, null);
WeaviateErrorResponse errors = null;

if (body != null && classOfT.equals(GraphQL.class)) {
errors = getWeaviateGraphQLErrorResponse((GraphQLResponse) body, statusCode);
}
return new Response<>(statusCode, body, errors);
}

WeaviateErrorResponse error = toResponse(responseBody, WeaviateErrorResponse.class);
Expand Down Expand Up @@ -90,10 +97,22 @@ private String toJsonString(Object object) {
}

private WeaviateErrorResponse getWeaviateErrorResponse(Exception e) {
WeaviateErrorMessage error = WeaviateErrorMessage.builder()
.message(e.getMessage())
.throwable(e)
.build();
WeaviateErrorMessage error = WeaviateErrorMessage.builder().message(e.getMessage()).throwable(e).build();
return WeaviateErrorResponse.builder().error(Collections.singletonList(error)).build();
}

/**
* Extract errors from {@link WeaviateErrorResponse} from a GraphQL response body.
*
* @param gql GraphQL response body.
* @param code HTTP status code to pass in the {@link WeaviateErrorResponse}.
* @return Error response to be returned to the caller.
*/
private WeaviateErrorResponse getWeaviateGraphQLErrorResponse(GraphQLResponse gql, int code) {
List<WeaviateErrorMessage> messages = gql.errorMessages();
if (messages == null || messages.isEmpty()) {
return null;
}
return WeaviateErrorResponse.builder().code(code).error(gql.errorMessages()).build();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.weaviate.client.base;

import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
Expand All @@ -9,6 +10,7 @@
@Getter
@Builder
@ToString
@AllArgsConstructor(access = AccessLevel.PUBLIC)
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class WeaviateErrorMessage {
String message;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package io.weaviate.client.v1.graphql.model;

import io.weaviate.client.base.WeaviateErrorMessage;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
Expand All @@ -15,4 +19,19 @@
public class GraphQLResponse {
Object data;
GraphQLError[] errors;


/**
* Extract the 'message' portion of every error in the response, omitting 'path' and 'location'.
*
* @return Non-throwable WeaviateErrorMessages
*/
public List<WeaviateErrorMessage> errorMessages() {
if (errors == null || errors.length == 0) {
return null;
}
return Arrays.stream(errors)
.map(err -> new WeaviateErrorMessage(err.getMessage(), null))
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package io.weaviate.client.v1.graphql.query.argument;

import io.weaviate.client.v1.graphql.query.util.Serializer;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.LinkedHashSet;
import java.util.Set;

@Getter
@Builder
Expand All @@ -24,7 +25,7 @@ public class NearVectorArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Map<String, Float[]> vectorPerTarget;
Map<String, Float[][]> vectorsPerTarget;
Targets targets;

@Override
Expand All @@ -41,19 +42,73 @@ public String build() {
arg.add(String.format("distance:%s", distance));
}
if (ArrayUtils.isNotEmpty(targetVectors)) {
arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
}
if (vectorPerTarget != null && !vectorPerTarget.isEmpty()) {
if (vectorsPerTarget != null && !vectorsPerTarget.isEmpty()) {
Set<String> vectorPerTargetArg = new LinkedHashSet<>();
for (Map.Entry<String, Float[]> entry : vectorPerTarget.entrySet()) {
vectorPerTargetArg.add(String.format("%s:%s", entry.getKey(), Serializer.array(entry.getValue())));
for (Map.Entry<String, Float[][]> e : vectorsPerTarget.entrySet()) {
Float[][] vectors = e.getValue();
vectorPerTargetArg.add(String.format("%s:%s", e.getKey(), vectors.length == 1 ? Serializer.array(vectors[0]) : Serializer.array(vectors)));
}
arg.add(String.format("vectorPerTarget:{%s}", String.join(" ", vectorPerTargetArg)));
}
if (targets != null) {
arg.add(String.format("%s", targets.build()));
arg.add(String.format("%s", withValidTargetVectors(this.targets).build()));
}

return String.format("nearVector:{%s}", String.join(" ", arg));
}

/**
* withValidTargetVectors makes sure the target names are repeated for each target vector,
* which is required by server, but may be easily overlooked by the user.
*
* <p>
* Note, too, that in case the user fails to pass a value in targetVectors altogether, it will not be added to the query.
*
* @return A copy of the Targets with validated target vectors.
*/
private Targets withValidTargetVectors(Targets targets) {
return Targets.builder().
combinationMethod(targets.getCombinationMethod()).
weightsMulti(targets.getWeights()).
targetVectors(prepareTargetVectors(targets.getTargetVectors())).
build();
}

/**
* prepareTargetVectors adds appends the target name for each target vector associated with it.
*/
private String[] prepareTargetVectors(String[] targets) {
List<String> out = new ArrayList<>();
for (String target : targets) {
if (this.vectorsPerTarget.containsKey(target)) {
int l = this.vectorsPerTarget.get(target).length;
for (int i = 0; i < l; i++) {
out.add(target);
}
} else {
out.add(target);
}
}
return out.toArray(new String[0]);
}

// Extend Lombok's builder to overload some methods.
public static class NearVectorArgumentBuilder {
Map<String, Float[][]> vectorsPerTarget = new LinkedHashMap<>();

public NearVectorArgumentBuilder vectorPerTarget(Map<String, Float[]> vectors) {
this.vectorsPerTarget.clear(); // Overwrite the existing entries each time this is called.
for (Map.Entry<String, Float[]> e : vectors.entrySet()) {
this.vectorsPerTarget.put(e.getKey(), new Float[][]{e.getValue()});
}
return this;
}

public NearVectorArgumentBuilder vectorsPerTarget(Map<String, Float[][]> vectors) {
this.vectorsPerTarget = vectors;
return this;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.weaviate.client.v1.graphql.query.argument;

import io.weaviate.client.v1.graphql.query.util.Serializer;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
Expand All @@ -20,7 +21,7 @@
public class Targets {
CombinationMethod combinationMethod;
String[] targetVectors;
Map<String, Float> weights;
Map<String, Float[]> weights;

public enum CombinationMethod {
minimum("minimum"),
Expand Down Expand Up @@ -52,13 +53,38 @@ String build() {
}
if (weights != null && !weights.isEmpty()) {
Set<String> weightsArg = new LinkedHashSet<>();
for (Map.Entry<String, Float> entry : weights.entrySet()) {
weightsArg.add(String.format("%s:%s", entry.getKey(), entry.getValue()));
for (Map.Entry<String, Float[]> e : weights.entrySet()) {
Float[] weightsPerTarget = e.getValue();
String target = e.getKey();

String weight = Serializer.array(weightsPerTarget);
if (weightsPerTarget.length == 1) {
weight = weightsPerTarget[0].toString();
}
weightsArg.add(String.format("%s:%s", target, weight));
}
arg.add(String.format("weights:{%s}", String.join(" ", weightsArg)));
}

return String.format("targets:{%s}", String.join(" ", arg));
}

// Extend lombok's builder to overload some methods.
public static class TargetsBuilder {
Map<String, Float[]> weights = new LinkedHashMap<>();

public TargetsBuilder weights(Map<String, Float> weights) {
this.weights.clear(); // Overwrite the existing entries each time this is called.
for (Map.Entry<String, Float> e : weights.entrySet()) {
this.weights.put(e.getKey(), new Float[]{e.getValue()});
}
return this;
}

public TargetsBuilder weightsMulti(Map<String, Float[]> weights) {
this.weights = weights;
return this;
}
}
}

Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package io.weaviate.client.v1.graphql.query.util;

import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.commons.lang3.StringUtils;

import java.util.Arrays;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.commons.lang3.StringUtils;

public class Serializer {

private Serializer() {}
private Serializer() {
}

/**
* Creates graphql safe string
Expand Down Expand Up @@ -69,7 +68,7 @@ public static <T> String array(T[] input) {
* Creates array string
* It is up to user to make elements json safe
*
* @param input array of arbitrary elements
* @param input array of arbitrary elements
* @param mapper maps single element before building array
* @return array string
*/
Expand All @@ -78,7 +77,12 @@ public static <T, R> String array(T[] input, Function<T, R> mapper) {
if (input != null) {
inner = Arrays.stream(input)
.map(mapper)
.map(Objects::toString)
.map(obj -> {
if (obj.getClass().isArray()) {
return array((Object[]) obj);
}
return obj.toString();
})
.collect(Collectors.joining(","));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.weaviate.client.v1.misc.model;

import com.google.gson.annotations.SerializedName;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
Expand All @@ -15,4 +16,12 @@
public class ReplicationConfig {
Boolean asyncEnabled;
Integer factor;
DeletionStrategy deletionStrategy;

public enum DeletionStrategy {
@SerializedName("DeleteOnConflict")
DELETE_ON_CONFLICT,
@SerializedName("NoAutomatedResolution")
NO_AUTOMATED_RESOLUTION;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.weaviate.client.v1.misc.model;

import com.google.gson.annotations.SerializedName;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
Expand All @@ -20,11 +21,19 @@ public class VectorIndexConfig {
Integer dynamicEfMin;
Integer dynamicEfMax;
Integer dynamicEfFactor;
FilterStrategy filterStrategy;
Long vectorCacheMaxObjects;
Integer flatSearchCutoff;
Integer cleanupIntervalSeconds;
Boolean skip;
PQConfig pq;
BQConfig bq;
SQConfig sq;

public enum FilterStrategy {
@SerializedName("sweeping")
SWEEPING,
@SerializedName("acorn")
ACORN;
}
}
Loading
Loading