Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
CLIENT_ORIGINS=https://fake-origin.example.com
CLIENT_ORIGINS_REGEX="^http://fake-localhost:.*"
SESSION_COOKIE_DOMAIN=.example.com
ENV=development

##### AZURE #####
Expand Down
260 changes: 26 additions & 234 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ langgraph-checkpoint-postgres = "^2.0.23"
azure-ai-inference = "^1.0.0b9"
azure-identity = "^1.25.0"
psycopg = {extras = ["binary"], version = "^3.2.10"}
welearn-database = "^1.4.5"
bs4 = "^0.0.2"
Comment thread
jmsevin marked this conversation as resolved.
urllib3 = "^2.6.3"
refinedoc = "^1.0.1"
Expand All @@ -50,7 +51,6 @@ langchain-mistralai = "^1.1.2"
langchain-azure-ai = "^1.2.3"
langgraph = "^1.1.10"
mistralai = "^2.4.3"
welearn-database = "^1.4.5"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
268 changes: 259 additions & 9 deletions src/app/api/api_v1/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import uuid
from typing import Dict, Optional, cast
from typing import Any, AsyncGenerator, Dict, Optional, cast
from uuid import UUID

import backoff
import psycopg
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from langchain_core.messages import ToolMessage
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from openai import RateLimitError
from psycopg.rows import dict_row
from psycopg.rows import AsyncRowFactory, DictRow, dict_row
from pydantic import BaseModel

from src.app.models import chat as models
Expand Down Expand Up @@ -42,6 +44,32 @@
database=settings.PG_DATABASE,
)

# psycopg exposes dict_row with a BaseCursor annotation, while AsyncConnection.connect
# expects an async row factory type. Runtime is valid; cast keeps static typing happy.
ASYNC_DICT_ROW_FACTORY = cast(AsyncRowFactory[DictRow], dict_row)

SSE_HEADERS = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}


def _format_sse_event(data: str) -> str:
lines = data.splitlines() or [""]
return "".join(f"data: {line}\n" for line in lines) + "\n"


async def _sse_wrap(stream: Any) -> AsyncGenerator[str, None]:
async for chunk in stream:
if isinstance(chunk, str):
data = chunk
elif isinstance(chunk, bytes):
data = chunk.decode("utf-8", errors="replace")
else:
data = json.dumps(jsonable_encoder(chunk))
Comment thread
jmsevin marked this conversation as resolved.
yield _format_sse_event(data)


def get_params(body: models.Context) -> models.ContextOut:
body.sources = body.sources[:7]
Expand All @@ -55,6 +83,7 @@ def get_params(body: models.Context) -> models.ContextOut:
history=body.history or [],
query=body.query,
subject=body.subject,
conversation_id=None,
Comment thread
jmsevin marked this conversation as resolved.
)


Expand Down Expand Up @@ -221,8 +250,9 @@ async def q_and_a_rephrase_stream(
)

return StreamingResponse(
content=content,
content=_sse_wrap(content),
media_type="text/event-stream",
headers=SSE_HEADERS,
)


Expand Down Expand Up @@ -316,8 +346,9 @@ async def q_and_a_stream(
)

return StreamingResponse(
content=content,
content=_sse_wrap(content),
media_type="text/event-stream",
headers=SSE_HEADERS,
)
except LanguageNotSupportedError as e:
bad_request(message=e.message, msg_code=e.msg_code)
Expand All @@ -338,8 +369,11 @@ async def get_chat_history(
chatfactory=Depends(get_chat_service),
) -> list[Dict[str, str | list[Dict[str, str]] | None]]:
if thread_id:
async with await psycopg.AsyncConnection.connect(
DB_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row
async with await psycopg.AsyncConnection[DictRow].connect(
DB_URI,
autocommit=True,
prepare_threshold=0,
row_factory=ASYNC_DICT_ROW_FACTORY,
) as conn:
await conn.execute("SET SEARCH_PATH to agent_related")
await conn.commit()
Expand All @@ -351,6 +385,219 @@ async def get_chat_history(
return res


def _resolve_thread_id(thread_id: UUID | None) -> UUID:
if thread_id:
return thread_id

logger.info("No thread_id provided. Generating new thread_id.")
return uuid.uuid4()


def _update_agent_stream_state(
chunk: dict[str, Any],
current_final_content: str,
current_docs: Any,
) -> tuple[str, Any]:
status = chunk.get("status")
docs = current_docs
final_content = current_final_content

if status == "processing" and chunk.get("docs"):
docs = chunk["docs"]
elif status == "streaming":
final_content += cast(str, chunk.get("content", ""))
elif status == "stop":
stop_content = cast(str, chunk.get("content", ""))
Comment thread
jmsevin marked this conversation as resolved.
if stop_content:
final_content = stop_content

return final_content, docs


def _serialize_agent_stream_chunk(chunk: dict[str, Any]) -> str:
payload = {
"content": chunk.get("content"),
"status": chunk.get("status"),
"step": chunk.get("step"),
"label": chunk.get("label"),
"docs": chunk.get("docs"),
}

return json.dumps(jsonable_encoder(payload))
Comment thread
jmsevin marked this conversation as resolved.


async def _stream_agent_with_memory(
*,
chatfactory: Any,
body: models.AgentContext,
sp: SearchService,
background_tasks: BackgroundTasks,
thread_id: UUID,
) -> AsyncGenerator[dict[str, Any], None]:
async with await psycopg.AsyncConnection[DictRow].connect(
DB_URI,
autocommit=True,
prepare_threshold=0,
row_factory=ASYNC_DICT_ROW_FACTORY,
) as conn:
await conn.execute("SET SEARCH_PATH to agent_related")
await conn.commit()

memory = AsyncPostgresSaver(conn)
stream = await chatfactory.agent_message(
query=body.query,
memory=memory,
thread_id=thread_id,
corpora=body.corpora,
sdg_filter=body.sdg_filter,
sp=sp,
background_tasks=background_tasks,
streamed_ans=True,
)

async for chunk in stream:
yield chunk


def _build_final_stream_payload(
*,
final_content: str,
docs: Any,
thread_id: UUID,
) -> dict[str, Any]:
return {
"content": final_content,
"status": "stop",
"docs": docs,
"thread_id": thread_id,
}


async def _register_stream_chat_data(
*,
data_collection: Any,
session_id: UUID | None,
user_query: str,
conversation_id: UUID,
answer_content: str,
sources: Any,
) -> Any:
_, message_id = await data_collection.register_chat_data(
session_id=session_id,
user_query=user_query,
conversation_id=conversation_id,
answer_content=answer_content,
sources=sources,
)
return message_id


async def _stream_agent_response(
*,
body: models.AgentContext,
chatfactory: Any,
sp: SearchService,
background_tasks: BackgroundTasks,
data_collection: Any,
session_id: UUID | None,
thread_id: UUID,
) -> AsyncGenerator[str, None]:
final_content = ""
docs = []
has_streamed_content = False

stream = _stream_agent_with_memory(
chatfactory=chatfactory,
body=body,
sp=sp,
background_tasks=background_tasks,
thread_id=thread_id,
)

async for chunk in stream:
final_content, docs = _update_agent_stream_state(chunk, final_content, docs)
if chunk.get("status") == "streaming" and chunk.get("content"):
has_streamed_content = True
if chunk.get("status") == "stop":
continue
try:
yield _format_sse_event(_serialize_agent_stream_chunk(chunk))
except Exception as e:
logger.error("Error while yielding chunk: %s", e)

Comment thread
jmsevin marked this conversation as resolved.
final_payload = _build_final_stream_payload(
final_content=final_content,
docs=docs,
thread_id=thread_id,
)

if has_streamed_content:
final_payload = {**final_payload, "content": ""}

try:
message_id = await _register_stream_chat_data(
data_collection=data_collection,
session_id=session_id,
user_query=cast(str, body.query),
conversation_id=thread_id,
answer_content=final_content,
sources=docs,
)
final_payload = {**final_payload, "message_id": message_id}
except Exception as e:
logger.error("Error while registering chat data: %s", e)

yield _format_sse_event(json.dumps(jsonable_encoder(final_payload)))


@router.post(
"/chat/agent_stream",
summary="Agent Response Stream",
description="This endpoint streams an agent response to the user's message and ends with the full response payload.",
response_class=StreamingResponse,
)
Comment thread
jmsevin marked this conversation as resolved.
@backoff.on_exception(
wait_gen=backoff.expo,
exception=RateLimitError,
logger=logger,
max_tries=5,
max_time=180,
jitter=backoff.random_jitter,
factor=2,
)
async def agent_stream_response(
request: Request,
background_tasks: BackgroundTasks,
body: models.AgentContext = Depends(get_agent_params),
chatfactory=Depends(get_chat_service),
sp: SearchService = Depends(get_search_service),
data_collection=Depends(get_data_collection_service),
) -> StreamingResponse:
try:
session_id = extract_session_cookie(request)
thread_id = _resolve_thread_id(body.thread_id)

if body.query is None:
raise EmptyQueryError()

return StreamingResponse(
content=_stream_agent_response(
body=body,
chatfactory=chatfactory,
sp=sp,
background_tasks=background_tasks,
data_collection=data_collection,
session_id=session_id,
thread_id=thread_id,
),
media_type="text/event-stream",
headers=SSE_HEADERS,
)
except LanguageNotSupportedError as e:
bad_request(message=e.message, msg_code=e.msg_code)
raise


@router.post(
"/chat/agent",
summary="Agent Response",
Expand Down Expand Up @@ -388,8 +635,11 @@ async def agent_response(
raise EmptyQueryError()

if thread_id:
async with await psycopg.AsyncConnection.connect(
DB_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row
async with await psycopg.AsyncConnection[DictRow].connect(
DB_URI,
autocommit=True,
prepare_threshold=0,
row_factory=ASYNC_DICT_ROW_FACTORY,
) as conn:
await conn.execute("SET SEARCH_PATH to agent_related")
await conn.commit()
Expand Down Expand Up @@ -425,7 +675,7 @@ async def agent_response(
}

try:
conversation_id, message_id = await data_collection.register_chat_data(
_, message_id = await data_collection.register_chat_data(
session_id=session_id,
user_query=body.query,
conversation_id=thread_id,
Expand Down
2 changes: 2 additions & 0 deletions src/app/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class AgentContext(SDGFilter):

class AgentResponse(BaseModel):
content: str | None = None
status: str | None = None
step: str | None = None
docs: list[ScoredPoint] | None = None
thread_id: uuid.UUID | None = None

Expand Down
7 changes: 7 additions & 0 deletions src/app/search/services/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def search_handler(
background_tasks: BackgroundTasks,
qp: EnhancedSearchQuery,
method: SearchMethods = SearchMethods.BY_SLICES,
without_vectors: bool = False,
) -> list[http_models.ScoredPoint]:
assert isinstance(qp.query, str)

Expand Down Expand Up @@ -314,6 +315,12 @@ async def search_handler(
ex,
)

if without_vectors:
points_without_vectors = [
point.model_copy(update={"vector": None}) for point in sorted_data
]
return points_without_vectors

Comment thread
jmsevin marked this conversation as resolved.
return sorted_data

@log_time_and_error
Expand Down
Loading
Loading