refactor(knowledge): 加入数据库,优化代码结构

This commit is contained in:
v-zhangjc9
2025-05-22 18:10:44 +08:00
parent 907d2826a4
commit 0d7d009be2
12 changed files with 385 additions and 44 deletions

View File

@@ -131,6 +131,11 @@
<artifactId>hutool-all</artifactId>
<version>${hutool.version}</version>
</dependency>
<dependency>
<groupId>com.yomahub</groupId>
<artifactId>liteflow-spring-boot-starter</artifactId>
<version>2.13.2</version>
</dependency>
</dependencies>
</dependencyManagement>

View File

@@ -28,12 +28,24 @@
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-qdrant-store</artifactId>
<artifactId>spring-ai-starter-vector-store-qdrant</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-markdown-document-reader</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
</dependency>
<dependency>
<groupId>com.yomahub</groupId>
<artifactId>liteflow-spring-boot-starter</artifactId>
</dependency>
</dependencies>
<build>

View File

@@ -1,9 +1,10 @@
package com.lanyuanxiaoyao.service.ai.knowledge.controller;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.CollectionVO;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.PointVO;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.KnowledgeVO;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.PointVO;
import com.lanyuanxiaoyao.service.ai.knowledge.reader.TextLineReader;
import com.lanyuanxiaoyao.service.ai.knowledge.service.KnowledgeService;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.Points;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ExecutionException;
@@ -13,6 +14,7 @@ import org.eclipse.collections.api.list.ImmutableList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig;
import org.springframework.ai.vectorstore.VectorStore;
@@ -34,10 +36,12 @@ import org.springframework.web.bind.annotation.RestController;
public class KnowledgeController {
private static final Logger logger = LoggerFactory.getLogger(KnowledgeController.class);
private final KnowledgeService knowledgeService;
private final QdrantClient client;
private final EmbeddingModel embeddingModel;
public KnowledgeController(VectorStore vectorStore, EmbeddingModel embeddingModel) {
public KnowledgeController(KnowledgeService knowledgeService, VectorStore vectorStore, EmbeddingModel embeddingModel) {
this.knowledgeService = knowledgeService;
client = (QdrantClient) vectorStore.getNativeClient().orElseThrow();
this.embeddingModel = embeddingModel;
}
@@ -47,39 +51,12 @@ public class KnowledgeController {
@RequestParam("name") String name,
@RequestParam("strategy") String strategy
) throws ExecutionException, InterruptedException {
logger.info("Enter method: add[name, strategy]. name:{},strategy:{}", name, strategy);
client.createCollectionAsync(
name,
Collections.VectorParams.newBuilder()
.setDistance(Collections.Distance.valueOf(strategy))
.setSize(embeddingModel.dimensions())
.build()
).get();
knowledgeService.add(name, strategy);
}
@GetMapping("list")
public ImmutableList<CollectionVO> list() throws ExecutionException, InterruptedException {
return client.listCollectionsAsync()
.get()
.stream()
.collect(Collectors.toCollection(Lists.mutable::empty))
.collect(name -> {
try {
Collections.CollectionInfo info = client.getCollectionInfoAsync(name).get();
CollectionVO vo = new CollectionVO();
vo.setName(name);
vo.setPoints(info.getPointsCount());
vo.setSegments(info.getSegmentsCount());
vo.setStatus(info.getStatus().name());
Collections.VectorParams vectorParams = info.getConfig().getParams().getVectorsConfig().getParams();
vo.setStrategy(vectorParams.getDistance().name());
vo.setSize(vectorParams.getSize());
return vo;
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
})
.toImmutable();
public ImmutableList<KnowledgeVO> list() {
return knowledgeService.list();
}
@GetMapping("list_points")
@@ -107,16 +84,27 @@ public class KnowledgeController {
@GetMapping("delete")
public void delete(@RequestParam("name") String name) throws ExecutionException, InterruptedException {
client.deleteCollectionAsync(name).get();
knowledgeService.remove(name);
}
@PostMapping(value = "preview_text", consumes = "text/plain;charset=utf-8")
public ImmutableList<String> previewText(
@PostMapping("preview_text")
public ImmutableList<PointVO> previewText(
@RequestParam("name") String name,
@RequestParam(value = "mode", defaultValue = "normal") String mode,
@RequestBody String text
@RequestParam(value = "type", defaultValue = "text") String type,
@RequestParam("content") String content
) {
return Lists.immutable.empty();
TextReader reader = new TextLineReader(new ByteArrayResource(content.getBytes(StandardCharsets.UTF_8)));
return reader.get()
.stream()
.collect(Collectors.toCollection(Lists.mutable::empty))
.collect(doc -> {
PointVO vo = new PointVO();
vo.setId(doc.getId());
vo.setText(doc.getText());
return vo;
})
.toImmutable();
}
@PostMapping(value = "process_text", consumes = "text/plain;charset=utf-8")

View File

@@ -0,0 +1,54 @@
package com.lanyuanxiaoyao.service.ai.knowledge.entity;
/**
* @author lanyuanxiaoyao
* @version 20250522
*/
public class Knowledge {
private Long id;
private Long vectorSourceId;
private String name;
private String strategy;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public Long getVectorSourceId() {
return vectorSourceId;
}
public void setVectorSourceId(Long vectorSourceId) {
this.vectorSourceId = vectorSourceId;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getStrategy() {
return strategy;
}
public void setStrategy(String strategy) {
this.strategy = strategy;
}
@Override
public String toString() {
return "Knowledge{" +
"id=" + id +
", vectorSourceId=" + vectorSourceId +
", name='" + name + '\'' +
", strategy='" + strategy + '\'' +
'}';
}
}

View File

@@ -1,10 +1,10 @@
package com.lanyuanxiaoyao.service.ai.knowledge.entity;
package com.lanyuanxiaoyao.service.ai.knowledge.entity.vo;
/**
* @author lanyuanxiaoyao
* @version 20250516
*/
public class CollectionVO {
public class KnowledgeVO {
private String name;
private String strategy;
private Long size;

View File

@@ -1,4 +1,4 @@
package com.lanyuanxiaoyao.service.ai.knowledge.entity;
package com.lanyuanxiaoyao.service.ai.knowledge.entity.vo;
/**
* @author lanyuanxiaoyao

View File

@@ -0,0 +1,34 @@
package com.lanyuanxiaoyao.service.ai.knowledge.reader;
import cn.hutool.core.util.StrUtil;
import java.util.List;
import java.util.stream.Stream;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.TextReader;
import org.springframework.core.io.Resource;
/**
* @author lanyuanxiaoyao
* @version 20250522
*/
public class TextLineReader extends TextReader {
public TextLineReader(Resource resource) {
super(resource);
}
@Override
public List<Document> get() {
return super.get()
.stream()
.flatMap(doc -> {
String text = doc.getText();
if (StrUtil.isBlank(text)) {
return Stream.of(doc);
}
return Stream.of(text.split("\n\n"))
.filter(StrUtil::isNotBlank)
.map(line -> new Document(line, doc.getMetadata()));
})
.toList();
}
}

View File

@@ -0,0 +1,16 @@
package com.lanyuanxiaoyao.service.ai.knowledge.service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
/**
* @author lanyuanxiaoyao
* @version 20250522
*/
@Service
public class EmbeddingService {
private static final Logger logger = LoggerFactory.getLogger(EmbeddingService.class);
}

View File

@@ -0,0 +1,58 @@
package com.lanyuanxiaoyao.service.ai.knowledge.service;
import club.kingon.sql.builder.SqlBuilder;
import cn.hutool.core.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
/**
* @author lanyuanxiaoyao
* @version 20250522
*/
@Service
public class KnowledgeGroupService {
private static final Logger logger = LoggerFactory.getLogger(KnowledgeGroupService.class);
private static final String GROUP_TABLE_NAME = "service_ai_group";
private final JdbcTemplate template;
public KnowledgeGroupService(JdbcTemplate template) {
this.template = template;
}
@Transactional(rollbackFor = Exception.class)
public void add(Long knowledgeId, String name) {
template.update(
SqlBuilder.insertInto(GROUP_TABLE_NAME, "id", "knowledge_id", "name")
.values()
.addValue("?", "?", "?")
.precompileSql(),
IdUtil.getSnowflakeNextId(),
knowledgeId,
name
);
}
@Transactional(rollbackFor = Exception.class)
public void remove(Long groupId) {
template.update(
SqlBuilder.delete(GROUP_TABLE_NAME)
.whereEq("id", "?")
.precompileSql(),
groupId
);
}
@Transactional(rollbackFor = Exception.class)
public void removeByKnowledgeId(Long knowledgeId) {
template.update(
SqlBuilder.delete(GROUP_TABLE_NAME)
.whereEq("knowledge_id", "?")
.precompileSql(),
knowledgeId
);
}
}

View File

@@ -0,0 +1,151 @@
package com.lanyuanxiaoyao.service.ai.knowledge.service;
import club.kingon.sql.builder.SqlBuilder;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.Knowledge;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.KnowledgeVO;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.grpc.Collections;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.api.list.ImmutableList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
/**
* @author lanyuanxiaoyao
* @version 20250522
*/
@Service
public class KnowledgeService {
private static final Logger logger = LoggerFactory.getLogger(KnowledgeService.class);
private static final String KNOWLEDGE_TABLE_NAME = "service_ai_knowledge";
private final JdbcTemplate template;
private final EmbeddingModel embeddingModel;
private final QdrantClient client;
private final KnowledgeGroupService knowledgeGroupService;
public KnowledgeService(JdbcTemplate template, EmbeddingModel embeddingModel, VectorStore vectorStore, KnowledgeGroupService knowledgeGroupService) {
this.template = template;
this.embeddingModel = embeddingModel;
this.client = (QdrantClient) vectorStore.getNativeClient().orElseThrow();
this.knowledgeGroupService = knowledgeGroupService;
}
public Knowledge get(Long id) {
return template.queryForObject(
SqlBuilder.select("id", "vector_source_id", "name", "strategy")
.from(KNOWLEDGE_TABLE_NAME)
.whereEq("id", "?")
.precompileSql(),
Knowledge.class,
id
);
}
public Knowledge get(String name) {
return template.queryForObject(
SqlBuilder.select("id", "vector_source_id", "name", "strategy")
.from(KNOWLEDGE_TABLE_NAME)
.whereEq("name", "?")
.precompileSql(),
Knowledge.class,
name
);
}
@Transactional(rollbackFor = Exception.class)
public void add(String name, String strategy) throws ExecutionException, InterruptedException {
Integer count = template.queryForObject(
SqlBuilder.select("count(*)")
.from(KNOWLEDGE_TABLE_NAME)
.whereEq("name", "?")
.precompileSql(),
Integer.class,
name
);
if (count > 0) {
throw new RuntimeException("名称已存在");
}
long id = IdUtil.getSnowflakeNextId();
long vectorSourceId = IdUtil.getSnowflakeNextId();
template.update(
SqlBuilder.insertInto(KNOWLEDGE_TABLE_NAME, "id", "vector_source_id", "name", "strategy")
.values()
.addValue("?", "?", "?", "?")
.precompileSql(),
id,
vectorSourceId,
name,
strategy
);
client.createCollectionAsync(
String.valueOf(vectorSourceId),
Collections.VectorParams.newBuilder()
.setDistance(Collections.Distance.valueOf(strategy))
.setSize(embeddingModel.dimensions())
.build()
).get();
}
public ImmutableList<KnowledgeVO> list() {
return template.query(
SqlBuilder.select("id", "vector_source_id", "name", "strategy")
.from(KNOWLEDGE_TABLE_NAME)
.build(),
(rs, index) -> {
Knowledge knowledge = new Knowledge();
knowledge.setId(rs.getLong(1));
knowledge.setVectorSourceId(rs.getLong(2));
knowledge.setName(rs.getString(3));
knowledge.setStrategy(rs.getString(4));
return knowledge;
}
)
.stream()
.map(knowledge -> {
try {
Collections.CollectionInfo info = client.getCollectionInfoAsync(String.valueOf(knowledge.getVectorSourceId())).get();
KnowledgeVO vo = new KnowledgeVO();
vo.setName(knowledge.getName());
vo.setPoints(info.getPointsCount());
vo.setSegments(info.getSegmentsCount());
vo.setStatus(info.getStatus().name());
Collections.VectorParams vectorParams = info.getConfig().getParams().getVectorsConfig().getParams();
vo.setStrategy(vectorParams.getDistance().name());
vo.setSize(vectorParams.getSize());
return vo;
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toCollection(Lists.mutable::empty))
.toImmutable();
}
@Transactional(rollbackFor = Exception.class)
public void remove(String name) throws ExecutionException, InterruptedException {
Knowledge knowledge = get(name);
if (ObjectUtil.isNull(knowledge)) {
throw new RuntimeException(StrUtil.format("{} 不存在"));
}
template.update(
SqlBuilder.delete(KNOWLEDGE_TABLE_NAME)
.whereEq("id", "?")
.precompileSql(),
knowledge.getId()
);
knowledgeGroupService.removeByKnowledgeId(knowledge.getId());
client.deleteCollectionAsync(String.valueOf(knowledge.getVectorSourceId())).get();
}
}

View File

@@ -17,6 +17,11 @@ spring:
hostname: localhost
hostname_full: localhost
start_time: 20250514112750
datasource:
url: jdbc:mysql://localhost:3307/ai?useSSL=false
username: test
password: test
driver-class-name: com.mysql.cj.jdbc.Driver
security:
meta:
authority: ENC(GXKnbq1LS11U2HaONspvH+D/TkIx13aWTaokdkzaF7HSvq6Z0Rv1+JUWFnYopVXu)
@@ -39,4 +44,8 @@ jasypt:
encryptor:
password: 'r#(R,P"Dp^A47>WSn:Wn].gs/+"v:q_Q*An~zF*g-@j@jtSTv5H/,S-3:R?r9R}.'
server:
port: 8080
port: 8080
liteflow:
rule-source: config/flow.xml
print-banner: false
check-node-exists: false

View File

@@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<flow>
<chain name="embedding">
SER(
embedding_start,
SWITCH(embedding_mode_switch).TO(
normal_embedding,
llm_embedding,
qa_embedding
),
embedding_finish
);
</chain>
</flow>