|
21 | 21 | from fastapi import Body, Depends, Path, Query |
22 | 22 | from fastapi.responses import StreamingResponse |
23 | 23 | from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding |
| 24 | +from nexent.core.models.rerank_model import OpenAICompatibleRerank, BaseRerank |
24 | 25 | from nexent.vector_database.base import VectorDatabaseCore |
25 | 26 | from nexent.vector_database.elasticsearch_core import ElasticSearchCore |
26 | 27 | from nexent.vector_database.datamate_core import DataMateCore |
@@ -241,6 +242,52 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): |
241 | 242 | return None |
242 | 243 |
|
243 | 244 |
|
| 245 | +def get_rerank_model(tenant_id: str, model_name: Optional[str] = None): |
| 246 | + """ |
| 247 | + Get the rerank model for the tenant, optionally using a specific model name. |
| 248 | +
|
| 249 | + Args: |
| 250 | + tenant_id: Tenant ID |
| 251 | + model_name: Optional specific model name to use (format: "model_repo/model_name" or just "model_name") |
| 252 | + If provided, will try to find the model in the tenant's model list. |
| 253 | +
|
| 254 | + Returns: |
| 255 | + Rerank model instance or None |
| 256 | + """ |
| 257 | + # If model_name is provided, try to find it in the tenant's models |
| 258 | + if model_name: |
| 259 | + try: |
| 260 | + models = get_model_records({"model_type": "rerank"}, tenant_id) |
| 261 | + for model in models: |
| 262 | + model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] |
| 263 | + if model_display_name == model_name: |
| 264 | + # Found the model, create rerank model instance |
| 265 | + return OpenAICompatibleRerank( |
| 266 | + model_name=get_model_name_from_config(model) or "", |
| 267 | + base_url=model.get("base_url", ""), |
| 268 | + api_key=model.get("api_key", ""), |
| 269 | + ssl_verify=model.get("ssl_verify", True), |
| 270 | + ) |
| 271 | + except Exception as e: |
| 272 | + logger.warning(f"Failed to get rerank model by name {model_name}: {e}") |
| 273 | + |
| 274 | + # Fall back to default rerank model |
| 275 | + model_config = tenant_config_manager.get_model_config( |
| 276 | + key="RERANK_ID", tenant_id=tenant_id) |
| 277 | + |
| 278 | + model_type = model_config.get("model_type", "") |
| 279 | + |
| 280 | + if model_type == "rerank": |
| 281 | + return OpenAICompatibleRerank( |
| 282 | + model_name=get_model_name_from_config(model_config) or "", |
| 283 | + base_url=model_config.get("base_url", ""), |
| 284 | + api_key=model_config.get("api_key", ""), |
| 285 | + ssl_verify=model_config.get("ssl_verify", True), |
| 286 | + ) |
| 287 | + else: |
| 288 | + return None |
| 289 | + |
| 290 | + |
244 | 291 | class ElasticSearchService: |
245 | 292 | @staticmethod |
246 | 293 | async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCore, user_id: str): |
|
0 commit comments