feat(chat): 尝试在对话中加入知识库
This commit is contained in:
@@ -8,10 +8,12 @@ import com.lanyuanxiaoyao.service.ai.knowledge.service.EmbeddingService;
|
||||
import com.lanyuanxiaoyao.service.ai.knowledge.service.KnowledgeService;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import org.eclipse.collections.api.factory.Lists;
|
||||
import org.eclipse.collections.api.list.ImmutableList;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
@@ -105,4 +107,14 @@ public class KnowledgeController {
|
||||
throw new IllegalArgumentException("Unsupported type: " + type);
|
||||
}
|
||||
}
|
||||
|
||||
@PostMapping("query")
|
||||
public ImmutableList<String> query(
|
||||
@RequestParam("id") Long id,
|
||||
@RequestParam(value = "limit", defaultValue = "5") Integer limit,
|
||||
@RequestParam(value = "threshold", defaultValue = "0.6") Double threshold,
|
||||
@RequestBody String text
|
||||
) throws ExecutionException, InterruptedException {
|
||||
return knowledgeService.query(id, text, limit, threshold);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package com.lanyuanxiaoyao.service.ai.knowledge.entity.vo;
|
||||
*/
|
||||
public class KnowledgeVO {
|
||||
private String id;
|
||||
private String vectorSourceId;
|
||||
private String name;
|
||||
private String strategy;
|
||||
private Long size;
|
||||
@@ -23,6 +24,14 @@ public class KnowledgeVO {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getVectorSourceId() {
|
||||
return vectorSourceId;
|
||||
}
|
||||
|
||||
public void setVectorSourceId(String vectorSourceId) {
|
||||
this.vectorSourceId = vectorSourceId;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
@@ -91,6 +100,7 @@ public class KnowledgeVO {
|
||||
public String toString() {
|
||||
return "KnowledgeVO{" +
|
||||
"id='" + id + '\'' +
|
||||
", vectorSourceId='" + vectorSourceId + '\'' +
|
||||
", name='" + name + '\'' +
|
||||
", strategy='" + strategy + '\'' +
|
||||
", size=" + size +
|
||||
|
||||
@@ -109,7 +109,6 @@ public class GroupService {
|
||||
Long.class,
|
||||
groupId
|
||||
);
|
||||
logger.info("Delete {} {}", vectorSourceId, groupId);
|
||||
client.deleteAsync(
|
||||
String.valueOf(vectorSourceId),
|
||||
Points.Filter.newBuilder()
|
||||
|
||||
@@ -9,14 +9,18 @@ import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.KnowledgeVO;
|
||||
import com.lanyuanxiaoyao.service.common.Constants;
|
||||
import io.qdrant.client.QdrantClient;
|
||||
import io.qdrant.client.grpc.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.stream.Collectors;
|
||||
import org.eclipse.collections.api.factory.Lists;
|
||||
import org.eclipse.collections.api.list.ImmutableList;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.jdbc.core.RowMapper;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -41,13 +45,13 @@ public class KnowledgeService {
|
||||
return knowledge;
|
||||
};
|
||||
private final JdbcTemplate template;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final EmbeddingModel model;
|
||||
private final QdrantClient client;
|
||||
private final GroupService groupService;
|
||||
|
||||
public KnowledgeService(JdbcTemplate template, EmbeddingModel embeddingModel, VectorStore vectorStore, GroupService groupService) {
|
||||
public KnowledgeService(JdbcTemplate template, EmbeddingModel model, VectorStore vectorStore, GroupService groupService) {
|
||||
this.template = template;
|
||||
this.embeddingModel = embeddingModel;
|
||||
this.model = model;
|
||||
this.client = (QdrantClient) vectorStore.getNativeClient().orElseThrow();
|
||||
this.groupService = groupService;
|
||||
}
|
||||
@@ -93,7 +97,7 @@ public class KnowledgeService {
|
||||
String.valueOf(vectorSourceId),
|
||||
Collections.VectorParams.newBuilder()
|
||||
.setDistance(Collections.Distance.valueOf(strategy))
|
||||
.setSize(embeddingModel.dimensions())
|
||||
.setSize(model.dimensions())
|
||||
.build()
|
||||
).get();
|
||||
}
|
||||
@@ -123,6 +127,7 @@ public class KnowledgeService {
|
||||
Collections.CollectionInfo info = client.getCollectionInfoAsync(String.valueOf(knowledge.getVectorSourceId())).get();
|
||||
KnowledgeVO vo = new KnowledgeVO();
|
||||
vo.setId(String.valueOf(knowledge.getId()));
|
||||
vo.setVectorSourceId(String.valueOf(knowledge.getVectorSourceId()));
|
||||
vo.setName(knowledge.getName());
|
||||
vo.setPoints(info.getPointsCount());
|
||||
vo.setSegments(info.getSegmentsCount());
|
||||
@@ -156,4 +161,29 @@ public class KnowledgeService {
|
||||
groupService.removeByKnowledgeId(knowledge.getId());
|
||||
client.deleteCollectionAsync(String.valueOf(knowledge.getVectorSourceId())).get();
|
||||
}
|
||||
|
||||
public ImmutableList<String> query(
|
||||
Long id,
|
||||
String text,
|
||||
Integer limit,
|
||||
Double threshold) throws ExecutionException, InterruptedException {
|
||||
Knowledge knowledge = get(id);
|
||||
Boolean exists = client.collectionExistsAsync(String.valueOf(knowledge.getVectorSourceId())).get();
|
||||
if (!exists) {
|
||||
throw new RuntimeException(StrUtil.format("{} not exists", id));
|
||||
}
|
||||
VectorStore vs = QdrantVectorStore.builder(client, model)
|
||||
.collectionName(String.valueOf(knowledge.getVectorSourceId()))
|
||||
.initializeSchema(false)
|
||||
.build();
|
||||
List<Document> documents = vs.similaritySearch(
|
||||
SearchRequest.builder()
|
||||
.query(text)
|
||||
.topK(limit)
|
||||
.similarityThreshold(threshold)
|
||||
.build()
|
||||
);
|
||||
return Lists.immutable.ofAll(documents)
|
||||
.collect(Document::getText);
|
||||
}
|
||||
}
|
||||
@@ -215,8 +215,10 @@ public class EmbeddingNodes {
|
||||
.build();
|
||||
for (Document document : context.getDocuments()) {
|
||||
Map<String, Object> metadata = document.getMetadata();
|
||||
metadata.put("filename", context.getFileFormat());
|
||||
metadata.put("filepath", context.getFile());
|
||||
if (StrUtil.isNotBlank(context.getFileFormat()))
|
||||
metadata.put("filename", context.getFileFormat());
|
||||
if (StrUtil.isNotBlank(context.getFile()))
|
||||
metadata.put("filepath", context.getFile());
|
||||
metadata.put("group_id", String.valueOf(context.getGroupId()));
|
||||
metadata.put("vector_source_id", String.valueOf(context.getVectorSourceId()));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user