Skip to content

Commit cb19676

Browse files
zouyi100ZOU Yi (BD/SWD-WDE1)
authored andcommitted
fix(sessions): code cleanup create shared logic to calc dt with diff database
1 parent 1732b38 commit cb19676

File tree

5 files changed

+110
-96
lines changed

5 files changed

+110
-96
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from contextlib import asynccontextmanager
1818
import copy
1919
from datetime import datetime
20-
from datetime import timezone
2120
import logging
2221
from typing import Any
2322
from typing import AsyncIterator
@@ -59,6 +58,7 @@
5958
from .schemas.v1 import StorageMetadata
6059
from .schemas.v1 import StorageSession as StorageSessionV1
6160
from .schemas.v1 import StorageUserState as StorageUserStateV1
61+
from .schemas.shared import update_time_from_timestamp
6262
from .session import Session
6363
from .state import State
6464

@@ -458,11 +458,10 @@ async def create_session(
458458
storage_user_state.state = storage_user_state.state | user_state_delta
459459

460460
# Store the session
461-
now = datetime.fromtimestamp(platform_time.get_time(), tz=timezone.utc)
462-
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
463-
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
464-
if is_sqlite or is_postgresql:
465-
now = now.replace(tzinfo=None)
461+
dialect_name = self.db_engine.dialect.name
462+
now = update_time_from_timestamp(
463+
platform_time.get_time(), dialect_name
464+
)
466465

467466
storage_session = schema.StorageSession(
468467
app_name=app_name,
@@ -480,7 +479,7 @@ async def create_session(
480479
storage_app_state.state, storage_user_state.state, session_state
481480
)
482481
session = storage_session.to_session(
483-
state=merged_state, is_sqlite=is_sqlite, is_postgresql=is_postgresql
482+
state=merged_state, dialect_name=dialect_name
484483
)
485484
return session
486485

@@ -498,8 +497,7 @@ async def get_session(
498497
# 2. Get all the events based on session id and filtering config
499498
# 3. Convert and return the session
500499
schema = self._get_schema_classes()
501-
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
502-
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
500+
dialect_name = self.db_engine.dialect.name
503501
async with self._rollback_on_exception_session(
504502
read_only=True
505503
) as sql_session:
@@ -548,8 +546,7 @@ async def get_session(
548546
session = storage_session.to_session(
549547
state=merged_state,
550548
events=events,
551-
is_sqlite=is_sqlite,
552-
is_postgresql=is_postgresql,
549+
dialect_name=dialect_name,
553550
)
554551
return session
555552

@@ -595,17 +592,15 @@ async def list_sessions(
595592
user_states_map[storage_user_state.user_id] = storage_user_state.state
596593

597594
sessions = []
598-
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
599-
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
595+
dialect_name = self.db_engine.dialect.name
600596
for storage_session in results:
601597
session_state = storage_session.state
602598
user_state = user_states_map.get(storage_session.user_id, {})
603599
merged_state = _merge_state(app_state, user_state, session_state)
604600
sessions.append(
605601
storage_session.to_session(
606602
state=merged_state,
607-
is_sqlite=is_sqlite,
608-
is_postgresql=is_postgresql,
603+
dialect_name=dialect_name,
609604
)
610605
)
611606
return ListSessionsResponse(sessions=sessions)
@@ -641,8 +636,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
641636
# 2. Update session attributes based on event config.
642637
# 3. Store the new event.
643638
schema = self._get_schema_classes()
644-
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
645-
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
639+
dialect_name = self.db_engine.dialect.name
646640
use_row_level_locking = self._supports_row_level_locking()
647641

648642
state_delta = (
@@ -672,9 +666,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
672666
storage_session = storage_session_result.scalars().one_or_none()
673667
if storage_session is None:
674668
raise ValueError(f"Session {session.id} not found.")
675-
storage_update_time = storage_session.get_update_timestamp(
676-
is_sqlite, is_postgresql
677-
)
669+
storage_update_time = storage_session.get_update_timestamp(dialect_name)
678670
storage_update_marker = storage_session.get_update_marker()
679671

680672
storage_app_state = await _select_required_state(
@@ -740,20 +732,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
740732
storage_session.state | state_deltas["session"]
741733
)
742734

743-
if is_sqlite or is_postgresql:
744-
update_time = datetime.fromtimestamp(
745-
event.timestamp, timezone.utc
746-
).replace(tzinfo=None)
747-
else:
748-
update_time = datetime.fromtimestamp(event.timestamp, timezone.utc)
749-
storage_session.update_time = update_time
735+
storage_session.update_time = update_time_from_timestamp(
736+
event.timestamp, dialect_name
737+
)
750738
sql_session.add(schema.StorageEvent.from_event(session, event))
751739

752740
await sql_session.commit()
753741

754742
# Update timestamp with commit time
755743
session.last_update_time = storage_session.get_update_timestamp(
756-
is_sqlite, is_postgresql
744+
dialect_name
757745
)
758746
session._storage_update_marker = storage_session.get_update_marker()
759747

src/google/adk/sessions/schemas/shared.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
from datetime import datetime
17+
from datetime import timezone
1618
import json
1719

1820
from sqlalchemy import Dialect
@@ -25,6 +27,33 @@
2527
DEFAULT_MAX_KEY_LENGTH = 128
2628
DEFAULT_MAX_VARCHAR_LENGTH = 256
2729

30+
# Dialects that store TIMESTAMP values as UTC-naive datetimes and therefore
31+
# require us to reattach UTC tzinfo on read and strip it on write.
32+
_NAIVE_UTC_DIALECTS = frozenset({"sqlite", "postgresql"})
33+
34+
35+
def update_timestamp_from_dt(dt: datetime, dialect_name: str) -> float:
36+
"""Converts a DB-returned datetime to a POSIX timestamp.
37+
38+
SQLite and PostgreSQL store naive datetimes that represent UTC values.
39+
All other dialects return timezone-aware datetimes directly.
40+
"""
41+
if dialect_name in _NAIVE_UTC_DIALECTS:
42+
return dt.replace(tzinfo=timezone.utc).timestamp()
43+
return dt.timestamp()
44+
45+
46+
def update_time_from_timestamp(posix_ts: float, dialect_name: str) -> datetime:
47+
"""Converts a POSIX timestamp to the datetime format expected by the DB.
48+
49+
SQLite and PostgreSQL require a UTC-naive datetime; every other dialect
50+
accepts (and prefers) a UTC-aware datetime.
51+
"""
52+
dt = datetime.fromtimestamp(posix_ts, timezone.utc)
53+
if dialect_name in _NAIVE_UTC_DIALECTS:
54+
return dt.replace(tzinfo=None)
55+
return dt
56+
2857

2958
class DynamicJSON(TypeDecorator):
3059
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""

src/google/adk/sessions/schemas/v0.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
6262
from .shared import DynamicJSON
6363
from .shared import PreciseTimestamp
64+
from .shared import update_timestamp_from_dt
6465

6566
logger = logging.getLogger("google_adk." + __name__)
6667

@@ -172,21 +173,11 @@ def update_timestamp_tz(self) -> float:
172173
if sqlalchemy_session and sqlalchemy_session.bind
173174
else None
174175
)
175-
is_sqlite = dialect_name == "sqlite"
176-
is_postgresql = dialect_name == "postgresql"
177-
return self.get_update_timestamp(
178-
is_sqlite=is_sqlite, is_postgresql=is_postgresql
179-
)
176+
return self.get_update_timestamp(dialect_name)
180177

181-
def get_update_timestamp(
182-
self, is_sqlite: bool, is_postgresql: bool = False
183-
) -> float:
184-
"""Returns the time zone aware update timestamp."""
185-
if is_sqlite or is_postgresql:
186-
# SQLite and PostgreSQL store naive datetimes as UTC values. We need to
187-
# attach UTC timezone info before converting to a POSIX timestamp.
188-
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
189-
return self.update_time.timestamp()
178+
def get_update_timestamp(self, dialect_name: str | None) -> float:
179+
"""Returns the update timestamp as a POSIX timestamp."""
180+
return update_timestamp_from_dt(self.update_time, dialect_name or "")
190181

191182
def get_update_marker(self) -> str:
192183
"""Returns a stable revision marker for optimistic concurrency checks."""
@@ -199,8 +190,7 @@ def to_session(
199190
self,
200191
state: dict[str, Any] | None = None,
201192
events: list[Event] | None = None,
202-
is_sqlite: bool = False,
203-
is_postgresql: bool = False,
193+
dialect_name: str | None = None,
204194
) -> Session:
205195
"""Converts the storage session to a session object."""
206196
if state is None:
@@ -214,9 +204,7 @@ def to_session(
214204
id=self.id,
215205
state=state,
216206
events=events,
217-
last_update_time=self.get_update_timestamp(
218-
is_sqlite=is_sqlite, is_postgresql=is_postgresql
219-
),
207+
last_update_time=self.get_update_timestamp(dialect_name),
220208
)
221209
session._storage_update_marker = self.get_update_marker()
222210
return session

src/google/adk/sessions/schemas/v1.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
4747
from .shared import DynamicJSON
4848
from .shared import PreciseTimestamp
49+
from .shared import update_timestamp_from_dt
4950

5051

5152
class Base(DeclarativeBase):
@@ -119,21 +120,11 @@ def update_timestamp_tz(self) -> float:
119120
if sqlalchemy_session and sqlalchemy_session.bind
120121
else None
121122
)
122-
is_sqlite = dialect_name == "sqlite"
123-
is_postgresql = dialect_name == "postgresql"
124-
return self.get_update_timestamp(
125-
is_sqlite=is_sqlite, is_postgresql=is_postgresql
126-
)
123+
return self.get_update_timestamp(dialect_name)
127124

128-
def get_update_timestamp(
129-
self, is_sqlite: bool, is_postgresql: bool = False
130-
) -> float:
131-
"""Returns the time zone aware update timestamp."""
132-
if is_sqlite or is_postgresql:
133-
# SQLite and PostgreSQL store naive datetimes as UTC values. We need to
134-
# attach UTC timezone info before converting to a POSIX timestamp.
135-
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
136-
return self.update_time.timestamp()
125+
def get_update_timestamp(self, dialect_name: str | None) -> float:
126+
"""Returns the update timestamp as a POSIX timestamp."""
127+
return update_timestamp_from_dt(self.update_time, dialect_name or "")
137128

138129
def get_update_marker(self) -> str:
139130
"""Returns a stable revision marker for optimistic concurrency checks."""
@@ -146,8 +137,7 @@ def to_session(
146137
self,
147138
state: dict[str, Any] | None = None,
148139
events: list[Event] | None = None,
149-
is_sqlite: bool = False,
150-
is_postgresql: bool = False,
140+
dialect_name: str | None = None,
151141
) -> Session:
152142
"""Converts the storage session to a session object."""
153143
if state is None:
@@ -161,9 +151,7 @@ def to_session(
161151
id=self.id,
162152
state=state,
163153
events=events,
164-
last_update_time=self.get_update_timestamp(
165-
is_sqlite=is_sqlite, is_postgresql=is_postgresql
166-
),
154+
last_update_time=self.get_update_timestamp(dialect_name),
167155
)
168156
session._storage_update_marker = self.get_update_marker()
169157
return session

tests/unittests/sessions/test_session_service.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from google.adk.sessions.base_session_service import GetSessionConfig
3030
from google.adk.sessions.database_session_service import DatabaseSessionService
3131
from google.adk.sessions.in_memory_session_service import InMemorySessionService
32+
from google.adk.sessions.schemas.shared import update_time_from_timestamp
33+
from google.adk.sessions.schemas.shared import update_timestamp_from_dt
3234
from google.adk.sessions.sqlite_session_service import SqliteSessionService
3335
from google.genai import types
3436
import pytest
@@ -103,44 +105,63 @@ def fake_create_async_engine(_db_url: str, **kwargs):
103105

104106

105107
@pytest.mark.parametrize('dialect_name', ['sqlite', 'postgresql'])
106-
def test_database_session_service_strips_timezone_for_dialect(dialect_name):
107-
"""Verifies that timezone-aware datetimes are converted to naive datetimes
108-
for SQLite and PostgreSQL to avoid 'can't subtract offset-naive and
109-
offset-aware datetimes' errors.
110-
111-
PostgreSQL's default TIMESTAMP type is WITHOUT TIME ZONE, which cannot
112-
accept timezone-aware datetime objects when using asyncpg. SQLite also
113-
requires naive datetimes.
114-
"""
115-
# Simulate the logic in create_session
116-
is_sqlite = dialect_name == 'sqlite'
117-
is_postgres = dialect_name == 'postgresql'
108+
def test_update_time_from_timestamp_strips_timezone_for_naive_utc_dialects(
109+
dialect_name,
110+
):
111+
"""update_time_from_timestamp returns a UTC-naive datetime for SQLite and
112+
PostgreSQL, which store TIMESTAMP WITHOUT TIME ZONE values."""
113+
posix_ts = 1_700_000_000.0
114+
result = update_time_from_timestamp(posix_ts, dialect_name)
115+
assert result.tzinfo is None
116+
# Value must represent the correct UTC instant.
117+
assert result == datetime.fromtimestamp(posix_ts, timezone.utc).replace(
118+
tzinfo=None
119+
)
120+
121+
122+
def test_update_time_from_timestamp_preserves_timezone_for_other_dialects():
123+
"""update_time_from_timestamp returns a UTC-aware datetime for dialects
124+
that support TIMESTAMP WITH TIME ZONE (e.g. MySQL)."""
125+
posix_ts = 1_700_000_000.0
126+
result = update_time_from_timestamp(posix_ts, 'mysql')
127+
assert result.tzinfo is not None
128+
assert result == datetime.fromtimestamp(posix_ts, timezone.utc)
118129

119-
now = datetime.now(timezone.utc)
120-
assert now.tzinfo is not None # Starts with timezone
121130

122-
if is_sqlite or is_postgres:
123-
now = now.replace(tzinfo=None)
131+
@pytest.mark.parametrize('dialect_name', ['sqlite', 'postgresql'])
132+
def test_update_timestamp_from_dt_treats_naive_dt_as_utc_for_naive_utc_dialects(
133+
dialect_name,
134+
):
135+
"""update_timestamp_from_dt must reattach UTC tzinfo before computing the
136+
POSIX timestamp for SQLite and PostgreSQL.
137+
138+
This is the core of the bug fixed in commit 0e5790805a2f4d:
139+
PostgreSQL returns a UTC-naive datetime, so calling .timestamp() directly
140+
on a non-UTC host would interpret it as local time and produce a wrong
141+
POSIX value.
142+
"""
143+
posix_ts = 1_700_000_000.0
144+
# Simulate a naive datetime as returned by PostgreSQL / SQLite.
145+
naive_utc_dt = datetime.fromtimestamp(posix_ts, timezone.utc).replace(
146+
tzinfo=None
147+
)
148+
assert naive_utc_dt.tzinfo is None
124149

125-
# Both SQLite and PostgreSQL should have timezone stripped
126-
assert now.tzinfo is None
150+
result = update_timestamp_from_dt(naive_utc_dt, dialect_name)
127151

152+
assert result == posix_ts
128153

129-
def test_database_session_service_preserves_timezone_for_other_dialects():
130-
"""Verifies that timezone info is preserved for dialects that support it."""
131-
# For dialects like MySQL with explicit timezone support, we don't strip
132-
dialect_name = 'mysql'
133-
is_sqlite = dialect_name == 'sqlite'
134-
is_postgres = dialect_name == 'postgresql'
135154

136-
now = datetime.now(timezone.utc)
137-
assert now.tzinfo is not None
155+
def test_update_timestamp_from_dt_uses_tzinfo_for_aware_dialects():
156+
"""update_timestamp_from_dt uses the datetime's own tzinfo for dialects
157+
that return timezone-aware datetimes (e.g. MySQL)."""
158+
posix_ts = 1_700_000_000.0
159+
aware_dt = datetime.fromtimestamp(posix_ts, timezone.utc)
160+
assert aware_dt.tzinfo is not None
138161

139-
if is_sqlite or is_postgres:
140-
now = now.replace(tzinfo=None)
162+
result = update_timestamp_from_dt(aware_dt, 'mysql')
141163

142-
# MySQL should preserve timezone (if the column type supports it)
143-
assert now.tzinfo is not None
164+
assert result == posix_ts
144165

145166

146167
def test_database_session_service_respects_pool_pre_ping_override():

0 commit comments

Comments
 (0)