Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from contextlib import asynccontextmanager
import copy
from datetime import datetime
from datetime import timezone
import logging
from typing import Any
from typing import AsyncIterator
Expand Down Expand Up @@ -59,6 +58,7 @@
from .schemas.v1 import StorageMetadata
from .schemas.v1 import StorageSession as StorageSessionV1
from .schemas.v1 import StorageUserState as StorageUserStateV1
from .schemas.shared import update_time_from_timestamp
from .session import Session
from .state import State

Expand Down Expand Up @@ -458,11 +458,10 @@ async def create_session(
storage_user_state.state = storage_user_state.state | user_state_delta

# Store the session
now = datetime.fromtimestamp(platform_time.get_time(), tz=timezone.utc)
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
if is_sqlite or is_postgresql:
now = now.replace(tzinfo=None)
dialect_name = self.db_engine.dialect.name
now = update_time_from_timestamp(
platform_time.get_time(), dialect_name
)

storage_session = schema.StorageSession(
app_name=app_name,
Expand All @@ -480,7 +479,7 @@ async def create_session(
storage_app_state.state, storage_user_state.state, session_state
)
session = storage_session.to_session(
state=merged_state, is_sqlite=is_sqlite
state=merged_state, dialect_name=dialect_name
)
return session

Expand All @@ -498,6 +497,7 @@ async def get_session(
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
schema = self._get_schema_classes()
dialect_name = self.db_engine.dialect.name
async with self._rollback_on_exception_session(
read_only=True
) as sql_session:
Expand Down Expand Up @@ -543,9 +543,10 @@ async def get_session(

# Convert storage session to session
events = [e.to_event() for e in reversed(storage_events)]
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
session = storage_session.to_session(
state=merged_state, events=events, is_sqlite=is_sqlite
state=merged_state,
events=events,
dialect_name=dialect_name,
)
return session

Expand Down Expand Up @@ -591,13 +592,16 @@ async def list_sessions(
user_states_map[storage_user_state.user_id] = storage_user_state.state

sessions = []
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
dialect_name = self.db_engine.dialect.name
for storage_session in results:
session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
merged_state = _merge_state(app_state, user_state, session_state)
sessions.append(
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
storage_session.to_session(
state=merged_state,
dialect_name=dialect_name,
)
)
return ListSessionsResponse(sessions=sessions)

Expand Down Expand Up @@ -632,7 +636,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
# 2. Update session attributes based on event config.
# 3. Store the new event.
schema = self._get_schema_classes()
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
dialect_name = self.db_engine.dialect.name
use_row_level_locking = self._supports_row_level_locking()

state_delta = (
Expand Down Expand Up @@ -662,7 +666,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_session = storage_session_result.scalars().one_or_none()
if storage_session is None:
raise ValueError(f"Session {session.id} not found.")
storage_update_time = storage_session.get_update_timestamp(is_sqlite)
storage_update_time = storage_session.get_update_timestamp(dialect_name)
storage_update_marker = storage_session.get_update_marker()

storage_app_state = await _select_required_state(
Expand Down Expand Up @@ -728,20 +732,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_session.state | state_deltas["session"]
)

if is_sqlite:
update_time = datetime.fromtimestamp(
event.timestamp, timezone.utc
).replace(tzinfo=None)
else:
update_time = datetime.fromtimestamp(event.timestamp)
storage_session.update_time = update_time
storage_session.update_time = update_time_from_timestamp(
event.timestamp, dialect_name
)
sql_session.add(schema.StorageEvent.from_event(session, event))

await sql_session.commit()

# Update timestamp with commit time
session.last_update_time = storage_session.get_update_timestamp(
is_sqlite
dialect_name
)
session._storage_update_marker = storage_session.get_update_marker()

Expand Down
29 changes: 29 additions & 0 deletions src/google/adk/sessions/schemas/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from __future__ import annotations

from datetime import datetime
from datetime import timezone
import json

from sqlalchemy import Dialect
Expand All @@ -25,6 +27,33 @@
DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256

# Dialects that store TIMESTAMP values as UTC-naive datetimes and therefore
# require us to reattach UTC tzinfo on read and strip it on write.
_NAIVE_UTC_DIALECTS = frozenset({"sqlite", "postgresql"})


def update_timestamp_from_dt(dt: datetime, dialect_name: str) -> float:
"""Converts a DB-returned datetime to a POSIX timestamp.

SQLite and PostgreSQL store naive datetimes that represent UTC values.
All other dialects return timezone-aware datetimes directly.
"""
if dialect_name in _NAIVE_UTC_DIALECTS:
return dt.replace(tzinfo=timezone.utc).timestamp()
return dt.timestamp()


def update_time_from_timestamp(posix_ts: float, dialect_name: str) -> datetime:
"""Converts a POSIX timestamp to the datetime format expected by the DB.

SQLite and PostgreSQL require a UTC-naive datetime; every other dialect
accepts (and prefers) a UTC-aware datetime.
"""
dt = datetime.fromtimestamp(posix_ts, timezone.utc)
if dialect_name in _NAIVE_UTC_DIALECTS:
return dt.replace(tzinfo=None)
return dt


class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
Expand Down
26 changes: 11 additions & 15 deletions src/google/adk/sessions/schemas/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
from .shared import DynamicJSON
from .shared import PreciseTimestamp
from .shared import update_timestamp_from_dt

logger = logging.getLogger("google_adk." + __name__)

Expand Down Expand Up @@ -167,21 +168,16 @@ def update_timestamp_tz(self) -> float:
This is a compatibility alias for callers that used the pre-`main` API.
"""
sqlalchemy_session = inspect(self).session
is_sqlite = bool(
sqlalchemy_session
and sqlalchemy_session.bind
and sqlalchemy_session.bind.dialect.name == "sqlite"
dialect_name = (
sqlalchemy_session.bind.dialect.name
if sqlalchemy_session and sqlalchemy_session.bind
else None
)
return self.get_update_timestamp(is_sqlite=is_sqlite)
return self.get_update_timestamp(dialect_name)

def get_update_timestamp(self, is_sqlite: bool) -> float:
"""Returns the time zone aware update timestamp."""
if is_sqlite:
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
# object without timezone information. We need to convert it to UTC
# manually.
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
return self.update_time.timestamp()
def get_update_timestamp(self, dialect_name: str | None) -> float:
"""Returns the update timestamp as a POSIX timestamp."""
return update_timestamp_from_dt(self.update_time, dialect_name or "")

def get_update_marker(self) -> str:
"""Returns a stable revision marker for optimistic concurrency checks."""
Expand All @@ -194,7 +190,7 @@ def to_session(
self,
state: dict[str, Any] | None = None,
events: list[Event] | None = None,
is_sqlite: bool = False,
dialect_name: str | None = None,
) -> Session:
"""Converts the storage session to a session object."""
if state is None:
Expand All @@ -208,7 +204,7 @@ def to_session(
id=self.id,
state=state,
events=events,
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
last_update_time=self.get_update_timestamp(dialect_name),
)
session._storage_update_marker = self.get_update_marker()
return session
Expand Down
26 changes: 11 additions & 15 deletions src/google/adk/sessions/schemas/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
from .shared import DynamicJSON
from .shared import PreciseTimestamp
from .shared import update_timestamp_from_dt


class Base(DeclarativeBase):
Expand Down Expand Up @@ -114,21 +115,16 @@ def update_timestamp_tz(self) -> float:
This is a compatibility alias for callers that used the pre-`main` API.
"""
sqlalchemy_session = inspect(self).session
is_sqlite = bool(
sqlalchemy_session
and sqlalchemy_session.bind
and sqlalchemy_session.bind.dialect.name == "sqlite"
dialect_name = (
sqlalchemy_session.bind.dialect.name
if sqlalchemy_session and sqlalchemy_session.bind
else None
)
return self.get_update_timestamp(is_sqlite=is_sqlite)
return self.get_update_timestamp(dialect_name)

def get_update_timestamp(self, is_sqlite: bool) -> float:
"""Returns the time zone aware update timestamp."""
if is_sqlite:
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
# object without timezone information. We need to convert it to UTC
# manually.
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
return self.update_time.timestamp()
def get_update_timestamp(self, dialect_name: str | None) -> float:
"""Returns the update timestamp as a POSIX timestamp."""
return update_timestamp_from_dt(self.update_time, dialect_name or "")

def get_update_marker(self) -> str:
"""Returns a stable revision marker for optimistic concurrency checks."""
Expand All @@ -141,7 +137,7 @@ def to_session(
self,
state: dict[str, Any] | None = None,
events: list[Event] | None = None,
is_sqlite: bool = False,
dialect_name: str | None = None,
) -> Session:
"""Converts the storage session to a session object."""
if state is None:
Expand All @@ -155,7 +151,7 @@ def to_session(
id=self.id,
state=state,
events=events,
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
last_update_time=self.get_update_timestamp(dialect_name),
)
session._storage_update_marker = self.get_update_marker()
return session
Expand Down
81 changes: 51 additions & 30 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from google.adk.sessions.base_session_service import GetSessionConfig
from google.adk.sessions.database_session_service import DatabaseSessionService
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.schemas.shared import update_time_from_timestamp
from google.adk.sessions.schemas.shared import update_timestamp_from_dt
from google.adk.sessions.sqlite_session_service import SqliteSessionService
from google.genai import types
import pytest
Expand Down Expand Up @@ -103,44 +105,63 @@ def fake_create_async_engine(_db_url: str, **kwargs):


@pytest.mark.parametrize('dialect_name', ['sqlite', 'postgresql'])
def test_database_session_service_strips_timezone_for_dialect(dialect_name):
"""Verifies that timezone-aware datetimes are converted to naive datetimes
for SQLite and PostgreSQL to avoid 'can't subtract offset-naive and
offset-aware datetimes' errors.

PostgreSQL's default TIMESTAMP type is WITHOUT TIME ZONE, which cannot
accept timezone-aware datetime objects when using asyncpg. SQLite also
requires naive datetimes.
"""
# Simulate the logic in create_session
is_sqlite = dialect_name == 'sqlite'
is_postgres = dialect_name == 'postgresql'
def test_update_time_from_timestamp_strips_timezone_for_naive_utc_dialects(
dialect_name,
):
"""update_time_from_timestamp returns a UTC-naive datetime for SQLite and
PostgreSQL, which store TIMESTAMP WITHOUT TIME ZONE values."""
posix_ts = 1_700_000_000.0
result = update_time_from_timestamp(posix_ts, dialect_name)
assert result.tzinfo is None
# Value must represent the correct UTC instant.
assert result == datetime.fromtimestamp(posix_ts, timezone.utc).replace(
tzinfo=None
)


def test_update_time_from_timestamp_preserves_timezone_for_other_dialects():
"""update_time_from_timestamp returns a UTC-aware datetime for dialects
that support TIMESTAMP WITH TIME ZONE (e.g. MySQL)."""
posix_ts = 1_700_000_000.0
result = update_time_from_timestamp(posix_ts, 'mysql')
assert result.tzinfo is not None
assert result == datetime.fromtimestamp(posix_ts, timezone.utc)

now = datetime.now(timezone.utc)
assert now.tzinfo is not None # Starts with timezone

if is_sqlite or is_postgres:
now = now.replace(tzinfo=None)
@pytest.mark.parametrize('dialect_name', ['sqlite', 'postgresql'])
def test_update_timestamp_from_dt_treats_naive_dt_as_utc_for_naive_utc_dialects(
dialect_name,
):
"""update_timestamp_from_dt must reattach UTC tzinfo before computing the
POSIX timestamp for SQLite and PostgreSQL.

This is the core of the bug fixed in commit 0e5790805a2f4d:
PostgreSQL returns a UTC-naive datetime, so calling .timestamp() directly
on a non-UTC host would interpret it as local time and produce a wrong
POSIX value.
"""
posix_ts = 1_700_000_000.0
# Simulate a naive datetime as returned by PostgreSQL / SQLite.
naive_utc_dt = datetime.fromtimestamp(posix_ts, timezone.utc).replace(
tzinfo=None
)
assert naive_utc_dt.tzinfo is None

# Both SQLite and PostgreSQL should have timezone stripped
assert now.tzinfo is None
result = update_timestamp_from_dt(naive_utc_dt, dialect_name)

assert result == posix_ts

def test_database_session_service_preserves_timezone_for_other_dialects():
"""Verifies that timezone info is preserved for dialects that support it."""
# For dialects like MySQL with explicit timezone support, we don't strip
dialect_name = 'mysql'
is_sqlite = dialect_name == 'sqlite'
is_postgres = dialect_name == 'postgresql'

now = datetime.now(timezone.utc)
assert now.tzinfo is not None
def test_update_timestamp_from_dt_uses_tzinfo_for_aware_dialects():
"""update_timestamp_from_dt uses the datetime's own tzinfo for dialects
that return timezone-aware datetimes (e.g. MySQL)."""
posix_ts = 1_700_000_000.0
aware_dt = datetime.fromtimestamp(posix_ts, timezone.utc)
assert aware_dt.tzinfo is not None

if is_sqlite or is_postgres:
now = now.replace(tzinfo=None)
result = update_timestamp_from_dt(aware_dt, 'mysql')

# MySQL should preserve timezone (if the column type supports it)
assert now.tzinfo is not None
assert result == posix_ts


def test_database_session_service_respects_pool_pre_ping_override():
Expand Down