Skip to content

Commit d8c0de7

Browse files
authored
✨ new implement : support Rerank models in model management and agent configurations
✨ new implement : support Rerank models in model management and agent configurations
2 parents 268708b + 3cb2fec commit d8c0de7

37 files changed

+3243
-112
lines changed

backend/agents/create_agent_info.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ElasticSearchService,
1515
get_vector_db_core,
1616
get_embedding_model,
17+
get_rerank_model,
1718
)
1819
from services.remote_mcp_service import get_remote_mcp_server_list
1920
from services.memory_config_service import build_memory_context
@@ -350,11 +351,32 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
350351
tool_config.metadata = langchain_tool
351352
break
352353

353-
# special logic for knowledge base search tool
354+
# special logic for search tools that may use reranking models
354355
if tool_config.class_name == "KnowledgeBaseSearchTool":
355-
tool_config.metadata = {
356+
rerank = param_dict.get("rerank", False)
357+
rerank_model_name = param_dict.get("rerank_model_name", "")
358+
rerank_model = None
359+
if rerank and rerank_model_name:
360+
rerank_model = get_rerank_model(
361+
tenant_id=tenant_id, model_name=rerank_model_name
362+
)
363+
364+
tool_config.metadata = {
356365
"vdb_core": get_vector_db_core(),
357366
"embedding_model": get_embedding_model(tenant_id=tenant_id),
367+
"rerank_model": rerank_model,
368+
}
369+
elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]:
370+
rerank = param_dict.get("rerank", False)
371+
rerank_model_name = param_dict.get("rerank_model_name", "")
372+
rerank_model = None
373+
if rerank and rerank_model_name:
374+
rerank_model = get_rerank_model(
375+
tenant_id=tenant_id, model_name=rerank_model_name
376+
)
377+
378+
tool_config.metadata = {
379+
"rerank_model": rerank_model,
358380
}
359381
elif tool_config.class_name == "AnalyzeTextFileTool":
360382
tool_config.metadata = {

backend/services/model_health_service.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from nexent.core import MessageObserver
44
from nexent.core.models import OpenAIModel, OpenAIVLModel
55
from nexent.core.models.embedding_model import JinaEmbedding, OpenAICompatibleEmbedding
6+
from nexent.core.models.rerank_model import OpenAICompatibleRerank
67

78
from services.voice_service import get_voice_service
89
from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST
@@ -102,7 +103,13 @@ async def _perform_connectivity_check(
102103
ssl_verify=ssl_verify
103104
).check_connectivity()
104105
elif model_type == "rerank":
105-
connectivity = False
106+
rerank_model = OpenAICompatibleRerank(
107+
model_name=model_name,
108+
base_url=model_base_url,
109+
api_key=model_api_key,
110+
ssl_verify=ssl_verify,
111+
)
112+
connectivity = await rerank_model.connectivity_check()
106113
elif model_type == "vlm":
107114
observer = MessageObserver()
108115
connectivity = await OpenAIVLModel(

backend/services/model_provider_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a
132132
model_dict["base_url"] = f"{model_url.rstrip('/')}/{MODEL_ENGINE_NORTH_PREFIX}/embeddings"
133133
# The embedding dimension might differ from the provided max_tokens.
134134
model_dict["max_tokens"] = await embedding_dimension_check(model_dict)
135+
elif model["model_type"] == "rerank":
136+
if provider == ProviderEnum.DASHSCOPE.value:
137+
model_dict["base_url"] = f"{model_url.replace('compatible-mode/v1','api/v1').rstrip('/')}/services/rerank/text-rerank/text-rerank"
138+
else:
139+
model_dict["base_url"] = f"{model_url.rstrip('/')}/rerank"
135140
else:
136141
# For non-embedding models
137142
if provider == ProviderEnum.MODELENGINE.value:

backend/services/providers/dashscope_provider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
5858
"chat": [], # Maps to "llm"
5959
"vlm": [], # Maps to "vlm"
6060
"embedding": [], # Maps to "embedding" / "multi_embedding"
61-
"reranker": [], # Maps to "reranker"
61+
"rerank": [], # Maps to "rerank"
6262
"tts": [], # Maps to "tts"
6363
"stt": [] # Maps to "stt"
6464
}
@@ -88,10 +88,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
8888
categorized_models['embedding'].append(cleaned_model)
8989
continue
9090

91-
# 2. Reranker
91+
# 2. Rerank
9292
if 'rerank' in m_id.lower() or '重排序' in desc:
93-
cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"})
94-
categorized_models['reranker'].append(cleaned_model)
93+
cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"})
94+
categorized_models['rerank'].append(cleaned_model)
9595
continue
9696

9797
# 3. STT

backend/services/providers/silicon_provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
3030
silicon_url = f"{SILICON_GET_URL}?sub_type=chat"
3131
elif model_type in ("embedding", "multi_embedding"):
3232
silicon_url = f"{SILICON_GET_URL}?sub_type=embedding"
33+
elif model_type == "rerank":
34+
silicon_url = f"{SILICON_GET_URL}?sub_type=reranker"
3335
else:
3436
silicon_url = SILICON_GET_URL
3537

@@ -48,6 +50,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
4850
for item in model_list:
4951
item["model_tag"] = "embedding"
5052
item["model_type"] = model_type
53+
elif model_type == "rerank":
54+
for item in model_list:
55+
item["model_tag"] = "rerank"
56+
item["model_type"] = model_type
5157

5258
# Return empty list to indicate successful API call but no models
5359
if not model_list:

backend/services/providers/tokenpony_provider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
4747
"chat": [], # Maps to "llm"
4848
"vlm": [], # Maps to "vlm"
4949
"embedding": [], # Maps to "embedding" / "multi_embedding"
50-
"reranker": [], # Maps to "reranker"
50+
"rerank": [], # Maps to "rerank"
5151
"tts": [], # Maps to "tts"
5252
"stt": [] # Maps to "stt"
5353
}
@@ -66,10 +66,10 @@ async def get_models(self, provider_config: Dict) -> List[Dict]:
6666
"model_type": "",
6767
"max_tokens": DEFAULT_LLM_MAX_TOKENS
6868
}
69-
# 1. reranker
69+
# 1. rerank
7070
if 'rerank' in m_id:
71-
cleaned_model.update({"model_tag": "reranker", "model_type": "reranker"})
72-
categorized_models['reranker'].append(cleaned_model)
71+
cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"})
72+
categorized_models['rerank'].append(cleaned_model)
7373
#2. embedding
7474
elif 'embedding' in m_id or m_id.startswith('bge-'):
7575
cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"})

backend/services/tool_configuration_service.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
check_tool_list_initialized,
2929
)
3030
from services.file_management_service import get_llm_model
31-
from services.vectordatabase_service import get_embedding_model, get_vector_db_core
31+
from services.vectordatabase_service import get_embedding_model, get_rerank_model, get_vector_db_core
3232
from database.client import minio_client
3333
from services.image_service import get_vlm_model
3434
from utils.tool_utils import get_local_tools_classes, get_local_tools_description_zh
@@ -694,10 +694,32 @@ def _validate_local_tool(
694694
if tool_name == "knowledge_base_search":
695695
embedding_model = get_embedding_model(tenant_id=tenant_id)
696696
vdb_core = get_vector_db_core()
697+
698+
# Get rerank configuration
699+
rerank = instantiation_params.get("rerank", False)
700+
rerank_model_name = instantiation_params.get("rerank_model_name", "")
701+
rerank_model = None
702+
if rerank and rerank_model_name:
703+
rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name)
704+
697705
params = {
698706
**instantiation_params,
699707
'vdb_core': vdb_core,
700708
'embedding_model': embedding_model,
709+
'rerank_model': rerank_model,
710+
}
711+
tool_instance = tool_class(**params)
712+
elif tool_name in ["dify_search", "datamate_search"]:
713+
# Get rerank configuration for dify and datamate search tools
714+
rerank = instantiation_params.get("rerank", False)
715+
rerank_model_name = instantiation_params.get("rerank_model_name", "")
716+
rerank_model = None
717+
if rerank and rerank_model_name:
718+
rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name)
719+
720+
params = {
721+
**instantiation_params,
722+
'rerank_model': rerank_model,
701723
}
702724
tool_instance = tool_class(**params)
703725
elif tool_name == "analyze_image":

backend/services/vectordatabase_service.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from fastapi import Body, Depends, Path, Query
2222
from fastapi.responses import StreamingResponse
2323
from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding
24+
from nexent.core.models.rerank_model import OpenAICompatibleRerank, BaseRerank
2425
from nexent.vector_database.base import VectorDatabaseCore
2526
from nexent.vector_database.elasticsearch_core import ElasticSearchCore
2627
from nexent.vector_database.datamate_core import DataMateCore
@@ -241,6 +242,52 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None):
241242
return None
242243

243244

245+
def get_rerank_model(tenant_id: str, model_name: Optional[str] = None):
246+
"""
247+
Get the rerank model for the tenant, optionally using a specific model name.
248+
249+
Args:
250+
tenant_id: Tenant ID
251+
model_name: Optional specific model name to use (format: "model_repo/model_name" or just "model_name")
252+
If provided, will try to find the model in the tenant's model list.
253+
254+
Returns:
255+
Rerank model instance or None
256+
"""
257+
# If model_name is provided, try to find it in the tenant's models
258+
if model_name:
259+
try:
260+
models = get_model_records({"model_type": "rerank"}, tenant_id)
261+
for model in models:
262+
model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"]
263+
if model_display_name == model_name:
264+
# Found the model, create rerank model instance
265+
return OpenAICompatibleRerank(
266+
model_name=get_model_name_from_config(model) or "",
267+
base_url=model.get("base_url", ""),
268+
api_key=model.get("api_key", ""),
269+
ssl_verify=model.get("ssl_verify", True),
270+
)
271+
except Exception as e:
272+
logger.warning(f"Failed to get rerank model by name {model_name}: {e}")
273+
274+
# Fall back to default rerank model
275+
model_config = tenant_config_manager.get_model_config(
276+
key="RERANK_ID", tenant_id=tenant_id)
277+
278+
model_type = model_config.get("model_type", "")
279+
280+
if model_type == "rerank":
281+
return OpenAICompatibleRerank(
282+
model_name=get_model_name_from_config(model_config) or "",
283+
base_url=model_config.get("base_url", ""),
284+
api_key=model_config.get("api_key", ""),
285+
ssl_verify=model_config.get("ssl_verify", True),
286+
)
287+
else:
288+
return None
289+
290+
244291
class ElasticSearchService:
245292
@staticmethod
246293
async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str):

doc/docs/zh/user-guide/agent-development.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,11 @@
130130
- 检索的模式 `search_mode`(默认为 `hybrid`)
131131
- 目标检索的知识库列表 `index_names`,如 `["医疗", "维生素知识大全"]`
132132
- 若不输入 `index_names`,则默认检索知识库页面所选中的全部知识库
133+
- 是否启用重排模型(默认为 `false`),启用后配置重排模型,实现对检索结果的重排优化
133134
6. 输入完成后点击"执行测试"开始测试,并在下方查看测试结果
134135
135136
<div style="display: flex; justify-content: left;">
136-
<img src="./assets/agent-development/tool-test-run.png" style="width: 80%; height: auto;" />
137+
<img src="./assets/agent-development/tool-test-run-1.png" style="width: 80%; height: auto;" />
137138
</div>
138139
139140
## 📝 描述业务逻辑
73.4 KB
Loading

0 commit comments

Comments
 (0)