@@ -9,7 +9,6 @@ import com.lanyuanxiaoyao.service.ai.web.entity.vo.KnowledgeVO;
import com.lanyuanxiaoyao.service.common.Constants ;
import io.qdrant.client.QdrantClient ;
import io.qdrant.client.grpc.Collections ;
import java.io.IOException ;
import java.util.List ;
import java.util.concurrent.ExecutionException ;
import java.util.stream.Collectors ;
@@ -32,14 +31,16 @@ import org.springframework.transaction.annotation.Transactional;
@Service
public class KnowledgeBaseService {
public static final String KNOWLEDGE_TABLE_NAME = Constants . DATABASE_NAME + " .service_ai_knowledge " ;
public static final String [ ] KNOWLEDGE_COLUMNS = new String [ ] { " id " , " vector_source_id " , " name " , " description " , " strategy " , " created_time " , " modified_time " } ;
private static final RowMapper < Knowledge > knowledgeMapper = ( rs , row ) - > {
Knowledge knowledge = new Knowledge ( ) ;
knowledge . setId ( rs . getLong ( 1 ) ) ;
knowledge . setVectorSourceId ( rs . getLong ( 2 ) ) ;
knowledge . setName ( rs . getString ( 3 ) ) ;
knowledge . setStrategy ( rs . getString ( 4 ) ) ;
knowledge . setCreatedTime ( rs . getTimestamp ( 5 ) . getTime ( ) ) ;
knowledge . setModifi edTime ( rs . getTimestamp ( 6 ) . getTime ( ) ) ;
knowledge . setDescription ( rs . getString ( 4 ) ) ;
knowledge . setStrategy ( rs . getString ( 5 ) ) ;
knowledge . setCreat edTime ( rs . getTimestamp ( 6 ) . getTime ( ) ) ;
knowledge . setModifiedTime ( rs . getTimestamp ( 7 ) . getTime ( ) ) ;
return knowledge ;
} ;
private final JdbcTemplate template ;
@@ -56,7 +57,7 @@ public class KnowledgeBaseService {
public Knowledge get ( Long id ) {
return template . queryForObject (
SqlBuilder . select ( " id " , " vector_source_id " , " name " , " strategy " , " created_time " , " modified_time " )
SqlBuilder . select ( KNOWLEDGE_COLUMNS )
. from ( KNOWLEDGE_TABLE_NAME )
. whereEq ( " id " , " ? " )
. precompileSql ( ) ,
@@ -66,7 +67,7 @@ public class KnowledgeBaseService {
}
@Transactional ( rollbackFor = Exception . class )
public void add ( String name , String strategy ) throws ExecutionException , InterruptedException {
public void add ( String name , String description , String strategy ) throws ExecutionException , InterruptedException {
Integer count = template . queryForObject (
SqlBuilder . select ( " count(*) " )
. from ( KNOWLEDGE_TABLE_NAME )
@@ -82,13 +83,14 @@ public class KnowledgeBaseService {
long id = SnowflakeId . next ( ) ;
long vectorSourceId = SnowflakeId . next ( ) ;
template . update (
SqlBuilder . insertInto ( KNOWLEDGE_TABLE_NAME , " id " , " vector_source_id " , " name " , " strategy " )
SqlBuilder . insertInto ( KNOWLEDGE_TABLE_NAME , " id " , " vector_source_id " , " name " , " description " , " strategy" )
. values ( )
. addValue ( " ? " , " ? " , " ? " , " ? " )
. addValue ( " ? " , " ? " , " ? " , " ? " , " ? " )
. precompileSql ( ) ,
id ,
vectorSourceId ,
name ,
description ,
strategy
) ;
client . createCollectionAsync (
@@ -100,6 +102,18 @@ public class KnowledgeBaseService {
) . get ( ) ;
}
@Transactional ( rollbackFor = Exception . class )
public void updateDescription ( Long id , String description ) {
template . update (
SqlBuilder . update ( KNOWLEDGE_TABLE_NAME )
. set ( " description " , " ? " )
. whereEq ( " id " , " ? " )
. precompileSql ( ) ,
description ,
id
) ;
}
public String getName ( Long id ) {
return template . queryForObject (
SqlBuilder . select ( " name " )
@@ -113,7 +127,7 @@ public class KnowledgeBaseService {
public ImmutableList < KnowledgeVO > list ( ) {
return template . query (
SqlBuilder . select ( " id " , " vector_source_id " , " name " , " strategy " , " created_time " , " modified_time " )
SqlBuilder . select ( KNOWLEDGE_COLUMNS )
. from ( KNOWLEDGE_TABLE_NAME )
. orderByDesc ( " created_time " )
. build ( ) ,
@@ -127,6 +141,7 @@ public class KnowledgeBaseService {
vo . setId ( knowledge . getId ( ) ) ;
vo . setVectorSourceId ( knowledge . getVectorSourceId ( ) ) ;
vo . setName ( knowledge . getName ( ) ) ;
vo . setDescription ( knowledge . getDescription ( ) ) ;
vo . setPoints ( info . getPointsCount ( ) ) ;
vo . setSegments ( info . getSegmentsCount ( ) ) ;
vo . setStatus ( info . getStatus ( ) . name ( ) ) ;
@@ -165,7 +180,7 @@ public class KnowledgeBaseService {
String text ,
Integer limit ,
Double threshold
) throws ExecutionException , InterruptedException , IOException {
) throws ExecutionException , InterruptedException {
Knowledge knowledge = get ( id ) ;
Boolean exists = client . collectionExistsAsync ( String . valueOf ( knowledge . getVectorSourceId ( ) ) ) . get ( ) ;
if ( ! exists ) {
@@ -182,13 +197,6 @@ public class KnowledgeBaseService {
. similarityThreshold ( threshold )
. build ( )
) ;
// 如果只是一个知识库的话, 似乎没有什么rerank的必要...
/* List<org.noear.solon.ai.rag.Document> rerankDocuments = rerankingModel.rerank(
text,
documents.stream()
.map(doc -> new org.noear.solon.ai.rag.Document(doc.getId(), doc.getText(), doc.getMetadata(), doc.getScore()))
.toList()
); */
return Lists . immutable . ofAll ( documents )
. collect ( Document : : getText ) ;
}