feat(knowledge): 增加rerank模型适配

This commit is contained in:
v-zhangjc9
2025-06-04 17:43:22 +08:00
parent 4124a8a851
commit c4d5a7b300
10 changed files with 165 additions and 29 deletions

View File

@@ -25,6 +25,7 @@
<spring-boot.version>3.4.3</spring-boot.version> <spring-boot.version>3.4.3</spring-boot.version>
<spring-cloud.version>2024.0.1</spring-cloud.version> <spring-cloud.version>2024.0.1</spring-cloud.version>
<spring-ai.version>1.0.0</spring-ai.version> <spring-ai.version>1.0.0</spring-ai.version>
<solon-ai.version>3.3.1</solon-ai.version>
<eclipse-collections.version>11.1.0</eclipse-collections.version> <eclipse-collections.version>11.1.0</eclipse-collections.version>
<curator.version>5.1.0</curator.version> <curator.version>5.1.0</curator.version>
<hutool.version>5.8.27</hutool.version> <hutool.version>5.8.27</hutool.version>
@@ -136,6 +137,16 @@
<artifactId>liteflow-spring-boot-starter</artifactId> <artifactId>liteflow-spring-boot-starter</artifactId>
<version>2.13.2</version> <version>2.13.2</version>
</dependency> </dependency>
<dependency>
<groupId>org.noear</groupId>
<artifactId>solon-ai</artifactId>
<version>${solon-ai.version}</version>
</dependency>
<dependency>
<groupId>org.noear</groupId>
<artifactId>solon-ai-dialect-openai</artifactId>
<version>${solon-ai.version}</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>

View File

@@ -54,6 +54,14 @@
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pdf-document-reader</artifactId> <artifactId>spring-ai-pdf-document-reader</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.noear</groupId>
<artifactId>solon-ai</artifactId>
</dependency>
<dependency>
<groupId>org.noear</groupId>
<artifactId>solon-ai-dialect-openai</artifactId>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@@ -0,0 +1,21 @@
package com.lanyuanxiaoyao.service.ai.knowledge.configuration;
import cn.hutool.core.util.StrUtil;
import org.noear.solon.ai.reranking.RerankingModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* @author lanyuanxiaoyao
* @version 20250604
*/
@Configuration
public class SolonConfiguration {
@Bean
public RerankingModel rerankingModel(SolonProperties solonProperties) {
return RerankingModel.of(StrUtil.format("{}{}", solonProperties.getBaseUrl(), solonProperties.getRerank().getEndpoint()))
.apiKey(solonProperties.getApiKey())
.model(solonProperties.getRerank().getModel())
.build();
}
}

View File

@@ -0,0 +1,78 @@
package com.lanyuanxiaoyao.service.ai.knowledge.configuration;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
/**
* @author lanyuanxiaoyao
* @version 20250604
*/
@Configuration
@ConfigurationProperties(prefix = "solon")
public class SolonProperties {
private String baseUrl;
private String apiKey;
private Rerank rerank;
public String getBaseUrl() {
return baseUrl;
}
public void setBaseUrl(String baseUrl) {
this.baseUrl = baseUrl;
}
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
public Rerank getRerank() {
return rerank;
}
public void setRerank(Rerank rerank) {
this.rerank = rerank;
}
@Override
public String toString() {
return "SolonProperties{" +
"baseUrl='" + baseUrl + '\'' +
", apiKey='" + apiKey + '\'' +
", rerank=" + rerank +
'}';
}
public static final class Rerank {
private String model;
private String endpoint;
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public String getEndpoint() {
return endpoint;
}
public void setEndpoint(String endpoint) {
this.endpoint = endpoint;
}
@Override
public String toString() {
return "Rerank{" +
"model='" + model + '\'' +
", endpoint='" + endpoint + '\'' +
'}';
}
}
}

View File

@@ -5,7 +5,8 @@ import com.lanyuanxiaoyao.service.ai.core.entity.amis.AmisMapResponse;
import com.lanyuanxiaoyao.service.ai.core.entity.amis.AmisResponse; import com.lanyuanxiaoyao.service.ai.core.entity.amis.AmisResponse;
import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.SegmentVO; import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.SegmentVO;
import com.lanyuanxiaoyao.service.ai.knowledge.service.EmbeddingService; import com.lanyuanxiaoyao.service.ai.knowledge.service.EmbeddingService;
import com.lanyuanxiaoyao.service.ai.knowledge.service.KnowledgeService; import com.lanyuanxiaoyao.service.ai.knowledge.service.KnowledgeBaseService;
import java.io.IOException;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import org.eclipse.collections.api.factory.Lists; import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.api.list.ImmutableList; import org.eclipse.collections.api.list.ImmutableList;
@@ -24,14 +25,14 @@ import org.springframework.web.bind.annotation.RestController;
*/ */
@RestController @RestController
@RequestMapping("knowledge") @RequestMapping("knowledge")
public class KnowledgeController { public class KnowledgeBaseController {
private static final Logger logger = LoggerFactory.getLogger(KnowledgeController.class); private static final Logger logger = LoggerFactory.getLogger(KnowledgeBaseController.class);
private final KnowledgeService knowledgeService; private final KnowledgeBaseService knowledgeBaseService;
private final EmbeddingService embeddingService; private final EmbeddingService embeddingService;
public KnowledgeController(KnowledgeService knowledgeService, EmbeddingService embeddingService) { public KnowledgeBaseController(KnowledgeBaseService knowledgeBaseService, EmbeddingService embeddingService) {
this.knowledgeService = knowledgeService; this.knowledgeBaseService = knowledgeBaseService;
this.embeddingService = embeddingService; this.embeddingService = embeddingService;
} }
@@ -40,23 +41,23 @@ public class KnowledgeController {
@RequestParam("name") String name, @RequestParam("name") String name,
@RequestParam("strategy") String strategy @RequestParam("strategy") String strategy
) throws ExecutionException, InterruptedException { ) throws ExecutionException, InterruptedException {
knowledgeService.add(name, strategy); knowledgeBaseService.add(name, strategy);
} }
@GetMapping("name") @GetMapping("name")
public AmisMapResponse name(@RequestParam("id") Long id) { public AmisMapResponse name(@RequestParam("id") Long id) {
return AmisResponse.responseMapData() return AmisResponse.responseMapData()
.setData("name", knowledgeService.getName(id)); .setData("name", knowledgeBaseService.getName(id));
} }
@GetMapping("list") @GetMapping("list")
public AmisResponse<?> list() { public AmisResponse<?> list() {
return AmisResponse.responseCrudData(knowledgeService.list()); return AmisResponse.responseCrudData(knowledgeBaseService.list());
} }
@GetMapping("delete") @GetMapping("delete")
public void delete(@RequestParam("id") Long id) throws ExecutionException, InterruptedException { public void delete(@RequestParam("id") Long id) throws ExecutionException, InterruptedException {
knowledgeService.remove(id); knowledgeBaseService.remove(id);
} }
@PostMapping("preview_text") @PostMapping("preview_text")
@@ -114,7 +115,7 @@ public class KnowledgeController {
@RequestParam(value = "limit", defaultValue = "5") Integer limit, @RequestParam(value = "limit", defaultValue = "5") Integer limit,
@RequestParam(value = "threshold", defaultValue = "0.6") Double threshold, @RequestParam(value = "threshold", defaultValue = "0.6") Double threshold,
@RequestBody String text @RequestBody String text
) throws ExecutionException, InterruptedException { ) throws ExecutionException, InterruptedException, IOException {
return knowledgeService.query(id, text, limit, threshold); return knowledgeBaseService.query(id, text, limit, threshold);
} }
} }

View File

@@ -29,15 +29,15 @@ public class EmbeddingService {
private final DataFileService dataFileService; private final DataFileService dataFileService;
private final FlowExecutor executor; private final FlowExecutor executor;
private final KnowledgeService knowledgeService; private final KnowledgeBaseService knowledgeBaseService;
private final GroupService groupService; private final GroupService groupService;
private final ExecutorService executors = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); private final ExecutorService executors = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
@SuppressWarnings("SpringJavaInjectionPointsAutowiringInspection") @SuppressWarnings("SpringJavaInjectionPointsAutowiringInspection")
public EmbeddingService(DataFileService dataFileService, FlowExecutor executor, KnowledgeService knowledgeService, GroupService groupService) { public EmbeddingService(DataFileService dataFileService, FlowExecutor executor, KnowledgeBaseService knowledgeBaseService, GroupService groupService) {
this.dataFileService = dataFileService; this.dataFileService = dataFileService;
this.executor = executor; this.executor = executor;
this.knowledgeService = knowledgeService; this.knowledgeBaseService = knowledgeBaseService;
this.groupService = groupService; this.groupService = groupService;
} }
@@ -63,7 +63,7 @@ public class EmbeddingService {
public void submit(Long id, String mode, String content) { public void submit(Long id, String mode, String content) {
executors.submit(() -> { executors.submit(() -> {
Knowledge knowledge = knowledgeService.get(id); Knowledge knowledge = knowledgeBaseService.get(id);
Long groupId = groupService.add(knowledge.getId(), StrUtil.format("文本-{}", IdUtil.nanoId(10))); Long groupId = groupService.add(knowledge.getId(), StrUtil.format("文本-{}", IdUtil.nanoId(10)));
EmbeddingContext context = EmbeddingContext.builder() EmbeddingContext context = EmbeddingContext.builder()
.vectorSourceId(knowledge.getVectorSourceId()) .vectorSourceId(knowledge.getVectorSourceId())
@@ -80,7 +80,7 @@ public class EmbeddingService {
public void submit(Long id, String mode, ImmutableList<String> ids) { public void submit(Long id, String mode, ImmutableList<String> ids) {
executors.submit(() -> { executors.submit(() -> {
Knowledge knowledge = knowledgeService.get(id); Knowledge knowledge = knowledgeBaseService.get(id);
List<Pair<Long, DataFileVO>> vos = Lists.mutable.empty(); List<Pair<Long, DataFileVO>> vos = Lists.mutable.empty();
for (String fileId : ids) { for (String fileId : ids) {
DataFileVO vo = dataFileService.downloadFile(Long.parseLong(fileId)); DataFileVO vo = dataFileService.downloadFile(Long.parseLong(fileId));

View File

@@ -102,7 +102,7 @@ public class GroupService {
public void remove(Long groupId) throws ExecutionException, InterruptedException { public void remove(Long groupId) throws ExecutionException, InterruptedException {
Long vectorSourceId = template.queryForObject( Long vectorSourceId = template.queryForObject(
SqlBuilder.select("k.vector_source_id") SqlBuilder.select("k.vector_source_id")
.from(Alias.of(GROUP_TABLE_NAME, "g"), Alias.of(KnowledgeService.KNOWLEDGE_TABLE_NAME, "k")) .from(Alias.of(GROUP_TABLE_NAME, "g"), Alias.of(KnowledgeBaseService.KNOWLEDGE_TABLE_NAME, "k"))
.whereEq("g.knowledge_id", Column.as("k.id")) .whereEq("g.knowledge_id", Column.as("k.id"))
.andEq("g.id", groupId) .andEq("g.id", groupId)
.precompileSql(), .precompileSql(),

View File

@@ -9,11 +9,13 @@ import com.lanyuanxiaoyao.service.ai.knowledge.entity.vo.KnowledgeVO;
import com.lanyuanxiaoyao.service.common.Constants; import com.lanyuanxiaoyao.service.common.Constants;
import io.qdrant.client.QdrantClient; import io.qdrant.client.QdrantClient;
import io.qdrant.client.grpc.Collections; import io.qdrant.client.grpc.Collections;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.eclipse.collections.api.factory.Lists; import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.api.list.ImmutableList; import org.eclipse.collections.api.list.ImmutableList;
import org.noear.solon.ai.reranking.RerankingModel;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
@@ -31,9 +33,9 @@ import org.springframework.transaction.annotation.Transactional;
* @version 20250522 * @version 20250522
*/ */
@Service @Service
public class KnowledgeService { public class KnowledgeBaseService {
public static final String KNOWLEDGE_TABLE_NAME = Constants.DATABASE_NAME + ".service_ai_knowledge"; public static final String KNOWLEDGE_TABLE_NAME = Constants.DATABASE_NAME + ".service_ai_knowledge";
private static final Logger logger = LoggerFactory.getLogger(KnowledgeService.class); private static final Logger logger = LoggerFactory.getLogger(KnowledgeBaseService.class);
private static final RowMapper<Knowledge> knowledgeMapper = (rs, row) -> { private static final RowMapper<Knowledge> knowledgeMapper = (rs, row) -> {
Knowledge knowledge = new Knowledge(); Knowledge knowledge = new Knowledge();
knowledge.setId(rs.getLong(1)); knowledge.setId(rs.getLong(1));
@@ -48,12 +50,14 @@ public class KnowledgeService {
private final EmbeddingModel model; private final EmbeddingModel model;
private final QdrantClient client; private final QdrantClient client;
private final GroupService groupService; private final GroupService groupService;
private final RerankingModel rerankingModel;
public KnowledgeService(JdbcTemplate template, EmbeddingModel model, VectorStore vectorStore, GroupService groupService) { public KnowledgeBaseService(JdbcTemplate template, EmbeddingModel model, VectorStore vectorStore, GroupService groupService, RerankingModel rerankingModel) {
this.template = template; this.template = template;
this.model = model; this.model = model;
this.client = (QdrantClient) vectorStore.getNativeClient().orElseThrow(); this.client = (QdrantClient) vectorStore.getNativeClient().orElseThrow();
this.groupService = groupService; this.groupService = groupService;
this.rerankingModel = rerankingModel;
} }
public Knowledge get(Long id) { public Knowledge get(Long id) {
@@ -166,7 +170,8 @@ public class KnowledgeService {
Long id, Long id,
String text, String text,
Integer limit, Integer limit,
Double threshold) throws ExecutionException, InterruptedException { Double threshold
) throws ExecutionException, InterruptedException, IOException {
Knowledge knowledge = get(id); Knowledge knowledge = get(id);
Boolean exists = client.collectionExistsAsync(String.valueOf(knowledge.getVectorSourceId())).get(); Boolean exists = client.collectionExistsAsync(String.valueOf(knowledge.getVectorSourceId())).get();
if (!exists) { if (!exists) {
@@ -183,7 +188,13 @@ public class KnowledgeService {
.similarityThreshold(threshold) .similarityThreshold(threshold)
.build() .build()
); );
return Lists.immutable.ofAll(documents) List<org.noear.solon.ai.rag.Document> rerankDocuments = rerankingModel.rerank(
.collect(Document::getText); text,
documents.stream()
.map(doc -> new org.noear.solon.ai.rag.Document(doc.getId(), doc.getText(), doc.getMetadata(), doc.getScore()))
.toList()
);
return Lists.immutable.ofAll(rerankDocuments)
.collect(org.noear.solon.ai.rag.Document::getContent);
} }
} }

View File

@@ -23,16 +23,16 @@ import org.springframework.stereotype.Service;
public class SegmentService { public class SegmentService {
private static final Logger logger = LoggerFactory.getLogger(SegmentService.class); private static final Logger logger = LoggerFactory.getLogger(SegmentService.class);
private final KnowledgeService knowledgeService; private final KnowledgeBaseService knowledgeBaseService;
private final QdrantClient client; private final QdrantClient client;
public SegmentService(KnowledgeService knowledgeService, VectorStore vectorStore) { public SegmentService(KnowledgeBaseService knowledgeBaseService, VectorStore vectorStore) {
this.knowledgeService = knowledgeService; this.knowledgeBaseService = knowledgeBaseService;
this.client = (QdrantClient) vectorStore.getNativeClient().orElseThrow(); this.client = (QdrantClient) vectorStore.getNativeClient().orElseThrow();
} }
public ImmutableList<SegmentVO> list(Long id, Long groupId) throws ExecutionException, InterruptedException { public ImmutableList<SegmentVO> list(Long id, Long groupId) throws ExecutionException, InterruptedException {
Knowledge knowledge = knowledgeService.get(id); Knowledge knowledge = knowledgeBaseService.get(id);
Points.ScrollResponse response = client.scrollAsync( Points.ScrollResponse response = client.scrollAsync(
Points.ScrollPoints.newBuilder() Points.ScrollPoints.newBuilder()
.setCollectionName(String.valueOf(knowledge.getVectorSourceId())) .setCollectionName(String.valueOf(knowledge.getVectorSourceId()))
@@ -59,7 +59,7 @@ public class SegmentService {
} }
public void remove(Long knowledgeId, Long segmentId) throws ExecutionException, InterruptedException { public void remove(Long knowledgeId, Long segmentId) throws ExecutionException, InterruptedException {
Knowledge knowledge = knowledgeService.get(knowledgeId); Knowledge knowledge = knowledgeBaseService.get(knowledgeId);
client.deletePayloadAsync( client.deletePayloadAsync(
String.valueOf(knowledgeId), String.valueOf(knowledgeId),
List.of(String.valueOf(segmentId)), List.of(String.valueOf(segmentId)),

View File

@@ -22,3 +22,9 @@ liteflow:
rule-source: config/flow.xml rule-source: config/flow.xml
print-banner: false print-banner: false
check-node-exists: false check-node-exists: false
solon:
base-url: http://132.121.206.65:10086
api-key: ENC(K+Hff9QGC+fcyi510VIDd9CaeK/IN5WBJ9rlkUsHEdDgIidW+stHHJlsK0lLPUXXREha+ToQZqqDXJrqSE+GUKCXklFhelD8bRHFXBIeP/ZzT2cxhzgKUXgjw3S0Qw2R)
rerank:
model: 'Bge-reranker-v2-vllm'
endpoint: '/v1/rerank'