Skip to content

Commit 8f56d78

Browse files
committed
improve main.py sandbox-mode test coverage
moving some shared utilities to test/utils.py
1 parent ad3093b commit 8f56d78

File tree

3 files changed

+73
-19
lines changed

3 files changed

+73
-19
lines changed

test/pytests/test_main.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import click
1515
from click.testing import CliRunner
16+
import pymysql
1617
from pymysql.err import OperationalError
1718
import pytest
1819

@@ -38,7 +39,9 @@
3839
TEMPFILE_PREFIX,
3940
USER,
4041
DummyFormatter,
42+
DummyLogger,
4143
FakeCursorBase,
44+
RecordingSQLExecute,
4245
ReusableLock,
4346
call_click_entrypoint_direct,
4447
dbtest,
@@ -2365,3 +2368,53 @@ def test_get_last_query_returns_latest_query() -> None:
23652368
cli.query_history = [main.Query('select 1', True, False)]
23662369

23672370
assert main.MyCli.get_last_query(cli) == 'select 1'
2371+
2372+
2373+
def test_connect_reports_expired_password_login_error(monkeypatch: pytest.MonkeyPatch) -> None:
2374+
cli = make_bare_mycli()
2375+
cli.my_cnf = {'client': {}, 'mysqld': {}}
2376+
cli.config_without_package_defaults = {'connection': {}}
2377+
cli.config = {'connection': {}, 'main': {}}
2378+
cli.logger = cast(Any, DummyLogger())
2379+
echo_calls: list[str] = []
2380+
cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment]
2381+
monkeypatch.setattr(main, 'WIN', False)
2382+
monkeypatch.setattr(main, 'str_to_bool', lambda value: False)
2383+
2384+
class ExpiredPasswordSQLExecute(RecordingSQLExecute):
2385+
calls: list[dict[str, Any]] = []
2386+
side_effects: list[Any] = [pymysql.OperationalError(main.ER_MUST_CHANGE_PASSWORD_LOGIN, 'must change password')]
2387+
2388+
monkeypatch.setattr(main, 'SQLExecute', ExpiredPasswordSQLExecute)
2389+
2390+
with pytest.raises(SystemExit):
2391+
main.MyCli.connect(cli, host='db', port=3307)
2392+
2393+
assert any('password has expired' in message for message in echo_calls)
2394+
2395+
2396+
def test_connect_sets_cli_sandbox_mode_when_sqlexecute_enters_sandbox(monkeypatch: pytest.MonkeyPatch) -> None:
2397+
cli = make_bare_mycli()
2398+
cli.my_cnf = {'client': {}, 'mysqld': {}}
2399+
cli.config_without_package_defaults = {'connection': {}}
2400+
cli.config = {'connection': {}, 'main': {}}
2401+
cli.logger = cast(Any, DummyLogger())
2402+
echo_calls: list[str] = []
2403+
cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment]
2404+
monkeypatch.setattr(main, 'WIN', False)
2405+
monkeypatch.setattr(main, 'str_to_bool', lambda value: False)
2406+
2407+
class SandboxSQLExecute(RecordingSQLExecute):
2408+
calls: list[dict[str, Any]] = []
2409+
side_effects: list[Any] = []
2410+
2411+
def __init__(self, **kwargs: Any) -> None:
2412+
super().__init__(**kwargs)
2413+
self.sandbox_mode = True
2414+
2415+
monkeypatch.setattr(main, 'SQLExecute', SandboxSQLExecute)
2416+
2417+
main.MyCli.connect(cli, host='db', port=3307)
2418+
2419+
assert cli.sandbox_mode is True
2420+
assert any('password has expired' in message for message in echo_calls)

test/pytests/test_main_regression.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
DummyFormatter,
3939
DummyLogger,
4040
FakeCursorBase,
41+
RecordingSQLExecute,
4142
call_click_entrypoint_direct,
4243
make_bare_mycli,
4344
make_dummy_mycli_class,
@@ -60,25 +61,6 @@ def as_bool(self, key: str) -> bool:
6061
return str(self[key]).lower() == 'true'
6162

6263

63-
class RecordingSQLExecute:
64-
calls: list[dict[str, Any]] = []
65-
side_effects: list[Any] = []
66-
67-
def __init__(self, **kwargs: Any) -> None:
68-
type(self).calls.append(dict(kwargs))
69-
if type(self).side_effects:
70-
effect = type(self).side_effects.pop(0)
71-
if isinstance(effect, BaseException):
72-
raise effect
73-
if callable(effect):
74-
effect(kwargs)
75-
self.kwargs = kwargs
76-
self.dbname = kwargs.get('database')
77-
self.user = kwargs.get('user')
78-
self.conn = kwargs.get('conn')
79-
self.sandbox_mode = False
80-
81-
8264
class ToggleBool:
8365
def __init__(self, values: list[bool]) -> None:
8466
self.values = values

test/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,25 @@ def __iter__(self) -> Iterator[tuple[Any, ...]]:
101101
return iter(self._rows)
102102

103103

104+
class RecordingSQLExecute:
105+
calls: list[dict[str, Any]] = []
106+
side_effects: list[Any] = []
107+
108+
def __init__(self, **kwargs: Any) -> None:
109+
type(self).calls.append(dict(kwargs))
110+
if type(self).side_effects:
111+
effect = type(self).side_effects.pop(0)
112+
if isinstance(effect, BaseException):
113+
raise effect
114+
if callable(effect):
115+
effect(kwargs)
116+
self.kwargs = kwargs
117+
self.dbname = kwargs.get('database')
118+
self.user = kwargs.get('user')
119+
self.conn = kwargs.get('conn')
120+
self.sandbox_mode = False
121+
122+
104123
def make_bare_mycli() -> Any:
105124
cli = object.__new__(main.MyCli)
106125
cli.logger = cast(Any, DummyLogger())

0 commit comments

Comments
 (0)