Skip to content

Commit

Permalink
Refactor SimpleVectorStore
Browse files Browse the repository at this point in the history
 - Remove SimpleVectorStore's dependency on deprecated embeddings from Document object
- Create a custom Content object that represents the SimpleVectorStore's contents and embedding
- Add tests
  • Loading branch information
ilayaperumalg authored and Mark Pollack committed Nov 26, 2024
1 parent d173572 commit 9d207e3
Show file tree
Hide file tree
Showing 4 changed files with 504 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@
* @author Mark Pollack
* @author Christian Tzolov
* @author Sebastien Deleuze
* @author Ilayaperumal Gopinathan
*/
public class SimpleVectorStore extends AbstractObservationVectorStore {

private static final Logger logger = LoggerFactory.getLogger(SimpleVectorStore.class);

private final ObjectMapper objectMapper;

protected Map<String, Document> store = new ConcurrentHashMap<>();
protected Map<String, SimpleVectorStoreContent> store = new ConcurrentHashMap<>();

protected EmbeddingModel embeddingModel;

Expand All @@ -94,11 +95,17 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse

@Override
public void doAdd(List<Document> documents) {
Objects.requireNonNull(documents, "Documents list cannot be null");
if (documents.isEmpty()) {
throw new IllegalArgumentException("Documents list cannot be empty");
}

for (Document document : documents) {
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
float[] embedding = this.embeddingModel.embed(document);
document.setEmbedding(embedding);
this.store.put(document.getId(), document);
SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(),
document.getContent(), document.getMetadata(), embedding);
this.store.put(document.getId(), storeContent);
}
}

Expand All @@ -120,12 +127,12 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
return this.store.values()
.stream()
.map(entry -> new Similarity(entry.getId(),
.map(entry -> new Similarity(entry,
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.filter(s -> s.score >= request.getSimilarityThreshold())
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
.limit(request.getTopK())
.map(s -> this.store.get(s.key))
.map(s -> s.getDocument())
.toList();
}

Expand Down Expand Up @@ -176,12 +183,11 @@ public void save(File file) {
* @param file the file to load the vector store content
*/
public void load(File file) {
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
TypeReference<HashMap<String, SimpleVectorStoreContent>> typeRef = new TypeReference<>() {

};
try {
Map<String, Document> deserializedMap = this.objectMapper.readValue(file, typeRef);
this.store = deserializedMap;
this.store = this.objectMapper.readValue(file, typeRef);
}
catch (IOException ex) {
throw new RuntimeException(ex);
Expand All @@ -193,12 +199,11 @@ public void load(File file) {
* @param resource the resource to load the vector store content
*/
public void load(Resource resource) {
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {
TypeReference<HashMap<String, SimpleVectorStoreContent>> typeRef = new TypeReference<>() {

};
try {
Map<String, Document> deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef);
this.store = deserializedMap;
this.store = this.objectMapper.readValue(resource.getInputStream(), typeRef);
}
catch (IOException ex) {
throw new RuntimeException(ex);
Expand Down Expand Up @@ -232,15 +237,23 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str

public static class Similarity {

private String key;
private SimpleVectorStoreContent content;

private double score;

public Similarity(String key, double score) {
this.key = key;
public Similarity(SimpleVectorStoreContent content, double score) {
this.content = content;
this.score = score;
}

Document getDocument() {
return Document.builder()
.withId(this.content.getId())
.withContent(this.content.getContent())
.withMetadata(this.content.getMetadata())
.build();
}

}

public final class EmbeddingMath {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vectorstore;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.document.id.IdGenerator;
import org.springframework.ai.document.id.RandomIdGenerator;
import org.springframework.ai.model.Content;
import org.springframework.util.Assert;

/**
* An immutable {@link Content} implementation representing content, metadata, and its
* embeddings. This class is thread-safe and all its fields are final and deeply
* immutable. The embedding vector is required for all instances of this class.
*/
public final class SimpleVectorStoreContent implements Content {

private final String id;

private final String content;

private final Map<String, Object> metadata;

private final float[] embedding;

/**
* Creates a new instance with the given content, empty metadata, and embedding
* vector.
* @param content the content text, must not be null
* @param embedding the embedding vector, must not be null
*/
@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public SimpleVectorStoreContent(@JsonProperty("content") String content,
@JsonProperty("embedding") float[] embedding) {
this(content, new HashMap<>(), embedding);
}

/**
* Creates a new instance with the given content, metadata, and embedding vector.
* @param content the content text, must not be null
* @param metadata the metadata map, must not be null
* @param embedding the embedding vector, must not be null
*/
public SimpleVectorStoreContent(String content, Map<String, Object> metadata, float[] embedding) {
this(content, metadata, new RandomIdGenerator(), embedding);
}

/**
* Creates a new instance with the given content, metadata, custom ID generator, and
* embedding vector.
* @param content the content text, must not be null
* @param metadata the metadata map, must not be null
* @param idGenerator the ID generator to use, must not be null
* @param embedding the embedding vector, must not be null
*/
public SimpleVectorStoreContent(String content, Map<String, Object> metadata, IdGenerator idGenerator,
float[] embedding) {
this(idGenerator.generateId(content, metadata), content, metadata, embedding);
}

/**
* Creates a new instance with all fields specified.
* @param id the unique identifier, must not be empty
* @param content the content text, must not be null
* @param metadata the metadata map, must not be null
* @param embedding the embedding vector, must not be null
* @throws IllegalArgumentException if any parameter is null or if id is empty
*/
public SimpleVectorStoreContent(String id, String content, Map<String, Object> metadata, float[] embedding) {
Assert.hasText(id, "id must not be null or empty");
Assert.notNull(content, "content must not be null");
Assert.notNull(metadata, "metadata must not be null");
Assert.notNull(embedding, "embedding must not be null");
Assert.isTrue(embedding.length > 0, "embedding vector must not be empty");

this.id = id;
this.content = content;
this.metadata = Collections.unmodifiableMap(new HashMap<>(metadata));
this.embedding = Arrays.copyOf(embedding, embedding.length);
}

/**
* Creates a new instance with an updated embedding vector.
* @param embedding the new embedding vector, must not be null
* @return a new instance with the updated embedding
* @throws IllegalArgumentException if embedding is null or empty
*/
public SimpleVectorStoreContent withEmbedding(float[] embedding) {
Assert.notNull(embedding, "embedding must not be null");
Assert.isTrue(embedding.length > 0, "embedding vector must not be empty");
return new SimpleVectorStoreContent(this.id, this.content, this.metadata, embedding);
}

public String getId() {
return this.id;
}

@Override
public String getContent() {
return this.content;
}

@Override
public Map<String, Object> getMetadata() {
return this.metadata;
}

/**
* Returns a defensive copy of the embedding vector.
* @return a new array containing the embedding vector
*/
public float[] getEmbedding() {
return Arrays.copyOf(this.embedding, this.embedding.length);
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;

SimpleVectorStoreContent that = (SimpleVectorStoreContent) o;
return Objects.equals(this.id, that.id) && Objects.equals(this.content, that.content)
&& Objects.equals(this.metadata, that.metadata) && Arrays.equals(this.embedding, that.embedding);
}

@Override
public int hashCode() {
int result = Objects.hashCode(this.id);
result = 31 * result + Objects.hashCode(this.content);
result = 31 * result + Objects.hashCode(this.metadata);
result = 31 * result + Arrays.hashCode(this.embedding);
return result;
}

@Override
public String toString() {
return "SimpleVectorStoreContent{" + "id='" + this.id + '\'' + ", content='" + this.content + '\''
+ ", metadata=" + this.metadata + ", embedding=" + Arrays.toString(embedding) + '}';
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vectorstore;

import java.util.HashMap;
import java.util.Map;

import org.junit.Test;

import org.springframework.ai.document.Document;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Ilayaperumal Gopinathan
*/
public class SimpleVectorStoreSimilarityTests {

@Test
public void testSimilarity() {
Map<String, Object> metadata = new HashMap<>();
metadata.put("foo", "bar");
float[] testEmbedding = new float[] { 1.0f, 2.0f, 3.0f };

SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata,
testEmbedding);
SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d);
Document document = similarity.getDocument();
assertThat(document).isNotNull();
assertThat(document.getId()).isEqualTo("1");
assertThat(document.getContent()).isEqualTo("hello, how are you?");
assertThat(document.getMetadata().get("foo")).isEqualTo("bar");
}

}
Loading

0 comments on commit 9d207e3

Please sign in to comment.