diff --git a/Cargo.lock b/Cargo.lock index f69a3417..2bac847b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -212,9 +212,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cfg_aliases" @@ -273,6 +273,12 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.6" @@ -283,6 +289,20 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "6.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deadpool" version = "0.12.1" @@ -514,6 +534,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.4" @@ -725,9 +751,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "openssl" @@ -998,6 +1024,7 @@ dependencies = [ "bytes", "chrono", "chrono-tz", + "dashmap", "deadpool-postgres", "futures", "futures-channel", diff --git a/Cargo.toml b/Cargo.toml index 6d4cefb7..7f690e7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,3 +64,4 @@ futures-channel = "0.3.31" futures = "0.3.31" regex = "1.11.1" once_cell = "1.20.3" +dashmap = "6" diff --git a/pyproject.toml b/pyproject.toml index a9c234fa..22351032 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,10 +104,12 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "python/psqlpy/*" = ["PYI021"] "python/tests/*" = [ - "S101", # Use of assert detected - "S608", # Possible SQL injection vector through string-based query construction - "D103", # Missing docstring in public function - "S311", # Standard pseudo-random generators are not suitable for security/cryptographic purposes + "S101", # Use of assert detected + "S608", # Possible SQL injection via string-based query construction + "D103", # Missing docstring in public function + "S311", # Standard pseudo-random generators not suitable for security + "PLR2004", # Magic value in comparison (common in test assertions) + "D205", # 1 blank line required between summary and description ] "python/psqlpy/_internal/exceptions.pyi" = [ "D205", diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index e0f4f794..f128df59 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -17,6 +17,33 @@ _RowFactoryRV = TypeVar( ParamsT: TypeAlias = Sequence[Any] | Mapping[str, Any] | None +class Record: + """An asyncpg-compatible row type with eagerly decoded column values. + + Supports positional indexing (`row[0]`), by-name indexing (`row["col"]`), + slicing, iteration, and dict-like access methods. + """ + + def __len__(self) -> int: ... + @typing.overload + def __getitem__(self, key: int) -> Any: ... + @typing.overload + def __getitem__(self, key: str) -> Any: ... + @typing.overload + def __getitem__(self, key: slice) -> list[Any]: ... + def __iter__(self) -> typing.Iterator[Any]: ... + def get(self, key: str, default: Any = None) -> Any: + """Return column value by name, or `default` if the column does not exist.""" + + def keys(self) -> list[str]: + """Return ordered list of column names.""" + + def values(self) -> list[Any]: + """Return ordered list of column values.""" + + def items(self) -> list[tuple[str, Any]]: + """Return ordered list of (column_name, value) pairs.""" + class QueryResult: """Result.""" @@ -107,6 +134,15 @@ class QueryResult: List of type that return passed `row_factory`. """ + def records(self: Self) -> list[Record]: + """Return result as a list of Record instances. + + Each Record shares column metadata with others from the same result set. + Supports positional/by-name indexing and dict-like access. + Unlike `result()`, the column-name lookup table is shared, not + re-created per row. + """ + class SingleQueryResult: """Single result.""" diff --git a/python/tests/test_copy_records.py b/python/tests/test_copy_records.py index 28aba34a..6316a823 100644 --- a/python/tests/test_copy_records.py +++ b/python/tests/test_copy_records.py @@ -1,5 +1,7 @@ import typing -from datetime import datetime, timezone +import uuid +from datetime import date, datetime, timezone +from decimal import Decimal import pytest from psqlpy import ConnectionPool @@ -172,3 +174,140 @@ async def test_copy_records_to_table_uses_schema_qualifier( finally: async with psql_pool.acquire() as connection: await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE") + + +async def test_copy_records_heterogeneous_types( + psql_pool: ConnectionPool, +) -> None: + """Characterization test: covers int, float, text, bytea, UUID, numeric, + date, timestamp, NULL, and array column types (AC-3.4). + """ + target: typing.Final = "copy_records_hetero" + + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {target}") + await connection.execute( + f""" + CREATE TABLE {target} ( + col_int INTEGER, + col_float DOUBLE PRECISION, + col_text TEXT, + col_bytea BYTEA, + col_uuid UUID, + col_numeric NUMERIC, + col_date DATE, + col_ts TIMESTAMPTZ, + col_null TEXT, + col_arr INTEGER[] + ) + """, + ) + + try: + sample_uuid = uuid.uuid4() + records = [ + ( + 42, + 3.14, + "hello", + b"\x00\x01\x02", + sample_uuid, + Decimal("12345.6789"), + date(2024, 6, 1), + datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc), + None, + [1, 2, 3], + ), + ] + + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + ) + + assert inserted == 1 + + async with psql_pool.acquire() as connection: + result = await connection.execute(f"SELECT * FROM {target}") + row = result.result()[0] + assert row["col_int"] == 42 + assert abs(row["col_float"] - 3.14) < 1e-9 + assert row["col_text"] == "hello" + assert bytes(row["col_bytea"]) == b"\x00\x01\x02" + assert row["col_uuid"] == str(sample_uuid) + assert row["col_numeric"] == Decimal("12345.6789") + assert row["col_date"] == date(2024, 6, 1) + assert row["col_null"] is None + assert row["col_arr"] == [1, 2, 3] + finally: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {target}") + + +async def test_copy_records_introspection_cache( + psql_pool: ConnectionPool, +) -> None: + """Second call to copy_records_to_table against the same table should not + issue a new column-type introspection PREPARE (AC-4.3). + """ + target: typing.Final = "copy_records_cache_test" + records = [(1, "first"), (2, "second")] + + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {target}") + await connection.execute( + f"CREATE TABLE {target} (id INTEGER, label TEXT)", + ) + + # Snapshot introspection query count before — use pg_stat_statements if available. + introspect_pattern = f"%{target}%WHERE false%" + pre_calls: int | None = None + try: + async with psql_pool.acquire() as connection: + res = await connection.execute( + "SELECT COALESCE(SUM(calls), 0) AS n FROM pg_stat_statements " + "WHERE query ILIKE $1", + parameters=[introspect_pattern], + ) + pre_calls = res.result()[0]["n"] + except Exception: # noqa: BLE001, S110 + pass # pg_stat_statements not available — skip count check + + try: + async with psql_pool.acquire() as connection: + # First call — populates the cache. + await connection.copy_records_to_table( + table_name=target, + records=records[:1], + ) + # Second call on the same connection — must hit the type cache. + await connection.copy_records_to_table( + table_name=target, + records=records[1:], + ) + + # Verify both rows were written correctly. + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label FROM {target} ORDER BY id", + ) + rows = [(r["id"], r["label"]) for r in result.result()] + assert rows == [(1, "first"), (2, "second")] + + # Verify only one introspection query was issued (cache hit on second call). + if pre_calls is not None: + async with psql_pool.acquire() as connection: + res = await connection.execute( + "SELECT COALESCE(SUM(calls), 0) AS n FROM pg_stat_statements " + "WHERE query ILIKE $1", + parameters=[introspect_pattern], + ) + post_calls = res.result()[0]["n"] + # At most one introspection PREPARE should be issued (cache hit on call 2). + assert post_calls - pre_calls <= 1, ( + f"Expected at most 1 introspection call, got {post_calls - pre_calls}" + ) + finally: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {target}") diff --git a/python/tests/test_record.py b/python/tests/test_record.py new file mode 100644 index 00000000..b5b85c84 --- /dev/null +++ b/python/tests/test_record.py @@ -0,0 +1,113 @@ +"""Tests for the Record pyclass and QueryResult.records() method (AC-5.5).""" + +import decimal + +import pytest +from psqlpy import ConnectionPool + +pytestmark = pytest.mark.anyio + + +async def test_record_positional_and_named_access( + psql_pool: ConnectionPool, +) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute( + "SELECT 1 AS a, 'hello' AS b, 3.14 AS c", + ) + + records = result.records() + assert len(records) == 1 + row = records[0] + + # positional access + assert row[0] == 1 + assert row[1] == "hello" + + # negative index resolves to last column (3.14 as numeric) + assert row[-1] == decimal.Decimal("3.14") + + # by-name access + assert row["a"] == 1 + assert row["b"] == "hello" + + +async def test_record_len(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 1 AS x, 2 AS y") + row = result.records()[0] + assert len(row) == 2 + + +async def test_record_iteration(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 10 AS x, 20 AS y, 30 AS z") + row = result.records()[0] + values = list(row) + assert values == [10, 20, 30] + + +async def test_record_get(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 42 AS answer") + row = result.records()[0] + assert row.get("answer") == 42 + assert row.get("missing") is None + assert row.get("missing", 99) == 99 + + +async def test_record_keys(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 1 AS alpha, 2 AS beta") + row = result.records()[0] + assert row.keys() == ["alpha", "beta"] + + +async def test_record_values(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 7 AS p, 8 AS q") + row = result.records()[0] + assert row.values() == [7, 8] + + +async def test_record_items(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 5 AS foo, 'bar' AS baz") + row = result.records()[0] + items = row.items() + assert items == [("foo", 5), ("baz", "bar")] + + +async def test_record_slice(psql_pool: ConnectionPool) -> None: + async with psql_pool.acquire() as connection: + result = await connection.execute( + "SELECT 1 AS a, 2 AS b, 3 AS c, 4 AS d", + ) + row = result.records()[0] + assert row[1:3] == [2, 3] + assert row[:2] == [1, 2] + assert row[::2] == [1, 3] + + +async def test_record_shared_descriptor(psql_pool: ConnectionPool) -> None: + """All records from one result set share the same column descriptor.""" + async with psql_pool.acquire() as connection: + result = await connection.execute( + "SELECT generate_series AS n FROM generate_series(1, 5)", + ) + records = result.records() + assert len(records) == 5 + # keys() returns the same column list for all rows + for row in records: + assert row.keys() == ["n"] + # values are distinct per row + assert [row["n"] for row in records] == list(range(1, 6)) + + +async def test_result_unchanged_by_records(psql_pool: ConnectionPool) -> None: + """result() still returns dicts after records() is added.""" + async with psql_pool.acquire() as connection: + result = await connection.execute("SELECT 1 AS n") + dicts = result.result() + assert isinstance(dicts[0], dict) + assert dicts[0]["n"] == 1 diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index fdb44bf0..745c544a 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -659,9 +659,9 @@ async def test_char_internal_type_byte_spectrum( value = decoded[i] assert isinstance(value, str) assert len(value) == 1 - assert ( - ord(value) == b - ), f"byte 0x{b:02x} round-tripped to ord(value)=0x{ord(value):02x}" + assert ord(value) == b, ( + f"byte 0x{b:02x} round-tripped to ord(value)=0x{ord(value):02x}" + ) assert decoded[len(bytes_under_test)] is None diff --git a/src/connection/impls.rs b/src/connection/impls.rs index 337b62fb..49227350 100644 --- a/src/connection/impls.rs +++ b/src/connection/impls.rs @@ -1,6 +1,6 @@ use bytes::Buf; -use futures::stream::{FuturesOrdered, StreamExt}; -use postgres_types::ToSql; +use futures::stream::{FuturesOrdered, Stream, StreamExt}; +use postgres_types::{ToSql, Type}; use pyo3::{PyAny, Python}; use tokio_postgres::{CopyInSink, Portal as tp_Portal, Row, Statement, ToStatement}; @@ -8,7 +8,10 @@ use crate::{ exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, options::{IsolationLevel, ReadVariant}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - statement::{statement::PsqlpyStatement, statement_builder::StatementBuilder}, + statement::{ + parameters::ParametersBuilder, statement::PsqlpyStatement, + statement_builder::StatementBuilder, + }, transaction::structs::PSQLPyTransaction, value_converter::to_python::postgres_to_py, }; @@ -17,10 +20,34 @@ use deadpool_postgres::Transaction as dp_Transaction; use tokio_postgres::Transaction as tp_Transaction; use super::{ - structs::{PSQLPyConnection, PoolConnection, SingleConnection}, + structs::{CopyTypeCache, PSQLPyConnection, PoolConnection, SingleConnection}, traits::{CloseTransaction, Connection, StartTransaction, Transaction}, }; +/// Drain a `FuturesOrdered` stream to completion, returning the first error seen. +/// +/// All futures are polled to completion regardless of errors — this keeps +/// the underlying connection in a quiescent state before the caller issues +/// the next statement (e.g., ROLLBACK TO SAVEPOINT after a failed batch). +async fn drain_ordered( + mut stream: impl Stream> + Unpin, +) -> PSQLPyResult<()> { + let mut first_err: Option = None; + while let Some(res) = stream.next().await { + if let Err(err) = res { + if first_err.is_none() { + first_err = Some(RustPSQLDriverError::ConnectionExecuteError(format!( + "Error occurred in `execute_many` statement: {err}" + ))); + } + } + } + match first_err { + Some(e) => Err(e), + None => Ok(()), + } +} + impl Transaction for T where T: Connection, @@ -62,12 +89,18 @@ where impl Connection for SingleConnection { async fn prepare(&self, query: &str, prepared: bool) -> PSQLPyResult { - let prepared_stmt = self.connection.prepare(query).await?; - - if !prepared { - self.drop_prepared(&prepared_stmt).await?; + if prepared { + if let Some(cached) = self.stmt_cache.get(query) { + return Ok(cached.clone()); + } + let stmt = self.connection.prepare(query).await?; + self.stmt_cache.insert(query.to_string(), stmt.clone()); + return Ok(stmt); } - Ok(prepared_stmt) + + let stmt = self.connection.prepare(query).await?; + self.drop_prepared(&stmt).await?; + Ok(stmt) } async fn drop_prepared(&self, stmt: &Statement) -> PSQLPyResult<()> { @@ -340,6 +373,14 @@ impl PSQLPyConnection { } } + #[must_use] + pub fn copy_type_cache(&self) -> &CopyTypeCache { + match self { + PSQLPyConnection::PoolConn(conn) => &conn.copy_type_cache, + PSQLPyConnection::SingleConnection(conn) => &conn.copy_type_cache, + } + } + /// Prepare internal `PSQLPy` statement /// /// # Errors @@ -476,14 +517,19 @@ impl PSQLPyConnection { /// asymmetry between the two call sites already exists — savepoints /// just bring `Transaction::execute_many` into line. /// - /// ## Behavioural change vs prior implementation + /// ## Breaking change vs prior implementation (0.12.0) + /// + /// **`Connection::execute_many`**: previously each row was an independent + /// auto-commit, so a mid-batch failure left earlier rows committed. The + /// new `BEGIN`/`COMMIT` wrap makes the whole batch atomic. /// - /// Previously this method ran each row as an independent auto-commit, - /// so a mid-batch failure left earlier rows committed. The new wrap - /// makes the whole batch atomic. This matches asyncpg / psycopg - /// `executemany` expectations and the issue's framing of `execute_many` - /// as a bulk operation, but it is a semantic change worth flagging in - /// release notes. + /// **`Transaction::execute_many`**: previously a batch failure left the + /// outer transaction in an aborted state, requiring the caller to issue + /// `ROLLBACK`. Now the batch is wrapped in `SAVEPOINT psqlpy_execute_many`; + /// on failure the savepoint is rolled back and the outer transaction + /// remains live. **Callers that catch the error and explicitly call + /// `transaction.rollback()` should be updated to omit that call** — + /// the outer transaction is still usable after a batch failure. /// /// # Errors /// May return error if there is some problem with DB communication. @@ -502,23 +548,51 @@ impl PSQLPyConnection { let prepared = prepared.unwrap_or(true); - let mut statements: Vec = Vec::with_capacity(parameters.len()); - for param_set in parameters { - let statement = - StatementBuilder::new(&querystring, &Some(param_set), self, Some(prepared)) - .build() - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot build statement in execute_many: {err}" - )) - })?; - statements.push(statement); - } + // Build statement once using the first param set to resolve types and + // (for the prepared path) obtain the server-side Statement handle. + let template = StatementBuilder::new( + &querystring, + &Some(parameters[0].clone()), + self, + Some(prepared), + ) + .build() + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot build statement in execute_many: {err}" + )) + })?; - if statements.is_empty() { - return Ok(()); - } + let prepared_stmt: Option = if prepared { + Some(template.statement_query()?.clone()) + } else { + None + }; + let param_types: Vec = template.param_types().to_vec(); + let raw_query = template.raw_query().to_string(); + // Named-parameter names are already computed inside StatementBuilder::build(). + let param_names: Option> = template.param_names().map(<[_]>::to_vec); + + // Two GIL passes total: one inside StatementBuilder::build() for row 0, + // one here for all remaining rows — independent of batch size, not per-row. + let first_pp = template.into_prepared_parameters(); + let remaining_pp: PSQLPyResult> = if parameters.len() > 1 { + Python::with_gil(|gil| { + parameters[1..] + .iter() + .map(|param_set| { + ParametersBuilder::new(Some(param_set), Some(param_types.clone()), vec![]) + .prepare_with_gil(gil, param_names.clone()) + }) + .collect() + }) + } else { + Ok(vec![]) + }; + + let mut all_pp = vec![first_pp]; + all_pp.extend(remaining_pp?); let wrap = if self.in_transaction() { ExecuteManyWrap::Savepoint @@ -532,7 +606,9 @@ impl PSQLPyConnection { )) })?; - let batch_result = self.run_pipelined_batch(&statements, prepared).await; + let batch_result = self + .run_pipelined_batch(prepared_stmt.as_ref(), &raw_query, &all_pp, prepared) + .await; let close_sql = wrap.close_sql(batch_result.is_ok()); let close_result = self.batch_execute(close_sql).await; @@ -558,72 +634,53 @@ impl PSQLPyConnection { /// short-circuiting with `?`) so already-sent messages can be acknowledged /// and the connection returns to a quiescent state before the caller /// issues the close-wrap statement. The first error is what we propagate. + /// + /// The `prepared_stmt` is already resolved by `execute_many` — no second + /// prepare call is issued here. + /// + /// # TODO(bind-execute-many) + /// + /// The per-row `FuturesOrdered` loop below is the layered + /// fallback for batched execution. It should be replaced by a future + /// `tokio_postgres::Client::bind_execute_many(&Statement, impl Iterator)` + /// primitive once it lands upstream in sfackler/rust-postgres. + /// + /// Target behaviour: pack Bind+Execute frames for all rows into a shared + /// `BytesMut` (constants: `_EXECUTE_MANY_BUF_NUM=4`, `_EXECUTE_MANY_BUF_SIZE=32768`), + /// issue a single trailing Sync, and writev ~128 KiB per round-trip. + /// Reference algorithm: asyncpg `coreproto.pyx:1022-1092`. + /// + /// Tracking: no upstream issue open yet in sfackler/rust-postgres — + /// file one at if the + /// primitive is not yet tracked. + #[allow(clippy::type_complexity)] async fn run_pipelined_batch( &self, - statements: &[PsqlpyStatement], + prepared_stmt: Option<&Statement>, + raw_query: &str, + all_params: &[crate::statement::parameters::PreparedParameters], prepared: bool, ) -> PSQLPyResult<()> { - // Materialize parameter slices into owned boxes so the borrows feeding - // each future live for the whole pipeline (the slices reference data - // owned by each `PsqlpyStatement`, which already outlives this fn). if prepared { - let prepared_stmt = self - .prepare(statements[0].raw_query(), true) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement in execute_many: {err}" - )) - })?; + let stmt = prepared_stmt.expect("prepared_stmt required when prepared=true"); let param_boxes: Vec> = - statements.iter().map(PsqlpyStatement::params).collect(); + all_params.iter().map(|p| p.params()).collect(); - let mut ordered: FuturesOrdered<_> = param_boxes - .iter() - .map(|p| self.query(&prepared_stmt, p)) - .collect(); + let ordered: FuturesOrdered<_> = + param_boxes.iter().map(|p| self.query(stmt, p)).collect(); - let mut first_err: Option = None; - while let Some(res) = ordered.next().await { - if let Err(err) = res { - if first_err.is_none() { - first_err = Some(RustPSQLDriverError::ConnectionExecuteError(format!( - "Error occurred in `execute_many` statement: {err}" - ))); - } - } - } - match first_err { - Some(e) => Err(e), - None => Ok(()), - } + drain_ordered(ordered).await } else { - let param_boxes: Vec<_> = statements - .iter() - .map(PsqlpyStatement::params_typed) - .collect(); + let param_boxes: Vec> = + all_params.iter().map(|p| p.params_typed()).collect(); - let mut ordered: FuturesOrdered<_> = statements + let ordered: FuturesOrdered<_> = param_boxes .iter() - .zip(param_boxes.iter()) - .map(|(s, p)| self.query_typed(s.raw_query(), p)) + .map(|p| self.query_typed(raw_query, p)) .collect(); - let mut first_err: Option = None; - while let Some(res) = ordered.next().await { - if let Err(err) = res { - if first_err.is_none() { - first_err = Some(RustPSQLDriverError::ConnectionExecuteError(format!( - "Error occurred in `execute_many` statement: {err}" - ))); - } - } - } - match first_err { - Some(e) => Err(e), - None => Ok(()), - } + drain_ordered(ordered).await } } diff --git a/src/connection/structs.rs b/src/connection/structs.rs index 9cbd9d05..f28f8abc 100644 --- a/src/connection/structs.rs +++ b/src/connection/structs.rs @@ -1,7 +1,16 @@ use std::sync::Arc; +use dashmap::DashMap; use deadpool_postgres::Object; -use tokio_postgres::{Client, Config}; +use postgres_types::Type; +use tokio_postgres::{Client, Config, Statement}; + +/// Per-connection cache for COPY column-type introspection results. +/// Key: `(schema_name, table_name, columns_in_declaration_order)`. +/// Column order is significant: `["a","b"]` and `["b","a"]` produce different COPY targets. +/// Cache is per-checkout on `PoolConnection` — reuse the same acquired connection +/// for consecutive `copy_records_to_table` calls to benefit from this cache. +pub type CopyTypeCache = DashMap<(Option, String, Vec), Vec>; #[derive(Debug)] pub struct PoolConnection { @@ -9,6 +18,8 @@ pub struct PoolConnection { pub in_transaction: bool, pub in_cursor: bool, pub pg_config: Arc, + /// Per-connection cache for COPY column-type introspection results. + pub copy_type_cache: CopyTypeCache, } impl PoolConnection { @@ -19,6 +30,7 @@ impl PoolConnection { in_transaction: false, in_cursor: false, pg_config, + copy_type_cache: DashMap::new(), } } } @@ -29,6 +41,10 @@ pub struct SingleConnection { pub in_transaction: bool, pub in_cursor: bool, pub pg_config: Arc, + /// Per-connection prepared-statement cache. Keyed by the raw query string. + pub stmt_cache: DashMap, + /// Per-connection cache for COPY column-type introspection results. + pub copy_type_cache: CopyTypeCache, } impl SingleConnection { @@ -39,6 +55,8 @@ impl SingleConnection { in_transaction: false, in_cursor: false, pg_config, + stmt_cache: DashMap::new(), + copy_type_cache: DashMap::new(), } } } diff --git a/src/driver/common.rs b/src/driver/common.rs index 2794cba2..b78d66af 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -15,10 +15,14 @@ use crate::{ value_converter::{dto::enums::PythonDTO, from_python::from_python_typed}, }; -use bytes::BytesMut; +use byteorder::{BigEndian, ByteOrder}; +use bytes::{BufMut, BytesMut}; use futures_util::pin_mut; use pyo3::{buffer::PyBuffer, types::PyAnyMethods, Python}; -use tokio_postgres::{binary_copy::BinaryCopyInWriter, types::ToSql}; +use tokio_postgres::{ + binary_copy::BinaryCopyInWriter, + types::{IsNull, ToSql}, +}; use crate::format_helpers::quote_ident; @@ -322,6 +326,40 @@ macro_rules! impl_binary_copy_method { impl_binary_copy_method!(Connection); impl_binary_copy_method!(Transaction); +/// Asyncpg's `_COPY_BUFFER_SIZE`: flush when the encode buffer reaches 512 KiB. +const COPY_BUFFER_SIZE: usize = 524_288; + +/// `PostgreSQL` binary COPY file header. +const COPY_MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0"; + +/// Encode one field into `buf` using the `PostgreSQL` binary COPY wire format: +/// a 4-byte big-endian length prefix followed by the serialised value, +/// or -1 (as i32) for NULL. +pub(crate) fn encode_copy_field( + buf: &mut BytesMut, + dto: &PythonDTO, + ty: &tokio_postgres::types::Type, +) -> PSQLPyResult<()> { + let len_pos = buf.len(); + buf.put_i32(0); // placeholder — overwritten after encoding + let data_start = buf.len(); + let is_null = dto.to_sql_checked(ty, buf).map_err(|e| { + RustPSQLDriverError::PyToRustValueConversionError(format!("COPY binary encode error: {e}")) + })?; + // COPY field lengths fit in i32; a single encoded field cannot exceed 2 GiB. + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] + let field_len = match is_null { + IsNull::No => (buf.len() - data_start) as i32, + IsNull::Yes => { + // NULL: truncate back to placeholder position, no data bytes + buf.truncate(data_start); + -1i32 + } + }; + BigEndian::write_i32(&mut buf[len_pos..], field_len); + Ok(()) +} + macro_rules! impl_copy_records_method { ($name:ident) => { #[pymethods] @@ -333,9 +371,15 @@ macro_rules! impl_copy_records_method { /// pass Python values directly (the same conversions used by /// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. /// + /// The encoder follows asyncpg's algorithm: a single BytesMut + /// accumulator flushed into 512 KiB (`_COPY_BUFFER_SIZE`) chunks. + /// All rows are encoded during the GIL pass; chunks are sent to the + /// server in a second pass after the GIL is released. + /// /// # Errors /// May return error if there is some problem with DB communication, /// the table cannot be introspected, or a value cannot be converted. + #[allow(clippy::too_many_lines)] #[pyo3(signature = (table_name, records, columns=None, schema_name=None))] pub async fn copy_records_to_table( self_: pyo3::Py, @@ -344,41 +388,20 @@ macro_rules! impl_copy_records_method { columns: Option>, schema_name: Option, ) -> PSQLPyResult { - let (db_client, raw_records) = Python::with_gil( - |gil| -> PSQLPyResult<(Option<_>, Vec>>)> { - let db_client = self_.borrow(gil).conn.clone(); - - let Some(db_client) = db_client else { - return Ok((None, Vec::new())); - }; - - let bound = records.bind(gil); - let mut rows: Vec>> = Vec::new(); - for item in bound.try_iter()? { - let row = item?; - let mut row_vec: Vec> = Vec::new(); - for cell in row.try_iter()? { - row_vec.push(cell?.unbind()); - } - rows.push(row_vec); - } - - Ok((Some(db_client), rows)) - }, - )?; + let db_client = Python::with_gil(|gil| self_.borrow(gil).conn.clone()); let Some(db_client) = db_client else { return Ok(0); }; - let full_table_name = match schema_name { - Some(ref schema) => { + let full_table_name = match schema_name.as_deref() { + Some(schema) => { format!("{}.{}", quote_ident(schema), quote_ident(&table_name)) } None => quote_ident(&table_name), }; - let columns_sql = match columns { + let columns_sql: Option = match columns { Some(ref cols) if !cols.is_empty() => Some( cols.iter() .map(|c| quote_ident(c)) @@ -388,16 +411,33 @@ macro_rules! impl_copy_records_method { _ => None, }; - let introspect_qs = match &columns_sql { - Some(cols) => format!("SELECT {} FROM {} WHERE false", cols, full_table_name), - None => format!("SELECT * FROM {} WHERE false", full_table_name), - }; - let read_conn_g = db_client.read().await; - let stmt = read_conn_g.prepare(&introspect_qs, false).await?; - let column_types: Vec = - stmt.columns().iter().map(|c| c.type_().clone()).collect(); + // Consult the per-connection type cache before issuing an + // introspection query (avoids PREPARE+DEALLOCATE round-trips). + let cache_key = ( + schema_name.clone(), + table_name.clone(), + columns.clone().unwrap_or_default(), + ); + let column_types: Vec = if let Some(cached) = + read_conn_g.copy_type_cache().get(&cache_key) + { + (*cached).clone() + } else { + let introspect_qs = match &columns_sql { + Some(cols) => { + format!("SELECT {} FROM {} WHERE false", cols, full_table_name) + } + None => format!("SELECT * FROM {} WHERE false", full_table_name), + }; + let stmt = read_conn_g.prepare(&introspect_qs, false).await?; + let types: Vec<_> = stmt.columns().iter().map(|c| c.type_().clone()).collect(); + read_conn_g + .copy_type_cache() + .insert(cache_key, types.clone()); + types + }; if column_types.is_empty() { return Err(RustPSQLDriverError::PyToRustValueConversionError( @@ -405,29 +445,6 @@ macro_rules! impl_copy_records_method { )); } - let typed_rows: Vec> = - Python::with_gil(|gil| -> PSQLPyResult>> { - let mut typed: Vec> = Vec::with_capacity(raw_records.len()); - for (row_idx, row) in raw_records.iter().enumerate() { - if row.len() != column_types.len() { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - format!( - "Record at index {} has {} fields, expected {}", - row_idx, - row.len(), - column_types.len() - ), - )); - } - let mut row_dto: Vec = Vec::with_capacity(row.len()); - for (cell, ty) in row.iter().zip(column_types.iter()) { - row_dto.push(from_python_typed(cell.bind(gil), ty)?); - } - typed.push(row_dto); - } - Ok(typed) - })?; - let copy_qs = match &columns_sql { Some(cols) => format!( "COPY {}({}) FROM STDIN (FORMAT binary)", @@ -436,17 +453,79 @@ macro_rules! impl_copy_records_method { None => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name), }; - let sink = read_conn_g.copy_in(©_qs).await?; - let writer = BinaryCopyInWriter::new(sink, &column_types); - pin_mut!(writer); + let sink = read_conn_g.copy_in::<_, bytes::Bytes>(©_qs).await?; + + // GIL pass: encode all rows into `COPY_BUFFER_SIZE` chunks. + // After the GIL is released, the chunks are sent to the server. + // This eliminates the prior two-phase approach (materialise + // Vec>> then re-visit for DTO conversion). + let mut chunks: Vec = Vec::new(); + + let gil_result: PSQLPyResult<()> = Python::with_gil(|gil| { + let n_cols = column_types.len(); + let mut buf = BytesMut::with_capacity(COPY_BUFFER_SIZE); + // Scratch vec allocated once and cleared between rows (T3#10). + let mut cells_scratch: Vec> = + Vec::with_capacity(n_cols); + + // COPY binary file header + buf.put_slice(COPY_MAGIC); + buf.put_i32(0); // flags + buf.put_i32(0); // header extension length + + for (row_idx, item) in records.bind(gil).try_iter()?.enumerate() { + let row = item?; + cells_scratch.clear(); + for cell in row.try_iter()? { + cells_scratch.push(cell?); + } - for row in &typed_rows { - let row_refs: Vec<&(dyn ToSql + Sync)> = - row.iter().map(|v| v as &(dyn ToSql + Sync)).collect(); - writer.as_mut().write(&row_refs).await?; + if cells_scratch.len() != n_cols { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + format!( + "Record at index {} has {} fields, expected {}", + row_idx, + cells_scratch.len(), + n_cols + ), + )); + } + + // PostgreSQL max columns = 1600, well within i16 range. + #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)] + buf.put_i16(n_cols as i16); + for (cell, ty) in cells_scratch.iter().zip(column_types.iter()) { + let dto = from_python_typed(cell, ty)?; + encode_copy_field(&mut buf, &dto, ty)?; + } + + if buf.len() >= COPY_BUFFER_SIZE { + chunks.push(buf.split().freeze()); + } + } + + // Binary COPY trailer + buf.put_i16(-1); + chunks.push(buf.freeze()); + Ok(()) + }); + + pin_mut!(sink); + + if let Err(e) = gil_result { + // Abort the sink so the server sees copy_fail + ReadyForQuery + // rather than a silent connection-level drop. + use futures_util::SinkExt; + let _ = sink.close().await; + return Err(e); } - let rows_created = writer.as_mut().finish().await?; + // Send all chunks outside the GIL. + for chunk in chunks { + use futures_util::SinkExt; + sink.send(chunk).await?; + } + let rows_created = sink.finish().await?; Ok(rows_created) } @@ -456,3 +535,48 @@ macro_rules! impl_copy_records_method { impl_copy_records_method!(Connection); impl_copy_records_method!(Transaction); + +#[cfg(test)] +mod tests { + use super::*; + use crate::value_converter::dto::enums::PythonDTO; + use byteorder::BigEndian; + use bytes::BytesMut; + use tokio_postgres::types::Type; + + fn decode_i32(buf: &[u8], offset: usize) -> i32 { + BigEndian::read_i32(&buf[offset..offset + 4]) + } + + #[test] + fn encode_copy_field_integer() { + let dto = PythonDTO::PyIntI32(42i32); + let mut buf = BytesMut::new(); + encode_copy_field(&mut buf, &dto, &Type::INT4).unwrap(); + // 4-byte length prefix + 4-byte INT4 payload + assert_eq!(buf.len(), 8); + assert_eq!(decode_i32(&buf, 0), 4); // field length + assert_eq!(decode_i32(&buf, 4), 42); // value + } + + #[test] + fn encode_copy_field_null() { + let dto = PythonDTO::PyNone; + let mut buf = BytesMut::new(); + encode_copy_field(&mut buf, &dto, &Type::INT4).unwrap(); + // NULL: 4-byte length prefix = -1, no payload + assert_eq!(buf.len(), 4); + assert_eq!(decode_i32(&buf, 0), -1); + } + + #[test] + fn encode_copy_field_text() { + let dto = PythonDTO::PyText("hi".to_string()); + let mut buf = BytesMut::new(); + encode_copy_field(&mut buf, &dto, &Type::TEXT).unwrap(); + // 4-byte length prefix + 2 bytes of UTF-8 + assert_eq!(buf.len(), 6); + assert_eq!(decode_i32(&buf, 0), 2); + assert_eq!(&buf[4..], b"hi"); + } +} diff --git a/src/lib.rs b/src/lib.rs index a20c1ce4..84ed6d38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,7 @@ fn psqlpy(py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyResult<()> { pymod.add_class::()?; pymod.add_class::()?; pymod.add_class::()?; + pymod.add_class::()?; add_module(py, pymod, "extra_types", extra_types_module)?; add_module(py, pymod, "exceptions", python_exceptions_module)?; add_module(py, pymod, "row_factories", row_factories_module)?; diff --git a/src/query_result.rs b/src/query_result.rs index 46047848..451cd0ba 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -1,7 +1,9 @@ +use std::{collections::HashMap, sync::Arc}; + use pyo3::{ prelude::*, pyclass, pymethods, - types::{PyDict, PyTuple}, + types::{PyDict, PyIterator, PyList, PySlice, PyTuple}, IntoPyObjectExt, Py, PyAny, Python, }; use tokio_postgres::Row; @@ -150,6 +152,54 @@ impl PSQLDriverPyQueryResult { } Ok(result.into_py_any(py)?) } + + /// Return result as a list of `Record` instances. + /// + /// Each `Record` shares a column-descriptor (`RecordDesc`) with all other + /// records from the same result set, avoiding per-row allocation of the + /// column-name map. Values are eagerly decoded (like asyncpg `Record`). + /// + /// `result()` is unchanged — this method is additive. + /// + /// # Errors + /// May return Err if a column value cannot be converted to Python. + pub fn records(&self, py: Python<'_>) -> PSQLPyResult> { + if self.inner.is_empty() { + return Ok(PyList::empty(py).into_py_any(py)?); + } + + // Build the shared column descriptor from the first row's columns. + let columns = self.inner[0].columns(); + let mut name_to_idx = HashMap::with_capacity(columns.len()); + let mut names = Vec::with_capacity(columns.len()); + for (i, col) in columns.iter().enumerate() { + if name_to_idx.contains_key(col.name()) { + return Err(crate::exceptions::rust_errors::RustPSQLDriverError::ConnectionExecuteError( + format!( + "Duplicate column name '{}' in result set; use positional indexing or aliases", + col.name() + ), + )); + } + name_to_idx.insert(col.name().to_string(), i); + names.push(col.name().to_string()); + } + let desc = Arc::new(RecordDesc { name_to_idx, names }); + + let mut records: Vec> = Vec::with_capacity(self.inner.len()); + for row in &self.inner { + let mut values: Vec> = Vec::with_capacity(row.columns().len()); + for (idx, col) in row.columns().iter().enumerate() { + values.push(postgres_to_py(py, row, col, idx, &None)?); + } + let record = Record { + desc: Arc::clone(&desc), + values, + }; + records.push(Py::new(py, record)?.into_py_any(py)?); + } + Ok(records.into_py_any(py)?) + } } #[pyclass(name = "SingleQueryResult")] @@ -230,3 +280,124 @@ impl PSQLDriverSinglePyQueryResult { Ok(row_factory.call(py, (pydict,), None)?) } } + +/// Shared column metadata for a result set. All `Record` instances from one +/// `records()` call point to the same `Arc`. +pub struct RecordDesc { + name_to_idx: HashMap, + names: Vec, +} + +/// An asyncpg-compatible row type: eagerly decoded values + shared column map. +/// +/// Supports positional (`row[0]`) and by-name (`row["col"]`) access, iteration, +/// and dict-like `.keys()` / `.values()` / `.items()` / `.get()`. +#[pyclass(name = "Record")] +pub struct Record { + desc: Arc, + values: Vec>, +} + +#[pymethods] +impl Record { + fn __len__(&self) -> usize { + self.values.len() + } + + fn __repr__(&self, py: Python<'_>) -> String { + let fields: Vec = self + .desc + .names + .iter() + .zip(self.values.iter()) + .map(|(name, val)| { + let repr = val + .bind(py) + .repr() + .map_or_else(|_| "?".into(), |r| r.to_string()); + format!("{name}: {repr}") + }) + .collect(); + format!("", fields.join(", ")) + } + + #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] + fn __getitem__(&self, py: Python<'_>, key: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult> { + use crate::exceptions::rust_errors::RustPSQLDriverError; + + // Integer index + if let Ok(idx) = key.extract::() { + // Safe: len() <= isize::MAX on any platform we target + let len = self.values.len() as isize; + let real_idx = if idx < 0 { len + idx } else { idx }; + if real_idx < 0 || real_idx >= len { + return Err(RustPSQLDriverError::RustPyError( + pyo3::exceptions::PyIndexError::new_err(format!( + "Record index {idx} out of range" + )), + )); + } + // Safe: real_idx checked >= 0 above + return Ok(self.values[real_idx as usize].clone_ref(py)); + } + + // Slice + if let Ok(slice) = key.downcast::() { + // Safe: len() <= isize::MAX on any platform we target + let indices = slice.indices(self.values.len() as isize)?; + let mut result: Vec> = Vec::new(); + let mut i = indices.start; + while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) { + // Safe: i is a valid index within the slice range + result.push(self.values[i as usize].clone_ref(py)); + i += indices.step; + } + return Ok(result.into_py_any(py)?); + } + + // String key + if let Ok(name) = key.extract::() { + if let Some(&idx) = self.desc.name_to_idx.get(&name) { + return Ok(self.values[idx].clone_ref(py)); + } + return Err(RustPSQLDriverError::RustPyError( + pyo3::exceptions::PyKeyError::new_err(name), + )); + } + + Err(RustPSQLDriverError::RustPyError( + pyo3::exceptions::PyTypeError::new_err("Record key must be int, slice, or str"), + )) + } + + fn __iter__(&self, py: Python<'_>) -> PSQLPyResult> { + let list = PyList::new(py, self.values.iter().map(|v| v.clone_ref(py)))?; + Ok(PyIterator::from_object(list.as_any())?.into_py_any(py)?) + } + + #[pyo3(signature = (key, default=None))] + fn get(&self, py: Python<'_>, key: &str, default: Option>) -> Py { + if let Some(&idx) = self.desc.name_to_idx.get(key) { + self.values[idx].clone_ref(py) + } else { + default.unwrap_or_else(|| py.None()) + } + } + + fn keys(&self) -> Vec { + self.desc.names.clone() + } + + fn values(&self, py: Python<'_>) -> Vec> { + self.values.iter().map(|v| v.clone_ref(py)).collect() + } + + fn items(&self, py: Python<'_>) -> Vec<(String, Py)> { + self.desc + .names + .iter() + .zip(self.values.iter()) + .map(|(k, v)| (k.clone(), v.clone_ref(py))) + .collect() + } +} diff --git a/src/statement/cache.rs b/src/statement/cache.rs deleted file mode 100644 index 90ba70ee..00000000 --- a/src/statement/cache.rs +++ /dev/null @@ -1,57 +0,0 @@ -use std::collections::HashMap; - -use postgres_types::Type; -use tokio::sync::RwLock; -use tokio_postgres::Statement; - -use super::{parameters::Column, query::QueryString, utils::hash_str}; - -#[derive(Default)] -pub(crate) struct StatementsCache(HashMap); - -impl StatementsCache { - pub fn add_cache(&mut self, query: &QueryString, inner_stmt: &Statement) { - self.0 - .insert(query.hash(), StatementCacheInfo::new(query, inner_stmt)); - } - - pub fn get_cache(&self, querystring: &String) -> Option { - let qs_hash = hash_str(querystring); - - if let Some(cache_info) = self.0.get(&qs_hash) { - return Some(cache_info.clone()); - } - - None - } -} - -#[derive(Clone)] -pub(crate) struct StatementCacheInfo { - pub(crate) query: QueryString, - pub(crate) inner_stmt: Statement, -} - -impl StatementCacheInfo { - fn new(query: &QueryString, inner_stmt: &Statement) -> Self { - Self { - query: query.clone(), - inner_stmt: inner_stmt.clone(), - } - } - - pub(crate) fn types(&self) -> Vec { - self.inner_stmt.params().to_vec() - } - - pub(crate) fn columns(&self) -> Vec { - self.inner_stmt - .columns() - .iter() - .map(|column| Column::new(column.name().to_string(), column.table_oid())) - .collect::>() - } -} - -pub(crate) static STMTS_CACHE: std::sync::LazyLock> = - std::sync::LazyLock::new(|| RwLock::new(StatementsCache::default())); diff --git a/src/statement/mod.rs b/src/statement/mod.rs index c5a82349..db88bca9 100644 --- a/src/statement/mod.rs +++ b/src/statement/mod.rs @@ -1,7 +1,5 @@ -pub mod cache; pub mod parameters; pub mod query; #[allow(clippy::module_inception)] pub mod statement; pub mod statement_builder; -pub mod utils; diff --git a/src/statement/parameters.rs b/src/statement/parameters.rs index 52417497..e1af46da 100644 --- a/src/statement/parameters.rs +++ b/src/statement/parameters.rs @@ -68,12 +68,28 @@ impl ParametersBuilder { self, parameters_names: Option>, ) -> PSQLPyResult { + if self.parameters.is_none() { + return Ok(PreparedParameters::default()); + } + let prepared_parameters = Python::with_gil(|gil| self.prepare_parameters(gil, parameters_names))?; Ok(prepared_parameters) } + /// Like `prepare` but reuses an already-held GIL token, avoiding a second acquisition. + /// + /// If `self.parameters` is `None`, returns `PreparedParameters::default()` without + /// any conversion work. If parameters is an empty sequence, also returns early. + pub(crate) fn prepare_with_gil( + self, + gil: Python<'_>, + parameters_names: Option>, + ) -> PSQLPyResult { + self.prepare_parameters(gil, parameters_names) + } + fn prepare_parameters( self, gil: Python<'_>, @@ -84,6 +100,12 @@ impl ParametersBuilder { } let sequence_typed = self.as_type::>(gil); + + // Empty sequence: no conversion work to do. + if sequence_typed.as_ref().is_some_and(Vec::is_empty) { + return Ok(PreparedParameters::default()); + } + let mapping_typed = self.downcast_as::(gil); let mut prepared_parameters: Option = None; @@ -316,6 +338,11 @@ impl PreparedParameters { .into_boxed_slice() } + #[must_use] + pub fn types(&self) -> &[Type] { + &self.types + } + #[must_use] pub fn columns(&self) -> &Vec { &self.columns diff --git a/src/statement/query.rs b/src/statement/query.rs index 312cdad0..5db5ad6f 100644 --- a/src/statement/query.rs +++ b/src/statement/query.rs @@ -4,8 +4,6 @@ use regex::Regex; use crate::value_converter::consts::KWARGS_PARAMS_REGEXP; -use super::utils::hash_str; - #[derive(Clone, Debug)] pub struct QueryString { pub(crate) initial_qs: String, @@ -37,10 +35,6 @@ impl QueryString { &self.initial_qs } - pub(crate) fn hash(&self) -> u64 { - hash_str(&self.initial_qs) - } - pub(crate) fn process_qs(&mut self) { if !self.is_kwargs_parametrized() { return; diff --git a/src/statement/statement.rs b/src/statement/statement.rs index e13900c3..62b82594 100644 --- a/src/statement/statement.rs +++ b/src/statement/statement.rs @@ -60,4 +60,26 @@ impl PsqlpyStatement { pub fn columns(&self) -> &Vec { self.prepared_parameters.columns() } + + #[must_use] + pub fn param_types(&self) -> &[postgres_types::Type] { + self.prepared_parameters.types() + } + + /// Return parameter placeholder names extracted from the query string. + /// + /// Returns `None` when the query uses positional `$1` syntax, and + /// `Some(&[String])` for kwargs-style `%(name)s` queries. + #[must_use] + pub fn param_names(&self) -> Option<&[String]> { + self.query + .converted_qs + .as_ref() + .map(|c| c.params_names().as_slice()) + } + + #[must_use] + pub fn into_prepared_parameters(self) -> PreparedParameters { + self.prepared_parameters + } } diff --git a/src/statement/statement_builder.rs b/src/statement/statement_builder.rs index 1ad81704..e35bb656 100644 --- a/src/statement/statement_builder.rs +++ b/src/statement/statement_builder.rs @@ -1,5 +1,4 @@ use pyo3::PyObject; -use tokio::sync::RwLockWriteGuard; use tokio_postgres::Statement; use crate::{ @@ -8,7 +7,6 @@ use crate::{ }; use super::{ - cache::{StatementCacheInfo, StatementsCache, STMTS_CACHE}, parameters::{Column, ParametersBuilder}, query::QueryString, statement::PsqlpyStatement, @@ -42,88 +40,41 @@ impl<'a> StatementBuilder<'a> { /// # Errors /// May return error if cannot prepare statement. pub async fn build(self) -> PSQLPyResult { - if !self.prepared { - { - let stmt_cache_guard = STMTS_CACHE.read().await; - if let Some(cached) = stmt_cache_guard.get_cache(self.querystring) { - return self.build_with_cached(cached); - } - } - } - - let stmt_cache_guard = STMTS_CACHE.write().await; - self.build_no_cached(stmt_cache_guard).await - } - - fn build_with_cached(self, cached: StatementCacheInfo) -> PSQLPyResult { - let raw_parameters = ParametersBuilder::new( - self.parameters.as_ref(), - Some(cached.types()), - cached.columns(), - ); - - let parameters_names = cached - .query - .converted_qs - .as_ref() - .map(|converted_qs| converted_qs.params_names().clone()); - - let prepared_parameters = raw_parameters.prepare(parameters_names)?; - - Ok(PsqlpyStatement::new( - cached.query, - prepared_parameters, - None, - )) - } - - async fn build_no_cached( - self, - cache_guard: RwLockWriteGuard<'_, StatementsCache>, - ) -> PSQLPyResult { let mut querystring = QueryString::new(self.querystring); querystring.process_qs(); - let prepared_stmt = self.prepare_query(&querystring, self.prepared).await?; + let stmt = self.prepare_query(&querystring, self.prepared).await?; - let columns = prepared_stmt + let columns = stmt .columns() .iter() - .map(|column| Column::new(column.name().to_string(), column.table_oid())) + .map(|c| Column::new(c.name().to_string(), c.table_oid())) .collect::>(); - let parameters_builder = ParametersBuilder::new( + + let params_builder = ParametersBuilder::new( self.parameters.as_ref(), - Some(prepared_stmt.params().to_vec()), + Some(stmt.params().to_vec()), columns, ); - let parameters_names = querystring + let param_names = querystring .converted_qs .as_ref() - .map(|converted_qs| converted_qs.params_names().clone()); + .map(|c| c.params_names().clone()); - let prepared_parameters = parameters_builder.prepare(parameters_names)?; + let prepared_parameters = params_builder.prepare(param_names)?; if self.prepared { Ok(PsqlpyStatement::new( querystring, prepared_parameters, - Some(prepared_stmt), + Some(stmt), )) } else { - Self::write_to_cache(cache_guard, &querystring, &prepared_stmt); Ok(PsqlpyStatement::new(querystring, prepared_parameters, None)) } } - fn write_to_cache( - mut cache_guard: RwLockWriteGuard<'_, StatementsCache>, - query: &QueryString, - inner_stmt: &Statement, - ) { - cache_guard.add_cache(query, inner_stmt); - } - async fn prepare_query(&self, query: &QueryString, prepared: bool) -> PSQLPyResult { self.inner_conn.prepare(query.query(), prepared).await } diff --git a/src/statement/utils.rs b/src/statement/utils.rs deleted file mode 100644 index a79f8bdd..00000000 --- a/src/statement/utils.rs +++ /dev/null @@ -1,8 +0,0 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; - -pub(crate) fn hash_str(string: &String) -> u64 { - let mut hasher = DefaultHasher::new(); - string.hash(&mut hasher); - - hasher.finish() -} diff --git a/src/value_converter/from_python.rs b/src/value_converter/from_python.rs index e56d791b..282e0f04 100644 --- a/src/value_converter/from_python.rs +++ b/src/value_converter/from_python.rs @@ -7,13 +7,57 @@ use postgres_types::Type; use std::net::IpAddr; use pyo3::{ + sync::GILOnceCell, types::{ PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyInt, PyList, - PySequence, PySet, PyString, PyTime, PyTuple, PyTypeMethods, + PySequence, PySet, PyString, PyTime, PyTuple, PyType, }, Bound, Py, PyAny, Python, }; +/// Cached `uuid.UUID` type object for O(1) pointer-equality type dispatch. +static UUID_TYPE: GILOnceCell> = GILOnceCell::new(); +/// Cached `decimal.Decimal` type object for O(1) pointer-equality type dispatch. +static DECIMAL_TYPE: GILOnceCell> = GILOnceCell::new(); + +fn uuid_type(py: Python<'_>) -> PSQLPyResult> { + UUID_TYPE + .get_or_try_init(py, || { + pyo3::types::PyModule::import(py, "uuid") + .and_then(|m| m.getattr("UUID")) + .and_then(|t| { + t.downcast::() + .map(|t| t.clone().unbind()) + .map_err(Into::into) + }) + .map_err(|e| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "failed to import uuid.UUID: {e}" + )) + }) + }) + .map(|t| t.bind(py).to_owned()) +} + +fn decimal_type(py: Python<'_>) -> PSQLPyResult> { + DECIMAL_TYPE + .get_or_try_init(py, || { + pyo3::types::PyModule::import(py, "decimal") + .and_then(|m| m.getattr("Decimal")) + .and_then(|t| { + t.downcast::() + .map(|t| t.clone().unbind()) + .map_err(Into::into) + }) + .map_err(|e| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "failed to import decimal.Decimal: {e}" + )) + }) + }) + .map(|t| t.bind(py).to_owned()) +} + use crate::{ exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, extra_types::{self}, @@ -154,14 +198,14 @@ pub fn from_python_untyped(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult

::to_python_dto(parameter); } - if parameter.get_type().name()? == "UUID" { - return ::to_python_dto(parameter); - } - - if parameter.get_type().name()? == "decimal.Decimal" - || parameter.get_type().name()? == "Decimal" { - return ::to_python_dto(parameter); + let py = parameter.py(); + if parameter.is_exact_instance(uuid_type(py)?.as_any()) { + return ::to_python_dto(parameter); + } + if parameter.is_exact_instance(decimal_type(py)?.as_any()) { + return ::to_python_dto(parameter); + } } if let Ok(converted_array) = from_python_array_typed(parameter) { @@ -204,13 +248,12 @@ pub fn from_python_typed( return ::to_python_dto(parameter); } - if parameter.get_type().name()? == "UUID" { + let py = parameter.py(); + if parameter.is_exact_instance(uuid_type(py)?.as_any()) { return ::to_python_dto(parameter); } - if parameter.get_type().name()? == "decimal.Decimal" - || parameter.get_type().name()? == "Decimal" - { + if parameter.is_exact_instance(decimal_type(py)?.as_any()) { return ::to_python_dto(parameter); }