Skip to content

Commit 03910e9

Browse files
Fix async mocks
1 parent 9387cb3 commit 03910e9

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

tests/unittests/memory/test_firestore_memory_service.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,18 @@
2424

2525
@pytest.fixture
2626
def mock_firestore_client():
27-
client = mock.AsyncMock()
28-
collection_ref = mock.AsyncMock()
27+
client = mock.MagicMock()
28+
collection_ref = mock.MagicMock()
2929
client.collection_group.return_value = collection_ref
30+
31+
# where() should return self (collection_ref) to allow chaining
3032
collection_ref.where.return_value = collection_ref
3133

3234
# Mock get() for documents
33-
doc_snapshot = mock.AsyncMock()
35+
doc_snapshot = mock.MagicMock()
3436
doc_snapshot.to_dict.return_value = {}
35-
collection_ref.get.return_value = [doc_snapshot]
37+
38+
collection_ref.get = mock.AsyncMock(return_value=[doc_snapshot])
3639

3740
return client
3841

tests/unittests/sessions/test_firestore_session_service.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,36 @@
2323

2424
@pytest.fixture
2525
def mock_firestore_client():
26-
client = mock.AsyncMock()
27-
# Mock collection and document references
28-
collection_ref = mock.AsyncMock()
29-
doc_ref = mock.AsyncMock()
30-
subcollection_ref = mock.AsyncMock()
31-
subdoc_ref = mock.AsyncMock()
26+
client = mock.MagicMock()
27+
collection_ref = mock.MagicMock()
28+
doc_ref = mock.MagicMock()
29+
subcollection_ref = mock.MagicMock()
30+
subdoc_ref = mock.MagicMock()
3231

3332
client.collection.return_value = collection_ref
3433
collection_ref.document.return_value = doc_ref
3534
doc_ref.collection.return_value = subcollection_ref
3635
subcollection_ref.document.return_value = subdoc_ref
3736

3837
# Mock get() for documents
39-
doc_snapshot = mock.AsyncMock()
38+
doc_snapshot = mock.MagicMock()
4039
doc_snapshot.exists = False
4140
doc_snapshot.to_dict.return_value = {}
42-
doc_ref.get.return_value = doc_snapshot
43-
subdoc_ref.get.return_value = doc_snapshot
41+
42+
doc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
43+
subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot)
44+
45+
# Mock subcollection get() (for events list in delete_session)
46+
subcollection_ref.get = mock.AsyncMock(return_value=[])
4447

4548
# Mock collection group
4649
client.collection_group.return_value = collection_ref
4750

51+
# Mock batch
52+
batch = mock.MagicMock()
53+
client.batch.return_value = batch
54+
batch.commit = mock.AsyncMock()
55+
4856
return client
4957

5058

tests/unittests/test_firestore_database_runner.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,15 @@ def mock_agent():
3131
def test_create_firestore_runner_with_arg(mock_agent, monkeypatch):
3232
monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False)
3333

34-
# Mock GcsArtifactService to avoid real client init
35-
with mock.patch(
36-
"google.adk.firestore_database_runner.GcsArtifactService"
37-
) as mock_gcs:
34+
with (
35+
mock.patch(
36+
"google.adk.firestore_database_runner.FirestoreSessionService"
37+
),
38+
mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"),
39+
mock.patch(
40+
"google.adk.firestore_database_runner.GcsArtifactService"
41+
) as mock_gcs,
42+
):
3843
runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket")
3944

4045
assert runner is not None
@@ -44,9 +49,15 @@ def test_create_firestore_runner_with_arg(mock_agent, monkeypatch):
4449
def test_create_firestore_runner_with_env(mock_agent, monkeypatch):
4550
monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket")
4651

47-
with mock.patch(
48-
"google.adk.firestore_database_runner.GcsArtifactService"
49-
) as mock_gcs:
52+
with (
53+
mock.patch(
54+
"google.adk.firestore_database_runner.FirestoreSessionService"
55+
),
56+
mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"),
57+
mock.patch(
58+
"google.adk.firestore_database_runner.GcsArtifactService"
59+
) as mock_gcs,
60+
):
5061
runner = create_firestore_runner(mock_agent)
5162

5263
assert runner is not None

0 commit comments

Comments
 (0)