|
13 | 13 |
|
14 | 14 | import click |
15 | 15 | from click.testing import CliRunner |
| 16 | +import pymysql |
16 | 17 | from pymysql.err import OperationalError |
17 | 18 | import pytest |
18 | 19 |
|
|
38 | 39 | TEMPFILE_PREFIX, |
39 | 40 | USER, |
40 | 41 | DummyFormatter, |
| 42 | + DummyLogger, |
41 | 43 | FakeCursorBase, |
| 44 | + RecordingSQLExecute, |
42 | 45 | ReusableLock, |
43 | 46 | call_click_entrypoint_direct, |
44 | 47 | dbtest, |
@@ -2365,3 +2368,53 @@ def test_get_last_query_returns_latest_query() -> None: |
2365 | 2368 | cli.query_history = [main.Query('select 1', True, False)] |
2366 | 2369 |
|
2367 | 2370 | assert main.MyCli.get_last_query(cli) == 'select 1' |
| 2371 | + |
| 2372 | + |
| 2373 | +def test_connect_reports_expired_password_login_error(monkeypatch: pytest.MonkeyPatch) -> None: |
| 2374 | + cli = make_bare_mycli() |
| 2375 | + cli.my_cnf = {'client': {}, 'mysqld': {}} |
| 2376 | + cli.config_without_package_defaults = {'connection': {}} |
| 2377 | + cli.config = {'connection': {}, 'main': {}} |
| 2378 | + cli.logger = cast(Any, DummyLogger()) |
| 2379 | + echo_calls: list[str] = [] |
| 2380 | + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] |
| 2381 | + monkeypatch.setattr(main, 'WIN', False) |
| 2382 | + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) |
| 2383 | + |
| 2384 | + class ExpiredPasswordSQLExecute(RecordingSQLExecute): |
| 2385 | + calls: list[dict[str, Any]] = [] |
| 2386 | + side_effects: list[Any] = [pymysql.OperationalError(main.ER_MUST_CHANGE_PASSWORD_LOGIN, 'must change password')] |
| 2387 | + |
| 2388 | + monkeypatch.setattr(main, 'SQLExecute', ExpiredPasswordSQLExecute) |
| 2389 | + |
| 2390 | + with pytest.raises(SystemExit): |
| 2391 | + main.MyCli.connect(cli, host='db', port=3307) |
| 2392 | + |
| 2393 | + assert any('password has expired' in message for message in echo_calls) |
| 2394 | + |
| 2395 | + |
| 2396 | +def test_connect_sets_cli_sandbox_mode_when_sqlexecute_enters_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: |
| 2397 | + cli = make_bare_mycli() |
| 2398 | + cli.my_cnf = {'client': {}, 'mysqld': {}} |
| 2399 | + cli.config_without_package_defaults = {'connection': {}} |
| 2400 | + cli.config = {'connection': {}, 'main': {}} |
| 2401 | + cli.logger = cast(Any, DummyLogger()) |
| 2402 | + echo_calls: list[str] = [] |
| 2403 | + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] |
| 2404 | + monkeypatch.setattr(main, 'WIN', False) |
| 2405 | + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) |
| 2406 | + |
| 2407 | + class SandboxSQLExecute(RecordingSQLExecute): |
| 2408 | + calls: list[dict[str, Any]] = [] |
| 2409 | + side_effects: list[Any] = [] |
| 2410 | + |
| 2411 | + def __init__(self, **kwargs: Any) -> None: |
| 2412 | + super().__init__(**kwargs) |
| 2413 | + self.sandbox_mode = True |
| 2414 | + |
| 2415 | + monkeypatch.setattr(main, 'SQLExecute', SandboxSQLExecute) |
| 2416 | + |
| 2417 | + main.MyCli.connect(cli, host='db', port=3307) |
| 2418 | + |
| 2419 | + assert cli.sandbox_mode is True |
| 2420 | + assert any('password has expired' in message for message in echo_calls) |
0 commit comments