Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 93 additions & 4 deletions sqlit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,69 @@ def _parse_float_value(value: str | None, default: float) -> float:
return default


def _add_stdin_secret_flags(parser: argparse.ArgumentParser, *, include_ssh: bool) -> None:
"""Attach --password-stdin (and optionally --ssh-password-stdin) to a parser."""
parser.add_argument(
"--password-stdin",
dest="password_stdin",
action="store_true",
help="Read the password from stdin (one line, trailing newline stripped)",
)
if include_ssh:
parser.add_argument(
"--ssh-password-stdin",
dest="ssh_password_stdin",
action="store_true",
help="Read the SSH password from stdin (one line, trailing newline stripped)",
)


def _resolve_stdin_secrets(args: argparse.Namespace) -> None:
"""Populate args.password / args.url / args.ssh_password from stdin if requested.

Recognised stdin-trigger attrs: ``password_stdin``, ``url_stdin``,
``ssh_password_stdin``. At most one may be set per invocation — stdin
is a single stream and we read one line from it. The corresponding
cleartext flag must not also be set.
"""
from sqlit.domains.connections.domain.stdin_secret import (
StdinSecretError,
read_secret_from_stdin,
)

requests: list[tuple[str, str]] = []
if getattr(args, "password_stdin", False):
requests.append(("password", "password"))
if getattr(args, "url_stdin", False):
requests.append(("url", "url"))
if getattr(args, "ssh_password_stdin", False):
requests.append(("ssh_password", "ssh-password"))

if not requests:
return

if len(requests) > 1:
flags = ", ".join(f"--{label}-stdin" for _, label in requests)
raise SystemExit(
f"Error: only one of {flags} may be used per invocation "
f"(stdin can only feed one secret)."
)

attr, label = requests[0]
existing = getattr(args, attr, None)
if existing:
raise SystemExit(
f"Error: --{label} and --{label}-stdin are mutually exclusive."
)

try:
value = read_secret_from_stdin(label=label)
except StdinSecretError as exc:
raise SystemExit(f"Error: {exc}")

setattr(args, attr, value)


def _resolve_startup_log_path(argv: list[str]) -> Path | None:
env_profile = os.environ.get("SQLIT_PROFILE_STARTUP") == "1"
env_exit = os.environ.get("SQLIT_PROFILE_STARTUP_EXIT") == "1"
Expand Down Expand Up @@ -430,7 +493,16 @@ def main() -> int:
parser.add_argument("--port", help="Temporary connection port")
parser.add_argument("--database", help="Temporary connection database name")
parser.add_argument("--username", help="Temporary connection username")
parser.add_argument("--password", help="Temporary connection password")
parser.add_argument(
"--password",
help="Temporary connection password (or use --password-stdin to read from stdin)",
)
parser.add_argument(
"--password-stdin",
dest="password_stdin",
action="store_true",
help="Read the password from stdin (one line, trailing newline stripped)",
)
parser.add_argument("--file-path", help="Temporary connection file path (SQLite/DuckDB)")
parser.add_argument(
"--auth-type",
Expand Down Expand Up @@ -568,13 +640,22 @@ def main() -> int:
add_parser.add_argument(
"--url",
metavar="URL",
help="Connection URL (e.g., postgresql://user:pass@host:5432/db). Requires --name.",
help=(
"Connection URL (e.g., postgresql://user:pass@host:5432/db). "
"Requires --name. Use --url-stdin to read it from stdin instead."
),
)
add_parser.add_argument(
"--url-stdin",
dest="url_stdin",
action="store_true",
help="Read the connection URL from stdin (one line, trailing newline stripped)",
)
add_parser.add_argument(
"--name",
"-n",
dest="url_name",
help="Connection name (required when using --url)",
help="Connection name (required when using --url / --url-stdin)",
)
add_provider_parsers = add_parser.add_subparsers(dest="provider", metavar="PROVIDER")
for db_type in get_supported_db_types():
Expand All @@ -587,6 +668,7 @@ def main() -> int:
add_schema_arguments(provider_parser, schema, include_name=True, name_required=True)
provider_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password")
provider_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password")
_add_stdin_secret_flags(provider_parser, include_ssh=True)
provider_parser.add_argument(
"--alert",
metavar="MODE",
Expand All @@ -601,7 +683,11 @@ def main() -> int:
edit_parser.add_argument("--port", "-P", help="Port")
edit_parser.add_argument("--database", "-d", help="Database name")
edit_parser.add_argument("--username", "-u", help="Username")
edit_parser.add_argument("--password", "-p", help="Password")
edit_parser.add_argument(
"--password",
"-p",
help="Password (or use --password-stdin to read from stdin)",
)
edit_parser.add_argument(
"--auth-type",
"-a",
Expand All @@ -611,6 +697,7 @@ def main() -> int:
edit_parser.add_argument("--file-path", help="Database file path (SQLite only)")
edit_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password")
edit_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password")
_add_stdin_secret_flags(edit_parser, include_ssh=True)
edit_parser.add_argument(
"--alert",
metavar="MODE",
Expand All @@ -632,6 +719,7 @@ def main() -> int:
add_schema_arguments(provider_parser, schema, include_name=True, name_required=False)
provider_parser.add_argument("--password-command", dest="password_command", help="Shell command to retrieve the database password")
provider_parser.add_argument("--ssh-password-command", dest="ssh_password_command", help="Shell command to retrieve the SSH password")
_add_stdin_secret_flags(provider_parser, include_ssh=True)
provider_parser.add_argument(
"--alert",
metavar="MODE",
Expand Down Expand Up @@ -705,6 +793,7 @@ def main() -> int:

with startup_span("cli_parse_args"):
args = parser.parse_args(filtered_argv[1:]) # Skip program name
_resolve_stdin_secrets(args)
log_startup_step("cli_parse_end")

with startup_span("runtime_build"):
Expand Down
47 changes: 47 additions & 0 deletions sqlit/domains/connections/domain/stdin_secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Read a secret (password, connection URL, ...) from stdin.

Used by the `--password-stdin` / `--url-stdin` / `--ssh-password-stdin`
flags so callers can pipe credentials in instead of passing them on the
command line, where they'd be visible to other users via ``ps`` or
``/proc/<pid>/cmdline``.
"""

from __future__ import annotations

import sys
from typing import TextIO


class StdinSecretError(Exception):
"""Raised when a secret can't be read from stdin."""


def read_secret_from_stdin(
*,
label: str = "secret",
stream: TextIO | None = None,
) -> str:
"""Read one line from stdin and strip the trailing newline.

Refuses to read when stdin is a TTY — there's no plausible
non-interactive workflow for that, and silently waiting on user
input would be confusing when the caller intended a piped value.
Use ``label`` to make the error point at the offending flag (e.g.
``password``, ``url``).
"""
source: TextIO = stream if stream is not None else sys.stdin
if source.isatty():
raise StdinSecretError(
f"Refusing to read {label} from stdin: stdin is a TTY. "
f"Pipe the value in, e.g. `echo $SECRET | sqlit ... --{label}-stdin`."
)

line = source.readline()
if line == "":
raise StdinSecretError(f"No {label} received on stdin (EOF).")

if line.endswith("\r\n"):
return line[:-2]
if line.endswith("\n"):
return line[:-1]
return line
103 changes: 103 additions & 0 deletions tests/cli/test_cli_main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from __future__ import annotations

import subprocess
import sys
from pathlib import Path

from tests.conftest import run_cli


def _run_cli_with_stdin(*args: str, stdin: str, env_config_dir: Path) -> subprocess.CompletedProcess:
"""Invoke the sqlit CLI with a piped stdin payload."""
cmd = [sys.executable, "-m", "sqlit.cli", *args]
return subprocess.run(
cmd,
input=stdin,
capture_output=True,
text=True,
env={"SQLIT_CONFIG_DIR": str(env_config_dir), "PATH": __import__("os").environ.get("PATH", "")},
)


def test_cli_connections_list_empty(tmp_path: Path, monkeypatch):
settings_path = tmp_path / "settings.json"
settings_path.write_text('{"allow_plaintext_credentials": true}', encoding="utf-8")
Expand All @@ -15,3 +29,92 @@ def test_cli_connections_list_empty(tmp_path: Path, monkeypatch):

assert result.returncode == 0
assert "No saved connections." in result.stdout


def test_url_stdin_creates_connection(tmp_path: Path):
settings_path = tmp_path / "settings.json"
settings_path.write_text('{"allow_plaintext_credentials": true}', encoding="utf-8")

result = _run_cli_with_stdin(
"connections", "add", "--url-stdin", "--name", "StdinURL",
stdin="sqlite:///tmp/sqlit-stdin-test.db\n",
env_config_dir=tmp_path,
)

assert result.returncode == 0, result.stderr
assert "StdinURL" in result.stdout


def test_url_stdin_rejects_when_url_also_provided(tmp_path: Path):
settings_path = tmp_path / "settings.json"
settings_path.write_text('{"allow_plaintext_credentials": true}', encoding="utf-8")

result = _run_cli_with_stdin(
"connections", "add",
"--url", "sqlite:///tmp/a.db",
"--url-stdin",
"--name", "X",
stdin="sqlite:///tmp/b.db\n",
env_config_dir=tmp_path,
)

assert result.returncode != 0
assert "mutually exclusive" in (result.stderr + result.stdout)


def test_password_stdin_mutex_with_password(tmp_path: Path):
settings_path = tmp_path / "settings.json"
settings_path.write_text('{"allow_plaintext_credentials": true}', encoding="utf-8")

result = _run_cli_with_stdin(
"connect", "postgresql",
"--name", "X",
"--server", "localhost",
"--port", "5432",
"--database", "d",
"--username", "u",
"--password", "cleartext",
"--password-stdin",
stdin="frompipe\n",
env_config_dir=tmp_path,
)

assert result.returncode != 0
assert "mutually exclusive" in (result.stderr + result.stdout)


def test_multiple_stdin_flags_rejected(tmp_path: Path):
settings_path = tmp_path / "settings.json"
settings_path.write_text('{"allow_plaintext_credentials": true}', encoding="utf-8")

result = _run_cli_with_stdin(
"connections", "edit", "Nonexistent",
"--password-stdin",
"--ssh-password-stdin",
stdin="x\n",
env_config_dir=tmp_path,
)

assert result.returncode != 0
output = result.stderr + result.stdout
assert "only one" in output and "stdin" in output


def test_password_stdin_eof_errors_cleanly(tmp_path: Path):
settings_path = tmp_path / "settings.json"
settings_path.write_text('{"allow_plaintext_credentials": true}', encoding="utf-8")

result = _run_cli_with_stdin(
"connect", "postgresql",
"--name", "X",
"--server", "localhost",
"--port", "5432",
"--database", "d",
"--username", "u",
"--password-stdin",
stdin="",
env_config_dir=tmp_path,
)

assert result.returncode != 0
assert "EOF" in (result.stderr + result.stdout)
61 changes: 61 additions & 0 deletions tests/unit/test_stdin_secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Tests for the stdin-secret reader used by --password-stdin / --url-stdin."""

from __future__ import annotations

import io
from unittest.mock import patch

import pytest

from sqlit.domains.connections.domain.stdin_secret import (
StdinSecretError,
read_secret_from_stdin,
)


class _FakeStream(io.StringIO):
def __init__(self, contents: str, *, isatty: bool = False) -> None:
super().__init__(contents)
self._isatty = isatty

def isatty(self) -> bool: # type: ignore[override]
return self._isatty


class TestReadSecretFromStdin:
def test_strips_trailing_newline(self) -> None:
assert read_secret_from_stdin(stream=_FakeStream("secret\n")) == "secret"

def test_strips_crlf(self) -> None:
assert read_secret_from_stdin(stream=_FakeStream("secret\r\n")) == "secret"

def test_preserves_internal_spaces(self) -> None:
assert read_secret_from_stdin(stream=_FakeStream("a b c\n")) == "a b c"

def test_no_trailing_newline_is_returned_verbatim(self) -> None:
assert read_secret_from_stdin(stream=_FakeStream("naked")) == "naked"

def test_only_reads_first_line(self) -> None:
stream = _FakeStream("first\nsecond\n")
assert read_secret_from_stdin(stream=stream) == "first"

def test_refuses_tty(self) -> None:
with pytest.raises(StdinSecretError, match="TTY"):
read_secret_from_stdin(stream=_FakeStream("ignored\n", isatty=True))

def test_refuses_empty_stream(self) -> None:
with pytest.raises(StdinSecretError, match="EOF"):
read_secret_from_stdin(stream=_FakeStream(""))

def test_label_appears_in_tty_error(self) -> None:
with pytest.raises(StdinSecretError, match="url"):
read_secret_from_stdin(label="url", stream=_FakeStream("x", isatty=True))

def test_label_appears_in_eof_error(self) -> None:
with pytest.raises(StdinSecretError, match="ssh-password"):
read_secret_from_stdin(label="ssh-password", stream=_FakeStream(""))

def test_defaults_to_sys_stdin(self) -> None:
with patch("sqlit.domains.connections.domain.stdin_secret.sys") as mock_sys:
mock_sys.stdin = _FakeStream("from-real-stdin\n")
assert read_secret_from_stdin() == "from-real-stdin"
Loading