Skip to content

Commit d864822

Browse files
authored
PYTHON-5788 - Refine withTransaction timeout error wrapping semantics… (#2745)
1 parent 02320d6 commit d864822

File tree

4 files changed

+40
-14
lines changed

4 files changed

+40
-14
lines changed

pymongo/asynchronous/client_session.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
516516
def _make_timeout_error(error: BaseException) -> PyMongoError:
517517
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
518518
if _csot.remaining() is not None:
519-
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
519+
timeout_error: PyMongoError = ExecutionTimeout(
520+
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
521+
)
520522
else:
521-
return NetworkTimeout(str(error))
523+
timeout_error = NetworkTimeout(str(error))
524+
if isinstance(error, PyMongoError):
525+
timeout_error._error_labels = error._error_labels.copy()
526+
return timeout_error
522527

523528

524529
_T = TypeVar("_T")
@@ -804,15 +809,17 @@ async def callback(session, custom_arg, custom_kwarg=None):
804809
await self.commit_transaction()
805810
except PyMongoError as exc:
806811
last_error = exc
807-
if not _within_time_limit(start_time):
808-
raise _make_timeout_error(last_error) from exc
809812
if exc.has_error_label(
810813
"UnknownTransactionCommitResult"
811814
) and not _max_time_expired_error(exc):
815+
if not _within_time_limit(start_time):
816+
raise _make_timeout_error(last_error) from exc
812817
# Retry the commit.
813818
continue
814819

815820
if exc.has_error_label("TransientTransactionError"):
821+
if not _within_time_limit(start_time):
822+
raise _make_timeout_error(last_error) from exc
816823
# Retry the entire transaction.
817824
break
818825
raise

pymongo/synchronous/client_session.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
514514
def _make_timeout_error(error: BaseException) -> PyMongoError:
515515
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
516516
if _csot.remaining() is not None:
517-
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
517+
timeout_error: PyMongoError = ExecutionTimeout(
518+
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
519+
)
518520
else:
519-
return NetworkTimeout(str(error))
521+
timeout_error = NetworkTimeout(str(error))
522+
if isinstance(error, PyMongoError):
523+
timeout_error._error_labels = error._error_labels.copy()
524+
return timeout_error
520525

521526

522527
_T = TypeVar("_T")
@@ -800,15 +805,17 @@ def callback(session, custom_arg, custom_kwarg=None):
800805
self.commit_transaction()
801806
except PyMongoError as exc:
802807
last_error = exc
803-
if not _within_time_limit(start_time):
804-
raise _make_timeout_error(last_error) from exc
805808
if exc.has_error_label(
806809
"UnknownTransactionCommitResult"
807810
) and not _max_time_expired_error(exc):
811+
if not _within_time_limit(start_time):
812+
raise _make_timeout_error(last_error) from exc
808813
# Retry the commit.
809814
continue
810815

811816
if exc.has_error_label("TransientTransactionError"):
817+
if not _within_time_limit(start_time):
818+
raise _make_timeout_error(last_error) from exc
812819
# Retry the entire transaction.
813820
break
814821
raise

test/asynchronous/test_transactions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,12 @@ async def callback(session):
500500
listener.reset()
501501
async with client.start_session() as s:
502502
with PatchSessionTimeout(0):
503-
with self.assertRaises(NetworkTimeout):
503+
with self.assertRaises(NetworkTimeout) as context:
504504
await s.with_transaction(callback)
505505

506506
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
507+
# Assert that the timeout error has the same labels as the error it wraps.
508+
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
507509

508510
@async_client_context.require_test_commands
509511
@async_client_context.require_transactions
@@ -534,10 +536,12 @@ async def callback(session):
534536

535537
async with client.start_session() as s:
536538
with PatchSessionTimeout(0):
537-
with self.assertRaises(NetworkTimeout):
539+
with self.assertRaises(NetworkTimeout) as context:
538540
await s.with_transaction(callback)
539541

540542
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
543+
# Assert that the timeout error has the same labels as the error it wraps.
544+
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
541545

542546
@async_client_context.require_test_commands
543547
@async_client_context.require_transactions
@@ -565,14 +569,16 @@ async def callback(session):
565569

566570
async with client.start_session() as s:
567571
with PatchSessionTimeout(0):
568-
with self.assertRaises(NetworkTimeout):
572+
with self.assertRaises(NetworkTimeout) as context:
569573
await s.with_transaction(callback)
570574

571575
# One insert for the callback and two commits (includes the automatic
572576
# retry).
573577
self.assertEqual(
574578
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
575579
)
580+
# Assert that the timeout error has the same labels as the error it wraps.
581+
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
576582

577583
@async_client_context.require_transactions
578584
async def test_callback_not_retried_after_csot_timeout(self):

test/test_transactions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,12 @@ def callback(session):
492492
listener.reset()
493493
with client.start_session() as s:
494494
with PatchSessionTimeout(0):
495-
with self.assertRaises(NetworkTimeout):
495+
with self.assertRaises(NetworkTimeout) as context:
496496
s.with_transaction(callback)
497497

498498
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
499+
# Assert that the timeout error has the same labels as the error it wraps.
500+
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
499501

500502
@client_context.require_test_commands
501503
@client_context.require_transactions
@@ -524,10 +526,12 @@ def callback(session):
524526

525527
with client.start_session() as s:
526528
with PatchSessionTimeout(0):
527-
with self.assertRaises(NetworkTimeout):
529+
with self.assertRaises(NetworkTimeout) as context:
528530
s.with_transaction(callback)
529531

530532
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
533+
# Assert that the timeout error has the same labels as the error it wraps.
534+
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
531535

532536
@client_context.require_test_commands
533537
@client_context.require_transactions
@@ -553,14 +557,16 @@ def callback(session):
553557

554558
with client.start_session() as s:
555559
with PatchSessionTimeout(0):
556-
with self.assertRaises(NetworkTimeout):
560+
with self.assertRaises(NetworkTimeout) as context:
557561
s.with_transaction(callback)
558562

559563
# One insert for the callback and two commits (includes the automatic
560564
# retry).
561565
self.assertEqual(
562566
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
563567
)
568+
# Assert that the timeout error has the same labels as the error it wraps.
569+
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
564570

565571
@client_context.require_transactions
566572
def test_callback_not_retried_after_csot_timeout(self):

0 commit comments

Comments
 (0)