diff --git a/doc/code/targets/0_prompt_targets.md b/doc/code/targets/0_prompt_targets.md index e00983f769..ef46d29148 100644 --- a/doc/code/targets/0_prompt_targets.md +++ b/doc/code/targets/0_prompt_targets.md @@ -25,7 +25,7 @@ A `PromptTarget` is a generic place to send a prompt. With PyRIT, the idea is th With some algorithms, you want to send a prompt, set a system prompt, and modify conversation history (including PAIR [@chao2023pair], TAP [@mehrotra2023tap], and flip attack [@li2024flipattack]). These algorithms require a target whose [`TargetCapabilities`](#target-capabilities) declare both `supports_multi_turn=True` and `supports_editable_history=True` — i.e. you can modify a conversation history. Consumers express this requirement via `CHAT_TARGET_REQUIREMENTS` and validate it against `target.configuration` at construction time. See [Target Capabilities](#target-capabilities) below for the full list of capabilities and how they compose into a `TargetConfiguration`. -Note: The previous `PromptChatTarget` class is **deprecated** as of v0.13.0 and will be removed in v0.15.0. Use `PromptTarget` directly with a `TargetConfiguration` declaring `supports_multi_turn=True` and `supports_editable_history=True`. See [Target Capabilities](#target-capabilities) for details. +Note: The previous `PromptChatTarget` class is **deprecated** as of v0.14.0 and will be removed in v0.16.0. Use `PromptTarget` directly with a `TargetConfiguration` declaring `supports_multi_turn=True` and `supports_editable_history=True`. See [Target Capabilities](#target-capabilities) for details. Here are some examples: @@ -107,6 +107,27 @@ target = MyHTTPTarget(custom_configuration=config, ...) The full implementation lives in [`pyrit/prompt_target/common/target_capabilities.py`](https://github.com/microsoft/PyRIT/blob/main/pyrit/prompt_target/common/target_capabilities.py) and [`pyrit/prompt_target/common/target_configuration.py`](https://github.com/microsoft/PyRIT/blob/main/pyrit/prompt_target/common/target_configuration.py). For runnable examples — inspecting capabilities on a real target, comparing known model profiles, and `ADAPT` vs `RAISE` in action — see [Target Capabilities](./6_1_target_capabilities.ipynb). +### Discovering live target capabilities + +Declared capabilities describe what a target *should* support. For deployments where actual behavior is uncertain — custom OpenAI-compatible endpoints, gateways that strip features, models whose support drifts — you can probe what the target *actually* accepts at runtime: + +```python +from pyrit.prompt_target import ( + discover_target_capabilities_async, + discover_target_async, + discover_target_modalities_async, +) + +# Probe a single dimension: +queried_caps = await discover_target_capabilities_async(target=target) +queried_modalities = await discover_target_modalities_async(target=target) + +# Or do both at once and get a best-effort TargetCapabilities back: +queried = await discover_target_async(target=target) +``` + +Each probe sends a minimal request (bounded by `per_probe_timeout_s`, default 30s, with one retry on transient errors) and only marks a capability or modality as supported if the call returns cleanly. `discover_target_async` returns a merged view: probed where possible, declared where probing is unavailable or out of scope. "Supported" here means *the request was accepted* — a target that silently ignores a system prompt or `response_format` directive is still reported as supporting it, so validate response content out of band when the distinction matters. These functions are not safe to call concurrently with other operations on the same target instance: they temporarily mutate `target._configuration` and write probe rows to memory (rows are tagged with `prompt_metadata["capability_probe"] == "1"` for filtering). See [Target Capabilities](./6_1_target_capabilities.ipynb) for runnable examples. + ## Multi-Modal Targets Like most of PyRIT, targets can be multi-modal. diff --git a/doc/code/targets/6_1_target_capabilities.ipynb b/doc/code/targets/6_1_target_capabilities.ipynb index 18c1902062..a296b85bde 100644 --- a/doc/code/targets/6_1_target_capabilities.ipynb +++ b/doc/code/targets/6_1_target_capabilities.ipynb @@ -53,13 +53,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "No new upgrade operations detected.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "No new upgrade operations detected.\n", "supports_multi_turn: True\n", "supports_editable_history: True\n", "supports_system_prompt: True\n", @@ -382,19 +376,39 @@ } ], "source": [ - "from pyrit.models import Message\n", - "\n", - "conversation = [\n", - " Message.from_prompt(prompt=\"What is the capital of France?\", role=\"user\"),\n", - " Message.from_prompt(prompt=\"Paris.\", role=\"assistant\"),\n", - " Message.from_prompt(prompt=\"And of Germany?\", role=\"user\"),\n", - "]\n", - "\n", - "normalized = await adapt_target.configuration.normalize_async(messages=conversation) # type: ignore\n", - "print(f\"original turns: {len(conversation)}\")\n", - "print(f\"normalized turns: {len(normalized)}\")\n", - "print(\"flattened text:\")\n", - "print(normalized[-1].message_pieces[0].original_value)" + "from unittest.mock import AsyncMock\n", + "\n", + "from pyrit.models import MessagePiece\n", + "from pyrit.prompt_target import (\n", + " discover_target_async,\n", + " discover_target_capabilities_async,\n", + " discover_target_modalities_async,\n", + ")\n", + "\n", + "\n", + "def _ok_response():\n", + " return [\n", + " Message(\n", + " [\n", + " MessagePiece(\n", + " role=\"assistant\",\n", + " original_value=\"ok\",\n", + " original_value_data_type=\"text\",\n", + " conversation_id=\"probe\",\n", + " response_error=\"none\",\n", + " )\n", + " ]\n", + " )\n", + " ]\n", + "\n", + "\n", + "probe_target = OpenAIChatTarget(model_name=\"gpt-4o\", endpoint=\"https://example.invalid/\", api_key=\"sk-not-a-real-key\")\n", + "probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign]\n", + "\n", + "queried = await discover_target_capabilities_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore\n", + "print(\"queried capabilities:\")\n", + "for capability in sorted(queried, key=lambda c: c.value):\n", + " print(f\" - {capability.value}\")" ] }, { @@ -402,8 +416,27 @@ "id": "15", "metadata": {}, "source": [ - "By contrast, the `RAISE` configuration validates eagerly: any consumer requiring `MULTI_TURN` will\n", - "get a `ValueError` before a single prompt is sent." + "To narrow the probe to specific capabilities (faster, fewer calls), pass `capabilities=`:\n", + "\n", + "```python\n", + "from pyrit.prompt_target.common.target_capabilities import CapabilityName\n", + "\n", + "queried = await query_target_capabilities_async(\n", + " target=target,\n", + " capabilities=[CapabilityName.JSON_SCHEMA, CapabilityName.SYSTEM_PROMPT],\n", + ")\n", + "```\n", + "\n", + "If you only care about accepted input combinations, call\n", + "`query_target_modalities_async` directly. The example below uses the\n", + "packaged default probe assets for the non-text modalities PyRIT ships.\n", + "Pass `test_assets=` only when you want to override those defaults or probe\n", + "a modality without a packaged asset.\n", + "\n", + "`query_target_async` is the most common entry point: it runs both the capability and modality\n", + "probes and assembles a best-effort `TargetCapabilities` you can drop into a\n", + "`TargetConfiguration`, so the rest of PyRIT operates on probed values where available and\n", + "declared values otherwise." ] }, { @@ -411,6 +444,26 @@ "execution_count": null, "id": "16", "metadata": {}, + "outputs": [], + "source": [ + "probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign]\n", + "\n", + "queried_modalities = await discover_target_modalities_async(\n", + " target=probe_target,\n", + " test_modalities={frozenset({\"text\"}), frozenset({\"text\", \"image_path\"})},\n", + " per_probe_timeout_s=5.0,\n", + ") # type: ignore\n", + "\n", + "print(\"query_target_modalities_async result:\")\n", + "for combination in sorted(sorted(m) for m in queried_modalities):\n", + " print(f\" - {combination}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, "outputs": [ { "name": "stdout", @@ -429,7 +482,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "18", "metadata": {}, "source": [ "## 6. Non-adaptable capabilities\n", @@ -443,7 +496,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": {}, "outputs": [ { @@ -462,12 +515,208 @@ "try:\n", " no_editable_history.ensure_can_handle(capability=CapabilityName.EDITABLE_HISTORY)\n", "except ValueError as exc:\n", - " print(exc)\n", - "# ---" + " print(exc)" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "## 7. Querying live target capabilities\n", + "\n", + "Declared capabilities describe what a target *should* support. For deployments where the actual\n", + "behavior is uncertain — custom OpenAI-compatible endpoints, gateways that strip features, models\n", + "whose support drifts over time — you can probe what the target *actually* accepts at runtime with\n", + "`query_target_capabilities_async`, `query_target_modalities_async`, or the convenience wrapper\n", + "`query_target_async` that runs both and returns a best-effort `TargetCapabilities`.\n", + "\n", + "`query_target_capabilities_async` walks each capability that has a registered probe (currently\n", + "`SYSTEM_PROMPT`, `MULTI_MESSAGE_PIECES`, `MULTI_TURN`, `JSON_OUTPUT`, `JSON_SCHEMA`), sends a\n", + "minimal request, and includes the capability in the returned set only if the call succeeds.\n", + "During probing the target's configuration is temporarily replaced with a permissive one so\n", + "`ensure_can_handle` does not short-circuit a probe for a capability the target declares as\n", + "unsupported. The original configuration is restored before the function returns.\n", + "\n", + "`query_target_modalities_async` does the same for input modality combinations declared in\n", + "`capabilities.input_modalities`, sending a small payload built from optional `test_assets`.\n", + "\n", + "Each probe call is bounded by `per_probe_timeout_s` (default 30s) and is retried once on\n", + "transient errors before being declared failed. `query_target_async` returns a merged view:\n", + "probed where possible, declared where probing is unavailable or out of scope. \"Supported\" here\n", + "means *the request was accepted* — a target that silently ignores a system prompt or\n", + "`response_format` directive will still be reported as supporting that capability.\n", + "\n", + "These functions are **not safe to call concurrently** with other operations on the same target\n", + "instance: they temporarily mutate `target._configuration` and write probe rows to\n", + "`target._memory`. Probe-written memory rows are tagged with\n", + "`prompt_metadata[\"capability_probe\"] == \"1\"` so consumers can filter them.\n", + "\n", + "Typical usage against a real endpoint:\n", + "\n", + "```python\n", + "from pyrit.prompt_target import query_target_async\n", + "\n", + "queried = await query_target_async(target=target)\n", + "print(queried)\n", + "```\n", + "\n", + "Below we mock the target's underlying transport (`_send_prompt_to_target_async`) so the notebook\n", + "stays self-contained — the result shape is the same as a live run. We mock the protected method\n", + "rather than `send_prompt_async` so the probe still exercises the real validation and memory\n", + "pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "queried capabilities:\n", + " - supports_editable_history\n", + " - supports_json_output\n", + " - supports_json_schema\n", + " - supports_multi_message_pieces\n", + " - supports_multi_turn\n", + " - supports_system_prompt\n" + ] + } + ], + "source": [ + "from unittest.mock import AsyncMock\n", + "\n", + "from pyrit.models import MessagePiece\n", + "from pyrit.prompt_target import (\n", + " discover_target_async,\n", + " discover_target_capabilities_async,\n", + ")\n", + "\n", + "\n", + "def _ok_response():\n", + " return [\n", + " Message(\n", + " [\n", + " MessagePiece(\n", + " role=\"assistant\",\n", + " original_value=\"ok\",\n", + " original_value_data_type=\"text\",\n", + " conversation_id=\"probe\",\n", + " response_error=\"none\",\n", + " )\n", + " ]\n", + " )\n", + " ]\n", + "\n", + "\n", + "probe_target = OpenAIChatTarget(model_name=\"gpt-4o\", endpoint=\"https://example.invalid/\", api_key=\"sk-not-a-real-key\")\n", + "probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign]\n", + "\n", + "queried = await discover_target_capabilities_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore\n", + "print(\"queried capabilities:\")\n", + "for capability in sorted(queried, key=lambda c: c.value):\n", + " print(f\" - {capability.value}\")" + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": {}, + "source": [ + "To narrow the probe to specific capabilities (faster, fewer calls), pass `capabilities=`:\n", + "\n", + "```python\n", + "from pyrit.prompt_target.common.target_capabilities import CapabilityName\n", + "\n", + "queried = await query_target_capabilities_async(\n", + " target=target,\n", + " capabilities=[CapabilityName.JSON_SCHEMA, CapabilityName.SYSTEM_PROMPT],\n", + ")\n", + "```\n", + "\n", + "`query_target_async` is the most common entry point: it runs both the capability and modality\n", + "probes and assembles a `TargetCapabilities` you can drop straight into a `TargetConfiguration`,\n", + "so the rest of PyRIT (attacks, scorers, the normalization pipeline) operates on capabilities\n", + "that have been observed to work end-to-end." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query_target_async result:\n", + " supports_multi_turn: True\n", + " supports_system_prompt: True\n", + " supports_multi_message_pieces: True\n", + " supports_json_output: True\n", + " supports_json_schema: True\n", + " input_modalities: [['text']]\n" + ] + } + ], + "source": [ + "probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign]\n", + "\n", + "queried_caps = await discover_target_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore\n", + "print(\"query_target_async result:\")\n", + "print(f\" supports_multi_turn: {queried_caps.supports_multi_turn}\")\n", + "print(f\" supports_system_prompt: {queried_caps.supports_system_prompt}\")\n", + "print(f\" supports_multi_message_pieces: {queried_caps.supports_multi_message_pieces}\")\n", + "print(f\" supports_json_output: {queried_caps.supports_json_output}\")\n", + "print(f\" supports_json_schema: {queried_caps.supports_json_schema}\")\n", + "print(f\" input_modalities: {sorted(sorted(m) for m in queried_caps.input_modalities)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "24", + "metadata": {}, + "source": [ + "### Discovering undeclared modalities\n", + "\n", + "By default `query_target_async` only probes modality combinations the target already\n", + "**declares** in `capabilities.input_modalities`. For an OpenAI-compatible endpoint that\n", + "claims text-only but might actually accept images, pass `test_modalities=` explicitly to\n", + "probe combinations beyond the declared baseline. Provide `test_assets=` as well if you need\n", + "to override the packaged defaults or probe a modality without one:\n", + "\n", + "```python\n", + "queried = await query_target_async(\n", + " target=target,\n", + " test_modalities={frozenset({\"text\"}), frozenset({\"text\", \"image_path\"})},\n", + " test_assets={\"image_path\": \"/path/to/test_image.png\"},\n", + ")\n", + "```\n", + "\n", + "Similarly, when narrowing the probe set with `capabilities=`, capabilities NOT in the\n", + "narrowed set are copied from the target's declared values rather than being reset to\n", + "`False` — narrowing controls *what is re-queried*, not what the returned dataclass\n", + "reports. This makes incremental probing safe:\n", + "\n", + "```python\n", + "# Re-query only JSON support; other declared flags pass through unchanged.\n", + "queried = await query_target_async(\n", + " target=target,\n", + " capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA},\n", + ")\n", + "```" ] } ], "metadata": { + "jupytext": { + "main_language": "python" + }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/doc/code/targets/6_1_target_capabilities.py b/doc/code/targets/6_1_target_capabilities.py index 985374357b..c3104a54df 100644 --- a/doc/code/targets/6_1_target_capabilities.py +++ b/doc/code/targets/6_1_target_capabilities.py @@ -245,4 +245,160 @@ no_editable_history.ensure_can_handle(capability=CapabilityName.EDITABLE_HISTORY) except ValueError as exc: print(exc) -# --- + +# %% [markdown] +# ## 7. Discovering live target capabilities +# +# Declared capabilities describe what a target *should* support. For deployments where the actual +# behavior is uncertain — custom OpenAI-compatible endpoints, gateways that strip features, models +# whose support drifts over time — you can probe what the target *actually* accepts at runtime with +# `discover_target_capabilities_async`, `discover_target_modalities_async`, or the convenience wrapper +# `discover_target_async` that runs both and returns a best-effort `TargetCapabilities`. +# +# `discover_target_capabilities_async` walks each capability that has a registered probe (currently +# `SYSTEM_PROMPT`, `MULTI_MESSAGE_PIECES`, `MULTI_TURN`, `JSON_OUTPUT`, `JSON_SCHEMA`), sends a +# minimal request, and includes the capability in the returned set only if the call succeeds. +# During probing the target's configuration is temporarily replaced with a permissive one so +# `ensure_can_handle` does not short-circuit a probe for a capability the target declares as +# unsupported. The original configuration is restored before the function returns. +# +# `discover_target_modalities_async` does the same for input modality combinations declared in +# `capabilities.input_modalities`, sending a small payload built from optional `test_assets`. +# +# Each probe call is bounded by `per_probe_timeout_s` (default 30s) and is retried once on +# transient errors before being declared failed. `discover_target_async` returns a merged view: +# probed where possible, declared where probing is unavailable or out of scope. "Supported" here +# means *the request was accepted* — a target that silently ignores a system prompt or +# `response_format` directive will still be reported as supporting that capability. +# +# These functions are **not safe to call concurrently** with other operations on the same target +# instance: they temporarily mutate `target._configuration` and write probe rows to +# `target._memory`. Probe-written memory rows are tagged with +# `prompt_metadata["capability_probe"] == "1"` so consumers can filter them. +# +# Typical usage against a real endpoint: +# +# ```python +# from pyrit.prompt_target import discover_target_async +# +# queried = await discover_target_async(target=target) +# print(queried) +# ``` +# +# Below we mock the target's underlying transport (`_send_prompt_to_target_async`) so the notebook +# stays self-contained — the result shape is the same as a live run. We mock the protected method +# rather than `send_prompt_async` so the probe still exercises the real validation and memory +# pipeline. + +# %% +from unittest.mock import AsyncMock + +from pyrit.models import MessagePiece +from pyrit.prompt_target import ( + discover_target_async, + discover_target_capabilities_async, + discover_target_modalities_async, +) + + +def _ok_response(): + return [ + Message( + [ + MessagePiece( + role="assistant", + original_value="ok", + original_value_data_type="text", + conversation_id="probe", + response_error="none", + ) + ] + ) + ] + + +probe_target = OpenAIChatTarget(model_name="gpt-4o", endpoint="https://example.invalid/", api_key="sk-not-a-real-key") +probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + +queried = await discover_target_capabilities_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore +print("queried capabilities:") +for capability in sorted(queried, key=lambda c: c.value): + print(f" - {capability.value}") + +# %% [markdown] +# To narrow the probe to specific capabilities (faster, fewer calls), pass `capabilities=`: +# +# ```python +# from pyrit.prompt_target.common.target_capabilities import CapabilityName +# +# queried = await discover_target_capabilities_async( +# target=target, +# capabilities=[CapabilityName.JSON_SCHEMA, CapabilityName.SYSTEM_PROMPT], +# ) +# ``` +# +# If you only care about accepted input combinations, call +# `discover_target_modalities_async` directly. The example below uses the +# packaged default probe assets for the non-text modalities PyRIT ships. +# Pass `test_assets=` only when you want to override those defaults or probe +# a modality without a packaged asset. +# +# `discover_target_async` is the most common entry point: it runs both the capability and modality +# probes and assembles a best-effort `TargetCapabilities` you can drop into a +# `TargetConfiguration`, so the rest of PyRIT operates on probed values where available and +# declared values otherwise. + +# %% +probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + +queried_modalities = await discover_target_modalities_async( + target=probe_target, + test_modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, + per_probe_timeout_s=5.0, +) # type: ignore + +print("discover_target_modalities_async result:") +for combination in sorted(sorted(m) for m in queried_modalities): + print(f" - {combination}") + +# %% +probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + +queried_caps = await discover_target_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore +print("discover_target_async result:") +print(f" supports_multi_turn: {queried_caps.supports_multi_turn}") +print(f" supports_system_prompt: {queried_caps.supports_system_prompt}") +print(f" supports_multi_message_pieces: {queried_caps.supports_multi_message_pieces}") +print(f" supports_json_output: {queried_caps.supports_json_output}") +print(f" supports_json_schema: {queried_caps.supports_json_schema}") +print(f" input_modalities: {sorted(sorted(m) for m in queried_caps.input_modalities)}") + +# %% [markdown] +# ### Discovering undeclared modalities +# +# By default `discover_target_async` only probes modality combinations the target already +# **declares** in `capabilities.input_modalities`. For an OpenAI-compatible endpoint that +# claims text-only but might actually accept images, pass `test_modalities=` explicitly to +# probe combinations beyond the declared baseline. Provide `test_assets=` as well if you need +# to override the packaged defaults or probe a modality without one: +# +# ```python +# queried = await discover_target_async( +# target=target, +# test_modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, +# test_assets={"image_path": "/path/to/test_image.png"}, +# ) +# ``` +# +# Similarly, when narrowing the probe set with `capabilities=`, capabilities NOT in the +# narrowed set are copied from the target's declared values rather than being reset to +# `False` — narrowing controls *what is re-queried*, not what the returned dataclass +# reports. This makes incremental probing safe: +# +# ```python +# # Re-query only JSON support; other declared flags pass through unchanged. +# queried = await discover_target_async( +# target=target, +# capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA}, +# ) +# ``` diff --git a/pyrit/datasets/prompt_target/target_capabilities/probe_audio.wav b/pyrit/datasets/prompt_target/target_capabilities/probe_audio.wav new file mode 100644 index 0000000000..8dbde9545c Binary files /dev/null and b/pyrit/datasets/prompt_target/target_capabilities/probe_audio.wav differ diff --git a/pyrit/datasets/prompt_target/target_capabilities/probe_image.png b/pyrit/datasets/prompt_target/target_capabilities/probe_image.png new file mode 100644 index 0000000000..85dda3a6b1 Binary files /dev/null and b/pyrit/datasets/prompt_target/target_capabilities/probe_image.png differ diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 489fe34900..7cf7020a76 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -16,6 +16,11 @@ from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.query_target_capabilities import ( + discover_target_async, + discover_target_capabilities_async, + discover_target_modalities_async, +) from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, CapabilityName, @@ -97,11 +102,14 @@ def __getattr__(name: str) -> object: "PromptChatTarget", "PromptShieldTarget", "PromptTarget", + "discover_target_capabilities_async", "RealtimeTarget", "TargetCapabilities", "TargetConfiguration", "TargetRequirements", "UnsupportedCapabilityBehavior", "TextTarget", + "discover_target_async", + "discover_target_modalities_async", "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/common/query_target_capabilities.py b/pyrit/prompt_target/common/query_target_capabilities.py new file mode 100644 index 0000000000..de024611e2 --- /dev/null +++ b/pyrit/prompt_target/common/query_target_capabilities.py @@ -0,0 +1,762 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Runtime capability and modality discovery for prompt targets. + +This module exposes two complementary probes: + +* :func:`discover_target_capabilities_async` discovers the boolean capability flags + defined on :class:`TargetCapabilities` (e.g. ``supports_system_prompt``, + ``supports_multi_message_pieces``). For each capability that has a probe + defined, a minimal request is sent to the target. If the request succeeds, + the capability is included in the returned set. Capabilities without a + registered probe fall back to the target's declared native support from + ``target.capabilities``. +* :func:`discover_target_modalities_async` discovers which input modality + combinations a target actually supports by sending a minimal test request + for each combination declared in ``TargetCapabilities.input_modalities``. + +.. note:: + Output modality probing is intentionally not provided. Unlike inputs, + output modality is largely a property of the endpoint type (chat models + return text, image models return images, TTS endpoints return audio) + rather than something the caller controls per request, and there is no + PyRIT-level ``response_format=image`` style hint to assert against. + Eliciting non-text output reliably depends on prompt phrasing, costs + real compute per probe, and is prone to false negatives from safety + filters. Trust ``target.capabilities.output_modalities`` as declared. + +.. warning:: + These probes only verify that a request was *accepted*. They do not prove + that the endpoint enforced the feature, and the JSON probes are only + meaningful for targets that translate ``prompt_metadata`` JSON hints into + provider request fields. Treat the results as an upper bound on support and + validate response content separately when that distinction matters. +""" + +import asyncio +import json +import logging +import os +import uuid +from collections.abc import Awaitable, Callable, Iterable, Iterator +from contextlib import contextmanager +from dataclasses import replace + +from pyrit.common.path import DATASETS_PATH +from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + TargetCapabilities, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration + +logger = logging.getLogger(__name__) + +# Per-call timeout (seconds) applied to every discovery request. Override per-call via +# the ``per_probe_timeout_s`` parameter on the public functions. +DEFAULT_PROBE_TIMEOUT_SECONDS: float = 30.0 +DEFAULT_PROBE_RETRY_BACKOFF_SECONDS: float = 0.1 +MAX_PROBE_RETRY_BACKOFF_SECONDS: float = 1.0 + +# Marker stamped onto every MessagePiece this module writes to memory. Consumers +# that aggregate or display memory rows can filter probe-written rows by checking +# ``piece.prompt_metadata.get("capability_probe") == "1"``. Memory does not yet +# expose a delete-by-conversation-id API, so tagging is the cleanup mechanism. +PROBE_METADATA_KEY: str = "capability_probe" +PROBE_METADATA_VALUE: str = "1" + +_CapabilityProbe = Callable[[PromptTarget, float, int], Awaitable[bool]] + + +# Every text probe sends a text-only payload. Permissive overrides therefore +# always include this combination so that ``_validate_request``'s per-piece +# data-type check does not reject text probes against text-less targets. +_TEXT_MODALITY: frozenset[frozenset[PromptDataType]] = frozenset({frozenset({"text"})}) + +# Packaged fallback assets for non-text modality discovery. +_TARGET_CAPABILITIES_DATASET_PATH = DATASETS_PATH / "prompt_target" / "target_capabilities" + + +@contextmanager +def _permissive_configuration( + *, + target: PromptTarget, + extra_input_modalities: Iterable[frozenset[PromptDataType]] | None = None, +) -> Iterator[None]: + """ + Temporarily replace ``target``'s configuration with one that declares every + boolean capability as natively supported. + + This bypasses :meth:`PromptTarget._validate_request`, which would otherwise + short-circuit probes for capabilities the target declares as unsupported + before any API call is made. The original configuration is restored on exit. + + Args: + target (PromptTarget): The target whose configuration is temporarily replaced. + extra_input_modalities (Iterable[frozenset[PromptDataType]] | None): + Additional modality combinations to include in ``input_modalities`` + during the override. Used by modality probes so that + ``_validate_request``'s per-piece data-type check does not reject + combinations the caller asked us to test but the target does not + yet declare. Defaults to None. + + Yields: + None: Control returns to the ``with`` block while the permissive + configuration is in effect. + """ + original = target.configuration + merged_modalities = original.capabilities.input_modalities | _TEXT_MODALITY + if extra_input_modalities is not None: + merged_modalities = frozenset(merged_modalities | frozenset(extra_input_modalities)) + permissive_caps = replace( + original.capabilities, + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_json_schema=True, + supports_json_output=True, + supports_editable_history=True, + supports_system_prompt=True, + input_modalities=merged_modalities, + ) + # Rebuild a fresh configuration from the instance's native capabilities so + # probes bypass preflight validation without inheriting ADAPT policy or + # custom normalizer overrides from the target's runtime configuration. + probe_configuration = TargetConfiguration(capabilities=permissive_caps) + target._configuration = probe_configuration + try: + yield + finally: + target._configuration = original + + +def _new_conversation_id() -> str: + """ + Generate a unique conversation id for a single capability probe. + + Returns: + str: A conversation id of the form ``"capability-probe-"``. + """ + return f"capability-probe-{uuid.uuid4()}" + + +def _probe_metadata(extra: dict[str, str | int] | None = None) -> dict[str, str | int]: + """Return a fresh ``prompt_metadata`` dict tagged as a capability probe.""" + metadata: dict[str, str | int] = {PROBE_METADATA_KEY: PROBE_METADATA_VALUE} + if extra: + metadata.update(extra) + return metadata + + +def _user_text_piece(*, value: str, conversation_id: str) -> MessagePiece: + """ + Build a single user-role text :class:`MessagePiece` for use in a probe. + + The piece's ``prompt_metadata`` is tagged with :data:`PROBE_METADATA_KEY` + so that consumers aggregating memory can filter out probe-written rows. + + Args: + value (str): The text payload to send. + conversation_id (str): The conversation id to attach to the piece. + + Returns: + MessagePiece: A user-role text piece bound to ``conversation_id``. + """ + return MessagePiece( + role="user", + original_value=value, + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + + +async def _send_and_check_async( + *, + target: PromptTarget, + message: Message, + timeout_s: float, + retries: int = 1, + label: str = "Capability probe", +) -> bool: + """ + Send ``message`` and report whether the call succeeded cleanly. + + Each attempt is bounded by ``timeout_s``. Exceptions (network errors, + timeouts, validation failures) trigger up to ``retries`` retries before + the probe is declared failed, with a short exponential backoff between + retry attempts; an explicit error response from the target is treated as + deterministic and never retried. + + Args: + target (PromptTarget): The target to send the probe message to. + message (Message): The probe message to send. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions are retried; a non-error response is final. + Retry attempts use exponential backoff starting at + :data:`DEFAULT_PROBE_RETRY_BACKOFF_SECONDS`. Defaults to 1. + label (str): Short label used in log messages. Defaults to + ``"Capability probe"``. + + Returns: + bool: ``True`` iff the call returned without raising and every response + piece reported ``response_error == "none"``; ``False`` otherwise. + Any other ``response_error`` value (``"blocked"``, ``"processing"``, + ``"empty"``, ``"unknown"``) is treated as failure. An empty response + list (or responses with no message pieces) is also treated as a failure. + """ + attempts = max(1, retries + 1) + last_exc: Exception | None = None + for attempt in range(attempts): + try: + responses = await asyncio.wait_for(target.send_prompt_async(message=message), timeout=timeout_s) + except asyncio.TimeoutError: + last_exc = TimeoutError(f"timed out after {timeout_s}s") + logger.debug("%s timed out (attempt %d/%d)", label, attempt + 1, attempts) + if attempt + 1 < attempts: + await _sleep_before_retry_async(attempt=attempt) + continue + except Exception as exc: + last_exc = exc + logger.debug("%s failed (attempt %d/%d): %s", label, attempt + 1, attempts, exc) + if attempt + 1 < attempts: + await _sleep_before_retry_async(attempt=attempt) + continue + + if not responses or not any(r.message_pieces for r in responses): + logger.debug("%s returned an empty response; treating as failure", label) + return False + for response in responses: + for piece in response.message_pieces: + if piece.response_error != "none": + logger.debug("%s returned error response: %s", label, piece.converted_value) + return False + return True + + logger.info("%s exhausted %d attempt(s); last error: %s", label, attempts, last_exc) + return False + + +def _retry_backoff_seconds(*, attempt: int) -> float: + """Return the exponential backoff delay for a retry attempt.""" + return min(DEFAULT_PROBE_RETRY_BACKOFF_SECONDS * (2**attempt), MAX_PROBE_RETRY_BACKOFF_SECONDS) + + +async def _sleep_before_retry_async(*, attempt: int) -> None: + """Sleep for the retry backoff associated with ``attempt``.""" + await asyncio.sleep(_retry_backoff_seconds(attempt=attempt)) + + +async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a system prompt followed by a user message. + + Writes a system-role :class:`MessagePiece` directly to ``target._memory`` + rather than calling :meth:`PromptTarget.set_system_prompt`. ``set_system_prompt`` + can be overridden by subclasses (e.g. mocks) to do nothing or to perform + extra work, which would mask whether the underlying API actually accepts a + system message. A direct memory write guarantees the probe sees the same + multi-piece, system-then-user payload the target's wire layer would see + via the standard pipeline. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the system + user request succeeded; ``False`` otherwise. + """ + conversation_id = _new_conversation_id() + system_piece = MessagePiece( + role="system", + original_value="You are a helpful assistant.", + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + try: + target._memory.add_message_to_memory(request=Message([system_piece])) + except Exception as exc: + logger.debug("System-prompt probe could not seed system message: %s", exc) + return False + user_piece = _user_text_piece(value="hi", conversation_id=conversation_id) + return await _send_and_check_async( + target=target, + message=Message([user_piece]), + timeout_s=timeout_s, + retries=retries, + label="System-prompt probe", + ) + + +async def _probe_multi_message_pieces_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a single message containing multiple pieces. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the multi-piece request succeeded; ``False`` otherwise. + """ + conversation_id = _new_conversation_id() + pieces = [ + _user_text_piece(value="part one", conversation_id=conversation_id), + _user_text_piece(value="part two", conversation_id=conversation_id), + ] + return await _send_and_check_async( + target=target, + message=Message(pieces), + timeout_s=timeout_s, + retries=retries, + label="Multi-message-pieces probe", + ) + + +async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a request that includes prior conversation history. + + ``PromptTarget.send_prompt_async`` reads conversation history from memory but + does not write to it (persistence normally happens in the orchestrator + layer). To exercise true multi-turn behavior, this probe: + + 1. Sends an initial user message. + 2. Persists that user message and a synthetic assistant reply directly to + the target's memory under the same ``conversation_id``. + 3. Sends a second user message; ``send_prompt_async`` then fetches the + 2-message history and the target receives a real 3-message + multi-turn payload. + + The synthetic assistant reply's content is irrelevant — we are testing + whether the target's API accepts a multi-turn payload, not whether the + model recalls anything. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if both turns succeeded; ``False`` if either turn failed. + """ + conversation_id = _new_conversation_id() + first = _user_text_piece(value="My favorite color is blue.", conversation_id=conversation_id) + if not await _send_and_check_async( + target=target, message=Message([first]), timeout_s=timeout_s, retries=retries, label="Multi-turn probe (turn 1)" + ): + return False + + # Seed memory so the second send sees real prior history. + try: + target._memory.add_message_to_memory(request=Message([first])) + assistant_reply = MessagePiece( + role="assistant", + original_value="Got it.", + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ).to_message() + target._memory.add_message_to_memory(request=assistant_reply) + except Exception as exc: + logger.debug("Multi-turn probe could not seed conversation history: %s", exc) + return False + + second = _user_text_piece(value="What did I just tell you?", conversation_id=conversation_id) + return await _send_and_check_async( + target=target, + message=Message([second]), + timeout_s=timeout_s, + retries=retries, + label="Multi-turn probe (turn 2)", + ) + + +async def _probe_json_output_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a request asking for JSON-mode output. + + This probe is only meaningful for targets that translate PyRIT's JSON + metadata hints into native provider request fields. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the JSON-mode request succeeded; ``False`` otherwise. + """ + conversation_id = _new_conversation_id() + piece = MessagePiece( + role="user", + original_value='Respond with a JSON object: {"ok": true}.', + original_value_data_type="text", + conversation_id=conversation_id, + # This only becomes a real JSON-mode request on targets that honor + # PyRIT's JSON metadata contract when building the provider payload. + prompt_metadata=_probe_metadata({"response_format": "json"}), + ) + return await _send_and_check_async( + target=target, message=Message([piece]), timeout_s=timeout_s, retries=retries, label="JSON-output probe" + ) + + +async def _probe_json_schema_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a request constrained by a JSON schema. + + This probe is only meaningful for targets that translate PyRIT's JSON + metadata hints into native provider request fields. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the schema-constrained request succeeded; ``False`` otherwise. + """ + schema = { + "type": "object", + "properties": {"ok": {"type": "boolean"}}, + "required": ["ok"], + "additionalProperties": False, + } + conversation_id = _new_conversation_id() + piece = MessagePiece( + role="user", + original_value='Respond with a JSON object matching the schema: {"ok": true}.', + original_value_data_type="text", + conversation_id=conversation_id, + # As above, this probe is only strong for targets that map these + # metadata keys to native JSON-schema request parameters. + prompt_metadata=_probe_metadata( + { + "response_format": "json", + "json_schema": json.dumps(schema), + } + ), + ) + return await _send_and_check_async( + target=target, message=Message([piece]), timeout_s=timeout_s, retries=retries, label="JSON-schema probe" + ) + + +# Registry of capabilities that can be queried via a live API call. +# Capabilities not present here fall back to the target's declared support. +_CAPABILITY_PROBES: dict[CapabilityName, _CapabilityProbe] = { + CapabilityName.SYSTEM_PROMPT: _probe_system_prompt_async, + CapabilityName.MULTI_MESSAGE_PIECES: _probe_multi_message_pieces_async, + CapabilityName.MULTI_TURN: _probe_multi_turn_async, + CapabilityName.JSON_OUTPUT: _probe_json_output_async, + CapabilityName.JSON_SCHEMA: _probe_json_schema_async, +} + + +async def discover_target_capabilities_async( + *, + target: PromptTarget, + capabilities: Iterable[CapabilityName] | None = None, + per_probe_timeout_s: float = DEFAULT_PROBE_TIMEOUT_SECONDS, + retries: int = 1, +) -> set[CapabilityName]: + """ + Probe which capabilities ``target`` accepts. + + Registered capabilities are checked with live requests. Capabilities + without a live probe fall back to declared native support. + + Args: + target (PromptTarget): The target to probe. + capabilities (Iterable[CapabilityName] | None): Capabilities to check. + Defaults to every member of :class:`CapabilityName`. + per_probe_timeout_s (float): Per-attempt timeout (seconds) applied to + each probe request. Defaults to + :data:`DEFAULT_PROBE_TIMEOUT_SECONDS`. + retries (int): Number of additional attempts after the first failure + for each probe. Only exceptions/timeouts are retried; an explicit + error response is final. Set to ``0`` to disable retries. + Defaults to 1. + + Returns: + set[CapabilityName]: The capabilities confirmed to work against the target. + """ + capabilities_to_check: list[CapabilityName] = ( + list(capabilities) if capabilities is not None else list(CapabilityName) + ) + + queried: set[CapabilityName] = set() + with _permissive_configuration(target=target): + for capability in capabilities_to_check: + probe = _CAPABILITY_PROBES.get(capability) + if probe is None: + # Capabilities without a probe are handled after the permissive + # override is removed so we can read the target's native flags. + continue + + try: + # "Supported" means the request was accepted. A target can + # still ignore the feature semantics after accepting the call. + if await probe(target, per_probe_timeout_s, retries): + queried.add(capability) + except Exception as exc: + logger.debug("Probe for %s raised: %s", capability.value, exc) + + # Read unprobed capabilities from target.capabilities, not + # target.configuration, so ADAPTed behavior is not reported as native + # support. + for capability in capabilities_to_check: + if capability not in _CAPABILITY_PROBES and target.capabilities.includes(capability=capability): + queried.add(capability) + + return queried + + +# --------------------------------------------------------------------------- +# Modality query +# --------------------------------------------------------------------------- + + +# Default mapping of non-text modalities to packaged probe assets. Callers can +# override via the ``test_assets`` parameter of +# :func:`discover_target_modalities_async`. Modalities whose assets do not exist +# on disk are skipped (logged and excluded from the result). +DEFAULT_TEST_ASSETS: dict[PromptDataType, str] = { + "audio_path": str(_TARGET_CAPABILITIES_DATASET_PATH / "probe_audio.wav"), + "image_path": str(_TARGET_CAPABILITIES_DATASET_PATH / "probe_image.png"), +} + + +async def discover_target_modalities_async( + *, + target: PromptTarget, + test_modalities: set[frozenset[PromptDataType]] | None = None, + test_assets: dict[PromptDataType, str] | None = None, + per_probe_timeout_s: float = DEFAULT_PROBE_TIMEOUT_SECONDS, + retries: int = 1, +) -> set[frozenset[PromptDataType]]: + """ + Probe which input modality combinations ``target`` accepts. + + Each modality combination is checked with a minimal request built from the + supplied test assets. + + Args: + target (PromptTarget): The target to probe. + test_modalities (set[frozenset[PromptDataType]] | None): Specific + modality combinations to test. Defaults to the combinations + declared in ``target.capabilities.input_modalities``. + test_assets (dict[PromptDataType, str] | None): Mapping from + non-text modality to a file path used as the probe payload. + Defaults to :data:`DEFAULT_TEST_ASSETS`. Combinations whose + non-text assets are missing on disk are skipped. + per_probe_timeout_s (float): Per-attempt timeout (seconds) applied to + each probe request. Defaults to + :data:`DEFAULT_PROBE_TIMEOUT_SECONDS`. + retries (int): Number of additional attempts after the first failure + for each probe. Only exceptions/timeouts are retried; an explicit + error response is final. Set to ``0`` to disable retries. + Defaults to 1. + + Returns: + set[frozenset[PromptDataType]]: The modality combinations confirmed + to work against the target. + """ + if test_modalities is None: + declared = target.capabilities.input_modalities + test_modalities = set(declared) + + assets = test_assets if test_assets is not None else DEFAULT_TEST_ASSETS + + queried: set[frozenset[PromptDataType]] = set() + with _permissive_configuration(target=target, extra_input_modalities=test_modalities): + for combination in test_modalities: + try: + message = _create_test_message(modalities=combination, test_assets=assets) + except FileNotFoundError as exc: + # Skip combinations we cannot construct a valid probe payload for. + logger.info("Skipping modality %s: %s", combination, exc) + continue + except ValueError as exc: + logger.info("Skipping modality %s: %s", combination, exc) + continue + + # "Supported" means the request was accepted. A target may still + # ignore the non-text payload after accepting it. + if await _send_and_check_async( + target=target, + message=message, + timeout_s=per_probe_timeout_s, + retries=retries, + label=f"Modality probe {sorted(combination)}", + ): + queried.add(combination) + + return queried + + +async def discover_target_async( + *, + target: PromptTarget, + per_probe_timeout_s: float = DEFAULT_PROBE_TIMEOUT_SECONDS, + test_modalities: set[frozenset[PromptDataType]] | None = None, + test_assets: dict[PromptDataType, str] | None = None, + capabilities: Iterable[CapabilityName] | None = None, + retries: int = 1, +) -> TargetCapabilities: + """ + Probe capabilities and modalities and return a merged result. + + This wraps :func:`discover_target_capabilities_async` and + :func:`discover_target_modalities_async` and returns a best-effort + :class:`TargetCapabilities`. + + Args: + target (PromptTarget): The target to probe. + per_probe_timeout_s (float): Per-attempt timeout (seconds) applied to + each probe request. + test_modalities (set[frozenset[PromptDataType]] | None): Specific + modality combinations to probe. See + :func:`discover_target_modalities_async`. Defaults to the + target's declared ``input_modalities``. + test_assets (dict[PromptDataType, str] | None): Mapping from non-text + modality to a file path. See :func:`discover_target_modalities_async`. + capabilities (Iterable[CapabilityName] | None): Capabilities to probe. + See :func:`discover_target_capabilities_async`. Defaults to every + member of :class:`CapabilityName`. + retries (int): Number of additional attempts after the first failure + for each probe. Only exceptions/timeouts are retried; an explicit + error response is final. Set to ``0`` to disable retries. + Defaults to 1. + + Returns: + TargetCapabilities: A merged capability view: probed where possible, + declared where probing is unavailable or out of scope. + """ + capabilities_to_probe = list(capabilities) if capabilities is not None else None + + queried_caps = await discover_target_capabilities_async( + target=target, + capabilities=capabilities_to_probe, + per_probe_timeout_s=per_probe_timeout_s, + retries=retries, + ) + queried_modalities = await discover_target_modalities_async( + target=target, + test_modalities=test_modalities, + test_assets=test_assets, + per_probe_timeout_s=per_probe_timeout_s, + retries=retries, + ) + + declared = target.capabilities + # If the caller narrows the capability set, leave the rest at their + # declared values instead of silently forcing them to False. + probed: set[CapabilityName] = ( + set(capabilities_to_probe) if capabilities_to_probe is not None else set(CapabilityName) + ) + + def _resolve(name: CapabilityName) -> bool: + if name in probed: + return name in queried_caps + return bool(getattr(declared, name.value)) + + resolved_multi_turn = _resolve(CapabilityName.MULTI_TURN) + # Editable history is only meaningful if multi-turn probing/declaration + # also resolved to True. + resolved_editable_history = declared.supports_editable_history and resolved_multi_turn + if test_modalities is None: + resolved_input_modalities = frozenset(queried_modalities) + else: + resolved_input_modalities = frozenset( + queried_modalities | (declared.input_modalities - frozenset(test_modalities)) + ) + + return TargetCapabilities( + supports_multi_turn=resolved_multi_turn, + supports_multi_message_pieces=_resolve(CapabilityName.MULTI_MESSAGE_PIECES), + supports_json_schema=_resolve(CapabilityName.JSON_SCHEMA), + supports_json_output=_resolve(CapabilityName.JSON_OUTPUT), + supports_editable_history=resolved_editable_history, + supports_system_prompt=_resolve(CapabilityName.SYSTEM_PROMPT), + input_modalities=resolved_input_modalities, + # Output modalities are still declarative because probing them would + # require target-specific response inspection. + output_modalities=declared.output_modalities, + ) + + +def _create_test_message( + *, + modalities: frozenset[PromptDataType], + test_assets: dict[PromptDataType, str], +) -> Message: + """ + Build a minimal :class:`Message` that exercises ``modalities``. + + Args: + modalities (frozenset[PromptDataType]): The modalities to include. + test_assets (dict[PromptDataType, str]): Mapping from non-text + modality to a file path used for the probe. + + Returns: + Message: A message containing one piece per modality. + + Raises: + FileNotFoundError: If a configured asset path does not exist. + ValueError: If a non-text modality has no configured asset, or if + no pieces could be constructed. + """ + conversation_id = f"modality-probe-{uuid.uuid4()}" + pieces: list[MessagePiece] = [] + + for modality in modalities: + if modality == "text": + pieces.append( + MessagePiece( + role="user", + original_value="test", + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + ) + continue + + asset_path = test_assets.get(modality) + if asset_path is None: + raise ValueError(f"No test asset configured for modality '{modality}'.") + if not os.path.isfile(asset_path): + raise FileNotFoundError(f"Test asset for modality '{modality}' not found at: {asset_path}") + + pieces.append( + MessagePiece( + role="user", + original_value=asset_path, + original_value_data_type=modality, + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + ) + + if not pieces: + raise ValueError(f"Could not create test message for modalities: {modalities}") + + return Message(pieces) diff --git a/tests/unit/common/test_common_net_utility.py b/tests/unit/common/test_common_net_utility.py index 58fff4b222..5088a166de 100644 --- a/tests/unit/common/test_common_net_utility.py +++ b/tests/unit/common/test_common_net_utility.py @@ -77,8 +77,14 @@ def response_callback(request): async def test_debug_is_false_by_default(): with patch("pyrit.common.net_utility.get_httpx_client") as mock_get_httpx_client: - mock_client_instance = MagicMock() - mock_get_httpx_client.return_value = mock_client_instance + mock_client_context = MagicMock() + mock_client = MagicMock() + mock_client.request = AsyncMock( + return_value=httpx.Response(status_code=200, request=httpx.Request("GET", "http://example.com")) + ) + mock_client_context.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_context.__aexit__ = AsyncMock(return_value=None) + mock_get_httpx_client.return_value = mock_client_context await make_request_and_raise_if_error_async(endpoint_uri="http://example.com", method="GET") diff --git a/tests/unit/prompt_target/test_query_target_capabilities.py b/tests/unit/prompt_target/test_query_target_capabilities.py new file mode 100644 index 0000000000..ec53895355 --- /dev/null +++ b/tests/unit/prompt_target/test_query_target_capabilities.py @@ -0,0 +1,909 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.query_target_capabilities import ( + _CAPABILITY_PROBES, + DEFAULT_TEST_ASSETS, + _create_test_message, + _permissive_configuration, + discover_target_async, + discover_target_capabilities_async, + discover_target_modalities_async, +) +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from tests.unit.mocks import MockPromptTarget + + +class _RealValidationTarget(PromptTarget): + """ + Bare ``PromptTarget`` subclass that does NOT override ``_validate_request``. + + Tests that need to verify ``_permissive_configuration`` actually bypasses + the validation guard use this instead of ``MockPromptTarget`` (which + no-ops ``_validate_request``). + """ + + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities(), + ) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return _ok_response() + + +def _ok_response(*, conversation_id: str = "probe", text: str = "ok") -> list[Message]: + return [ + Message( + [ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + conversation_id=conversation_id, + response_error="none", + ) + ] + ) + ] + + +def _error_response(*, conversation_id: str = "probe") -> list[Message]: + return [ + Message( + [ + MessagePiece( + role="assistant", + original_value="blocked", + original_value_data_type="text", + conversation_id=conversation_id, + response_error="blocked", + ) + ] + ) + ] + + +@pytest.mark.usefixtures("patch_central_database") +class TestPermissiveConfiguration: + def test_replaces_and_restores_configuration(self) -> None: + target = MockPromptTarget() + original = target.configuration + + with _permissive_configuration(target=target): + permissive = target.configuration + assert permissive is not original + for capability in CapabilityName: + assert permissive.includes(capability=capability) + + assert target.configuration is original + + def test_restores_on_exception(self) -> None: + target = MockPromptTarget() + original = target.configuration + + with pytest.raises(RuntimeError): + with _permissive_configuration(target=target): + raise RuntimeError("boom") + + assert target.configuration is original + + +@pytest.mark.usefixtures("patch_central_database") +class TestQueryTargetCapabilitiesAsync: + async def test_returns_only_supported_when_all_probes_succeed(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target) + + # Every capability with a probe should be in the result. + for capability in _CAPABILITY_PROBES: + assert capability in result + + async def test_excludes_capabilities_when_probe_fails(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("nope")) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target) + + for capability in _CAPABILITY_PROBES: + assert capability not in result + + async def test_excludes_capabilities_when_response_has_error(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_error_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target) + + for capability in _CAPABILITY_PROBES: + assert capability not in result + + async def test_filters_by_requested_capabilities(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + requested = {CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN} + result = await discover_target_capabilities_async(target=target, capabilities=requested) + + assert result == requested + + async def test_capability_without_probe_falls_back_to_declared_support(self) -> None: + target = MockPromptTarget() + # Override the configuration so editable_history is declared as supported. + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_editable_history=True), + ) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.EDITABLE_HISTORY}, + ) + + assert result == {CapabilityName.EDITABLE_HISTORY} + + async def test_capability_without_probe_excluded_when_not_declared(self) -> None: + target = MockPromptTarget() + # Override to a configuration that does NOT declare editable_history. + target._configuration = TargetConfiguration(capabilities=TargetCapabilities()) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.EDITABLE_HISTORY}, + ) + + assert result == set() + + async def test_capability_without_probe_excluded_when_only_adapted(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ + ADAPT in the policy must NOT count as native support for the fallback. + + Today every adaptable capability also has a probe, so this scenario only + arises if a future capability is declared adaptable without a probe. + We simulate that by removing SYSTEM_PROMPT from the registry and + configuring the target with ``ADAPT`` for it but no native support. + """ + from pyrit.prompt_target.common import query_target_capabilities as qtc + from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + UnsupportedCapabilityBehavior, + ) + + patched_probes = {k: v for k, v in qtc._CAPABILITY_PROBES.items() if k is not CapabilityName.SYSTEM_PROMPT} + monkeypatch.setattr(qtc, "_CAPABILITY_PROBES", patched_probes) + + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(), # no native SYSTEM_PROMPT + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + } + ), + ) + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == set() + + async def test_accepts_single_pass_iterable(self) -> None: + """Passing a generator must not silently drop fallback (non-probed) capabilities.""" + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_editable_history=True), + ) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + gen = (c for c in [CapabilityName.SYSTEM_PROMPT, CapabilityName.EDITABLE_HISTORY]) + result = await discover_target_capabilities_async(target=target, capabilities=gen) + + assert CapabilityName.SYSTEM_PROMPT in result + assert CapabilityName.EDITABLE_HISTORY in result + + async def test_retries_zero_disables_retry(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("boom")) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + retries=0, + ) + + assert result == set() + assert target._send_prompt_to_target_async.await_count == 1 + + async def test_retries_use_exponential_backoff(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("boom")) # type: ignore[method-assign] + + with patch( + "pyrit.prompt_target.common.query_target_capabilities.asyncio.sleep", new_callable=AsyncMock + ) as sleep_mock: + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + retries=2, + ) + + assert result == set() + assert sleep_mock.await_args_list[0].args == (0.1,) + assert sleep_mock.await_args_list[1].args == (0.2,) + + async def test_restores_configuration_after_probing(self) -> None: + target = MockPromptTarget() + original = target.configuration + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_capabilities_async(target=target) + + assert target.configuration is original + + async def test_multi_turn_probe_sends_history_on_second_call(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.MULTI_TURN}, + ) + + # Multi-turn probe sends two requests on the same conversation_id, and + # seeds memory between them so the second call carries real history. + calls = target._send_prompt_to_target_async.await_args_list + assert len(calls) == 2 + + first_conv = calls[0].kwargs["normalized_conversation"] + second_conv = calls[1].kwargs["normalized_conversation"] + + first_conv_id = first_conv[-1].message_pieces[0].conversation_id + second_conv_id = second_conv[-1].message_pieces[0].conversation_id + assert first_conv_id == second_conv_id + + # First call is a single-turn user message; the second call must include + # the seeded user + assistant history followed by the new user turn. + assert len(first_conv) == 1 + assert len(second_conv) >= 3 + roles = [msg.message_pieces[0]._role for msg in second_conv] + assert roles[-3:] == ["user", "assistant", "user"] + + async def test_multi_turn_probe_short_circuits_on_first_failure(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("first call fails")) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.MULTI_TURN}, + ) + + assert result == set() + # _send_and_check_async retries once on exception, so the failing + # first turn is attempted twice; the second turn is never reached. + assert target._send_prompt_to_target_async.await_count == 2 + + async def test_json_schema_probe_sends_schema_in_metadata(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_SCHEMA}, + ) + + normalized: list[Message] = target._send_prompt_to_target_async.await_args.kwargs["normalized_conversation"] + metadata = normalized[-1].message_pieces[0].prompt_metadata + assert metadata is not None + assert metadata["response_format"] == "json" + # Schema is JSON-encoded into a string for prompt_metadata's value type. + schema = json.loads(metadata["json_schema"]) + assert schema["type"] == "object" + + async def test_system_prompt_probe_installs_system_message_and_sends_user(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + # The probe writes a system message directly to memory (bypassing + # PromptTarget.set_system_prompt, which subclasses can override) and + # then sends a user-role message. Message.validate forbids mixed + # roles in a single Message, so the system and user turns are + # separate. Verify the system message is in memory and the wire + # payload contains the system + user history. + normalized: list[Message] = target._send_prompt_to_target_async.await_args.kwargs["normalized_conversation"] + roles_sent = [piece._role for msg in normalized for piece in msg.message_pieces] + assert "system" in roles_sent + assert roles_sent[-1] == "user" + # The last sent Message itself should be user-only. + assert [piece._role for piece in normalized[-1].message_pieces] == ["user"] + + async def test_multi_message_pieces_probe_sends_two_pieces(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.MULTI_MESSAGE_PIECES}, + ) + + normalized: list[Message] = target._send_prompt_to_target_async.await_args.kwargs["normalized_conversation"] + assert len(normalized[-1].message_pieces) == 2 + + async def test_probes_run_under_permissive_configuration(self) -> None: + """ + Even when the target declares no boolean capabilities, the probe should + still execute because the configuration is temporarily permissive. + + Uses ``_RealValidationTarget`` so that ``_validate_request`` actually + runs and would reject the multi-piece probe were the override absent. + """ + target = _RealValidationTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.MULTI_MESSAGE_PIECES}, + ) + + # Probe was actually invoked through the full send_prompt_async pipeline, + # which means _validate_request ran and was satisfied by the permissive + # override (the bare target declares no capabilities natively). + assert send_mock.await_count >= 1 + assert CapabilityName.MULTI_MESSAGE_PIECES in result + + async def test_probed_capability_excluded_when_only_adapted(self) -> None: + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_system_prompt=False), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + } + ), + ) + + async def reject_system_roles(*, normalized_conversation: list[Message]) -> list[Message]: + roles = [piece._role for message in normalized_conversation for piece in message.message_pieces] + if "system" in roles: + raise RuntimeError("system messages are not natively supported") + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=reject_system_roles) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == set() + + async def test_probe_configuration_does_not_reuse_adapted_pipeline(self) -> None: + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_system_prompt=False), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + } + ), + ) + + async def require_native_system_role(*, normalized_conversation: list[Message]) -> list[Message]: + roles = [piece._role for message in normalized_conversation for piece in message.message_pieces] + if "system" not in roles: + raise RuntimeError("probe used adapted system-prompt shaping") + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=require_native_system_role) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == {CapabilityName.SYSTEM_PROMPT} + + +@pytest.mark.usefixtures("patch_central_database") +class TestQueryTargetCapabilitiesIsolatedTarget: + """Tests using a bare PromptTarget subclass (no PromptChatTarget extras).""" + + async def test_with_minimal_target_subclass(self) -> None: + class _MinimalTarget(PromptTarget): + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return _ok_response() + + target = _MinimalTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target) + + for capability in _CAPABILITY_PROBES: + assert capability in result + + +# --------------------------------------------------------------------------- +# Modality query tests +# --------------------------------------------------------------------------- + + +def _set_input_modalities( + *, + target: MockPromptTarget, + modalities: set[frozenset[PromptDataType]], +) -> None: + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities( + input_modalities=frozenset(modalities), + ), + ) + + +@pytest.fixture +def image_asset(tmp_path: Path) -> str: + """Create a tiny placeholder file usable as an image_path asset.""" + asset = tmp_path / "test_image.png" + asset.write_bytes(b"\x89PNG\r\n\x1a\n") + return str(asset) + + +@pytest.mark.usefixtures("patch_central_database") +class TestCreateTestMessage: + def test_default_assets_exist_for_packaged_modalities(self) -> None: + msg = _create_test_message( + modalities=frozenset({"audio_path", "image_path"}), + test_assets=DEFAULT_TEST_ASSETS, + ) + + types = {piece.original_value_data_type for piece in msg.message_pieces} + assert types == {"audio_path", "image_path"} + + def test_text_only(self) -> None: + msg = _create_test_message(modalities=frozenset({"text"}), test_assets={}) + assert len(msg.message_pieces) == 1 + assert msg.message_pieces[0].original_value_data_type == "text" + + def test_multimodal_uses_assets(self, image_asset: str) -> None: + msg = _create_test_message( + modalities=frozenset({"text", "image_path"}), + test_assets={"image_path": image_asset}, + ) + types = {piece.original_value_data_type for piece in msg.message_pieces} + assert types == {"text", "image_path"} + + # All pieces share the same conversation_id (Message.validate requires it). + conv_ids = {piece.conversation_id for piece in msg.message_pieces} + assert len(conv_ids) == 1 + + def test_missing_asset_file_raises_filenotfound(self, tmp_path: Path) -> None: + missing_path = str(tmp_path / "does_not_exist.png") + with pytest.raises(FileNotFoundError): + _create_test_message( + modalities=frozenset({"image_path"}), + test_assets={"image_path": missing_path}, + ) + + def test_unconfigured_modality_raises_valueerror(self) -> None: + with pytest.raises(ValueError, match="No test asset configured"): + _create_test_message( + modalities=frozenset({"image_path"}), + test_assets={}, + ) + + +@pytest.mark.usefixtures("patch_central_database") +class TestVerifyTargetModalitiesAsync: + async def test_all_combinations_supported(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_modalities_async(target=target) + + assert frozenset({"text"}) in result + + async def test_exception_excludes_combination(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("nope")) # type: ignore[method-assign] + + result = await discover_target_modalities_async(target=target) + + assert result == set() + + async def test_error_response_excludes_combination(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_error_response()) # type: ignore[method-assign] + + result = await discover_target_modalities_async(target=target) + + assert result == set() + + async def test_partial_support_via_selective_failure(self, image_asset: str) -> None: + target = MockPromptTarget() + _set_input_modalities( + target=target, + modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, + ) + + async def selective_send(*, normalized_conversation: list[Message]) -> list[Message]: + message = normalized_conversation[-1] + types = {p.original_value_data_type for p in message.message_pieces} + if "image_path" in types: + raise Exception("image not supported") + return _ok_response() + + target._send_prompt_to_target_async = selective_send # type: ignore[method-assign] + + result = await discover_target_modalities_async( + target=target, + test_assets={"image_path": image_asset}, + ) + + assert frozenset({"text"}) in result + assert frozenset({"text", "image_path"}) not in result + + async def test_explicit_test_modalities_overrides_declared(self, image_asset: str) -> None: + target = MockPromptTarget() + # Declared as text-only, but caller asks us to probe text+image too. + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_modalities_async( + target=target, + test_modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, + test_assets={"image_path": image_asset}, + ) + + assert frozenset({"text"}) in result + assert frozenset({"text", "image_path"}) in result + + async def test_combination_skipped_when_asset_missing(self, tmp_path: Path) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text", "image_path"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + # An explicit empty mapping disables the packaged defaults, so + # image_path combinations are skipped instead of probed. + result = await discover_target_modalities_async(target=target, test_assets={}) + + assert result == set() + assert target._send_prompt_to_target_async.await_count == 0 + + async def test_explicit_test_modalities_runs_under_permissive_configuration(self, image_asset: str) -> None: + """ + Probing a modality combination the target does NOT declare must still + succeed. Uses ``_RealValidationTarget`` so ``_validate_request`` runs + and would reject the multi-piece, non-text payload were the + permissive override absent. + """ + target = _RealValidationTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + result = await discover_target_modalities_async( + target=target, + test_modalities={frozenset({"text", "image_path"})}, + test_assets={"image_path": image_asset}, + ) + + assert send_mock.await_count == 1 + assert frozenset({"text", "image_path"}) in result + + +@pytest.mark.usefixtures("patch_central_database") +class TestSendAndCheckTimeout: + async def test_timeout_returns_false_after_retries(self) -> None: + """ + When ``send_prompt_async`` exceeds ``per_probe_timeout_s``, the probe + is treated as failed. ``_send_and_check_async`` retries once on + timeout, so the underlying mock is awaited twice and the capability + is excluded from the queried set. + """ + target = MockPromptTarget() + + async def _hang(**_kwargs: object) -> list[Message]: + await asyncio.sleep(10) + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=_hang) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + per_probe_timeout_s=0.01, + ) + + assert result == set() + # One initial attempt plus one retry. + assert target._send_prompt_to_target_async.await_count == 2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestSystemPromptProbeMemoryFailure: + async def test_returns_false_when_memory_seed_raises(self) -> None: + """ + If seeding the system message into memory raises (e.g. backend + offline), the system-prompt probe returns False without attempting + the user send. + """ + target = MockPromptTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + with patch.object(target._memory, "add_message_to_memory", side_effect=RuntimeError("memory offline")): + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == set() + # The user send is never attempted because seeding failed. + send_mock.assert_not_awaited() + + +@pytest.mark.usefixtures("patch_central_database") +class TestVerifyTargetAsync: + async def test_returns_target_capabilities_assembled_from_probes(self) -> None: + """ + ``discover_target_async`` runs both the capability and modality probes + and assembles a :class:`TargetCapabilities` populated from the + queried results, copying ``output_modalities`` from the target's + declared capabilities and deriving editable history conservatively. + """ + declared = TargetCapabilities( + input_modalities=frozenset({frozenset({"text"})}), + output_modalities=frozenset({frozenset({"text"})}), + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_async(target=target, per_probe_timeout_s=5.0) + + assert isinstance(result, TargetCapabilities) + # Single-piece probes that don't touch memory always succeed when + # the underlying send returns a clean response. + assert result.supports_multi_message_pieces is True + assert result.supports_json_schema is True + assert result.supports_json_output is True + # Editable history is conservative and therefore cannot remain true + # when multi-turn support was not confirmed by probing. + assert result.supports_editable_history is False + # Modalities returned from the modality probe (text combination). + assert frozenset({"text"}) in result.input_modalities + # Output modalities copied through (not probed). + assert result.output_modalities == declared.output_modalities + + async def test_excludes_capabilities_when_probe_send_fails(self) -> None: + """ + When the underlying send raises, no capability or modality is + queried, but ``supports_editable_history`` and ``output_modalities`` + are still copied conservatively from the declared capabilities. + """ + declared = TargetCapabilities( + supports_editable_history=True, + output_modalities=frozenset({frozenset({"text"})}), + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + + result = await discover_target_async(target=target, per_probe_timeout_s=0.5) + + assert result.supports_multi_turn is False + assert result.supports_system_prompt is False + assert result.supports_json_output is False + assert result.supports_json_schema is False + assert result.supports_multi_message_pieces is False + # Editable history is derived conservatively and must fall when + # multi-turn probing disproves the prerequisite capability. + assert result.supports_editable_history is False + # No modalities queried because send always fails. + assert result.input_modalities == frozenset() + # Output modalities still copied. + assert result.output_modalities == declared.output_modalities + + async def test_empty_response_treated_as_failure(self) -> None: + """A target returning an empty response list must NOT be reported as supporting probes.""" + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=[]) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.MULTI_MESSAGE_PIECES}, + ) + + assert result == set() + + async def test_response_with_no_pieces_treated_as_failure(self) -> None: + """Responses whose Messages have no pieces must also be rejected.""" + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock( # type: ignore[method-assign] + return_value=[Message.__new__(Message)] + ) + # Bypass __init__ to construct a Message with no pieces (Message.__init__ rejects empty). + empty_msg = target._send_prompt_to_target_async.return_value[0] + empty_msg.message_pieces = [] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + ) + + assert result == set() + + async def test_discover_target_async_forwards_test_modalities(self, image_asset: str) -> None: + declared = TargetCapabilities(input_modalities=frozenset({frozenset({"text"})})) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) + + extra_combo = frozenset({"text", "image_path"}) + result = await discover_target_async( + target=target, + test_modalities={extra_combo}, + test_assets={"image_path": image_asset}, + per_probe_timeout_s=2.0, + ) + + # The undeclared combination is in the result only if test_modalities was forwarded. + assert extra_combo in result.input_modalities + + async def test_discover_target_async_preserves_declared_modalities_when_test_modalities_narrowed( + self, image_asset: str + ) -> None: + declared_combo = frozenset({"text"}) + probed_combo = frozenset({"text", "image_path"}) + declared = TargetCapabilities(input_modalities=frozenset({declared_combo, probed_combo})) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) + + result = await discover_target_async( + target=target, + test_modalities={probed_combo}, + test_assets={"image_path": image_asset}, + per_probe_timeout_s=2.0, + ) + + assert result.input_modalities == frozenset({declared_combo, probed_combo}) + + async def test_discover_target_async_forwards_capabilities(self) -> None: + """``discover_target_async`` must forward ``capabilities`` to narrow the probe set.""" + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + per_probe_timeout_s=2.0, + ) + + # Only the JSON_OUTPUT probe (1 send) and the modality probe(s) should run; + # if `capabilities` were ignored, all 5 capability probes would fire (>= 6 sends + # because multi-turn issues 2 sends). + assert target._send_prompt_to_target_async.await_count <= 3 + + async def test_discover_target_async_preserves_declared_when_capabilities_narrowed(self) -> None: + """ + When ``capabilities`` narrows the probe set, capabilities NOT in the + narrowed set must fall back to the target's declared values rather + than being silently reset to False. + """ + declared = TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=True, + supports_json_schema=True, + supports_editable_history=True, + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + per_probe_timeout_s=2.0, + ) + + # The probed capability reflects the queried result. + assert result.supports_json_output is True + # Non-probed capabilities fall back to declared values. + assert result.supports_multi_turn is True + assert result.supports_system_prompt is True + assert result.supports_json_schema is True + assert result.supports_editable_history is True + + async def test_discover_target_async_drops_editable_history_when_multi_turn_probe_fails(self) -> None: + """Editable history must not remain true when probing disproves multi-turn support.""" + declared = TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + output_modalities=frozenset({frozenset({"text"})}), + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + + async def selective_send(*, normalized_conversation: list[Message]) -> list[Message]: + latest_text = normalized_conversation[-1].message_pieces[0].original_value + if latest_text == "My favorite color is blue." or latest_text == "What did I just tell you?": + raise RuntimeError("multi-turn unsupported") + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=selective_send) # type: ignore[method-assign] + + result = await discover_target_async(target=target, per_probe_timeout_s=2.0) + + assert result.supports_multi_turn is False + assert result.supports_editable_history is False + + async def test_discover_target_async_accepts_single_pass_iterable(self) -> None: + declared = TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + gen = (c for c in [CapabilityName.JSON_OUTPUT, CapabilityName.EDITABLE_HISTORY]) + result = await discover_target_async( + target=target, + capabilities=gen, + per_probe_timeout_s=2.0, + ) + + assert result.supports_json_output is True + assert result.supports_editable_history is True + + +@pytest.mark.usefixtures("patch_central_database") +class TestMultiTurnProbeMemoryFailure: + async def test_returns_false_when_history_seed_raises(self) -> None: + """ + If seeding conversation history into memory raises, the multi-turn + probe returns False rather than proceeding with a half-seeded + conversation that would produce a false positive. + """ + target = MockPromptTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + with patch.object(target._memory, "add_message_to_memory", side_effect=RuntimeError("memory offline")): + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.MULTI_TURN}, + ) + + assert result == set() + # The first turn ran (1 send); the second turn must NOT run because + # seeding failed, otherwise the probe would falsely succeed. + assert send_mock.await_count == 1