diff --git a/pyrit/datasets/prompt_converters/translation_user_prompt.yaml b/pyrit/datasets/prompt_converters/translation_user_prompt.yaml new file mode 100644 index 0000000000..de65a35006 --- /dev/null +++ b/pyrit/datasets/prompt_converters/translation_user_prompt.yaml @@ -0,0 +1,13 @@ +name: translation_user_prompt +description: | + User-side prompt template that wraps the input with begin/end tags for the translation converter. +authors: + - AI Red Team +groups: + - AI Red Team +source: AI Red Team +parameters: + - language + - objective +data_type: text +value: "Translate the following to {{ language }} between the begin and end tags:=== begin ===\n{{ objective }}\n=== end ===\n\n\n" diff --git a/pyrit/datasets/prompt_converters/variation_user_prompt.yaml b/pyrit/datasets/prompt_converters/variation_user_prompt.yaml new file mode 100644 index 0000000000..d4d5bd4ad9 --- /dev/null +++ b/pyrit/datasets/prompt_converters/variation_user_prompt.yaml @@ -0,0 +1,12 @@ +name: variation_user_prompt +description: | + User-side prompt template wrapping input with begin/end tags for the variation converter. +authors: + - AI Red Team +groups: + - AI Red Team +source: AI Red Team +parameters: + - objective +data_type: text +value: "Create 1 variation of the seed prompt given by the user between the begin and end tags=== begin ==={{ objective }}=== end ===" diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index 258d94168c..26bb5760d4 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -4,9 +4,13 @@ import hashlib import logging import uuid -from typing import Any, Optional +from typing import TYPE_CHECKING, Any + +from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_attempt, wait_exponential, wait_none from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults +from pyrit.exceptions.exception_classes import _DynamicStopAfterAttempt, get_retry_max_num_attempts +from pyrit.exceptions.exceptions_helpers import log_exception from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, @@ -17,25 +21,42 @@ from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS, PromptTarget +if TYPE_CHECKING: + from tenacity.stop import stop_base + from tenacity.wait import wait_base + logger = logging.getLogger(__name__) class LLMGenericTextConverter(PromptConverter): """ - Represents a generic LLM converter that expects text to be transformed (e.g. no JSON parsing or format). + Represents a generic LLM-backed converter for text-in/text-out transformations. + + Subclasses may override ``_process_response`` to parse, extract, or otherwise post-process + the raw LLM response (e.g., JSON parsing). Subclasses opt into retry behavior by setting + ``RETRY_EXCEPTIONS`` to the tuple of exception types that should trigger a retry; by default + the attempt count is read from the ``RETRY_MAX_NUM_ATTEMPTS`` environment variable and no + wait is applied between attempts (matching ``pyrit_json_retry``). Subclasses needing a + fixed attempt count or exponential backoff can pass ``max_retry_attempts`` and/or + ``retry_wait_max_seconds`` to ``__init__``. """ SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + RETRY_EXCEPTIONS: tuple[type[BaseException], ...] = () + @apply_defaults def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - system_prompt_template: Optional[SeedPrompt] = None, - user_prompt_template_with_objective: Optional[SeedPrompt] = None, + system_prompt_template: SeedPrompt | None = None, + user_prompt_template_with_objective: SeedPrompt | None = None, + retry_exceptions: tuple[type[BaseException], ...] | None = None, + max_retry_attempts: int | None = None, + retry_wait_max_seconds: int | None = None, **kwargs: Any, ) -> None: """ @@ -46,18 +67,34 @@ def __init__( ``CHAT_TARGET_REQUIREMENTS`` (multi-turn + editable history capabilities, possibly via normalization-pipeline adaptation). Can be omitted if a default has been configured via PyRIT initialization. - system_prompt_template (SeedPrompt, Optional): The prompt template to set as the system prompt. - user_prompt_template_with_objective (SeedPrompt, Optional): The prompt template to set as the user prompt. - expects - kwargs: Additional parameters for the prompt template. + system_prompt_template (SeedPrompt | None): The prompt template to set as the system prompt. + user_prompt_template_with_objective (SeedPrompt | None): The prompt template to wrap the + user input with. Must include an ``objective`` parameter; the raw user prompt is rendered + as ``objective``. Additional ``**kwargs`` are also forwarded to the renderer, so subclasses + can pass static template parameters (e.g., ``language``). + retry_exceptions (tuple[type[BaseException], ...] | None): Exception types that should + trigger a retry. Overrides the class-level ``RETRY_EXCEPTIONS`` for this instance only. + If ``None``, ``RETRY_EXCEPTIONS`` is used. + max_retry_attempts (int | None): Maximum number of retry attempts. If ``None``, the + value is read at retry time from the ``RETRY_MAX_NUM_ATTEMPTS`` environment variable. + retry_wait_max_seconds (int | None): Upper bound (in seconds) for exponential backoff + between retry attempts. If ``None``, no wait is applied between attempts (matches + ``pyrit_json_retry``). + kwargs: Additional parameters forwarded to both the system prompt and user prompt templates + during rendering. Raises: ValueError: If converter_target is not provided and no default has been configured. + ValueError: If ``user_prompt_template_with_objective`` does not declare an ``objective`` + parameter. """ super().__init__(converter_target=converter_target) self._converter_target = converter_target self._system_prompt_template = system_prompt_template self._prompt_kwargs = kwargs + self._retry_exceptions = retry_exceptions if retry_exceptions is not None else self.RETRY_EXCEPTIONS + self._max_retry_attempts = max_retry_attempts + self._retry_wait_max_seconds = retry_wait_max_seconds if user_prompt_template_with_objective and ( user_prompt_template_with_objective.parameters is None @@ -109,30 +146,33 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text Raises: ValueError: If the input type is not supported. """ - conversation_id = str(uuid.uuid4()) + if not self.input_supported(input_type): + raise ValueError("Input type not supported") + conversation_id = str(uuid.uuid4()) kwargs = self._prompt_kwargs.copy() if self._system_prompt_template: system_prompt = self._system_prompt_template.render_template_value(**kwargs) - self._converter_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, attack_identifier=None, ) - if not self.input_supported(input_type): - raise ValueError("Input type not supported") - + converted_prompt = prompt if self._user_prompt_template_with_objective: - prompt = self._user_prompt_template_with_objective.render_template_value(objective=prompt) + template_kwargs = {k: v for k, v in kwargs.items() if k != "objective"} + converted_prompt = self._user_prompt_template_with_objective.render_template_value( + objective=prompt, **template_kwargs + ) request = Message( [ MessagePiece( role="user", original_value=prompt, + converted_value=converted_prompt, conversation_id=conversation_id, sequence=1, prompt_target_identifier=self._converter_target.get_identifier(), @@ -143,5 +183,68 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text ] ) - response = await self._converter_target.send_prompt_async(message=request) - return ConverterResult(output_text=response[0].get_value(), output_type="text") + response_text = await self._send_with_retries_async(request) + return ConverterResult(output_text=response_text, output_type="text") + + async def _send_with_retries_async(self, request: Message) -> str: + """ + Send the request to the converter target, retrying on configured exception types. + + When ``self._retry_exceptions`` is empty, the request is sent once with no retry. + Otherwise, the attempt count comes from ``self._max_retry_attempts`` (or the + ``RETRY_MAX_NUM_ATTEMPTS`` env variable when unset) and the wait between attempts + comes from ``self._retry_wait_max_seconds`` (or no wait when unset). The final + exception is re-raised unchanged. + + Args: + request (Message): The message to send to the converter target. + + Returns: + str: The post-processed response text from ``_process_response``. + + Raises: + RuntimeError: Defensive guard for an unreachable code path; tenacity always + re-raises the underlying exception when retries are exhausted. + """ + if not self._retry_exceptions: + response = await self._converter_target.send_prompt_async(message=request) + return self._process_response(response[0].get_value()) + + stop_strategy: stop_base = ( + stop_after_attempt(self._max_retry_attempts) + if self._max_retry_attempts is not None + else _DynamicStopAfterAttempt(get_retry_max_num_attempts) + ) + wait_strategy: wait_base = ( + wait_exponential(multiplier=1, min=1, max=self._retry_wait_max_seconds) + if self._retry_wait_max_seconds is not None + else wait_none() + ) + + async for attempt in AsyncRetrying( + retry=retry_if_exception_type(self._retry_exceptions), + stop=stop_strategy, + wait=wait_strategy, + reraise=True, + after=log_exception, + ): + with attempt: + response = await self._converter_target.send_prompt_async(message=request) + return self._process_response(response[0].get_value()) + + raise RuntimeError("unreachable: tenacity reraises on exhaustion") # pragma: no cover + + def _process_response(self, response_text: str) -> str: + """ + Post-process the raw LLM response text. + + Subclasses override this to parse JSON, extract fields, strip whitespace, etc. + The default implementation returns the response unchanged. + + Args: + response_text (str): The raw text returned by the LLM. + + Returns: + str: The processed text used as the converter output. + """ + return response_text diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index cbcfb66d4d..748de2d26a 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -4,29 +4,26 @@ import json import logging import pathlib -import uuid +import warnings from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.exceptions import ( InvalidJsonException, - pyrit_json_retry, remove_markdown_json, ) from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, - MessagePiece, - PromptDataType, SeedPrompt, ) -from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS, PromptTarget +from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) -class PersuasionConverter(PromptConverter): +class PersuasionConverter(LLMGenericTextConverter): """ Rephrases prompts using a variety of persuasion techniques. @@ -45,9 +42,7 @@ class PersuasionConverter(PromptConverter): Presenting oneself or an issue in a way that's not genuine or true. """ - SUPPORTED_INPUT_TYPES = ("text",) - SUPPORTED_OUTPUT_TYPES = ("text",) - TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + RETRY_EXCEPTIONS = (InvalidJsonException,) @apply_defaults def __init__( @@ -70,20 +65,24 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the persuasion technique is not supported or does not exist. """ - super().__init__(converter_target=converter_target) - self.converter_target = converter_target - try: - prompt_template = SeedPrompt.from_yaml_file( + system_prompt_template = SeedPrompt.from_yaml_file( pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "persuasion" / f"{persuasion_technique}.yaml" ) except FileNotFoundError: raise ValueError( f"Persuasion technique '{persuasion_technique}' does not exist or is not supported." ) from None - self.system_prompt = str(prompt_template.value) + + self.system_prompt = str(system_prompt_template.value) self._persuasion_technique = persuasion_technique + super().__init__( + converter_target=converter_target, + system_prompt_template=system_prompt_template, + ) + self.converter_target = converter_target + def _build_identifier(self) -> ComponentIdentifier: """ Build the converter identifier with persuasion parameters. @@ -98,77 +97,42 @@ def _build_identifier(self) -> ComponentIdentifier: children={"converter_target": self.converter_target.get_identifier()}, ) - async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: + def _process_response(self, response_text: str) -> str: """ - Convert the given prompt using the persuasion technique specified during initialization. + Parse the JSON response and extract the ``mutated_text`` field. Args: - prompt (str): The input prompt to be converted. - input_type (PromptDataType): The type of input data. + response_text (str): The raw text returned by the LLM. Returns: - ConverterResult: The result containing the converted prompt text. + str: The value of the ``mutated_text`` key. Raises: - ValueError: If the input type is not supported. + InvalidJsonException: If the response is not valid JSON or the ``mutated_text`` key is missing. """ - if not self.input_supported(input_type): - raise ValueError("Input type not supported") - - conversation_id = str(uuid.uuid4()) - - self.converter_target.set_system_prompt( - system_prompt=self.system_prompt, - conversation_id=conversation_id, - attack_identifier=None, - ) - - request = Message( - [ - MessagePiece( - role="user", - original_value=prompt, - converted_value=prompt, - conversation_id=conversation_id, - sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), - original_value_data_type=input_type, - converted_value_data_type=input_type, - converter_identifiers=[self.get_identifier()], - ) - ] - ) - - response = await self.send_persuasion_prompt_async(request) - - return ConverterResult(output_text=response, output_type="text") + cleaned = remove_markdown_json(response_text) + try: + parsed = json.loads(cleaned) + if "mutated_text" not in parsed: + raise InvalidJsonException(message=f"Invalid JSON encountered; missing 'mutated_text' key: {cleaned}") + return str(parsed["mutated_text"]) + except (json.JSONDecodeError, TypeError): + raise InvalidJsonException(message=f"Invalid JSON encountered: {cleaned}") from None - @pyrit_json_retry async def send_persuasion_prompt_async(self, request: Message) -> str: """ - Send the prompt to the converter target and process the response. + Delegate to the unified retry helper. Deprecated shim retained for backward compatibility. Args: - request (Message): The message containing the prompt to be converted. + request (Message): The message to send to the converter target. Returns: - str: The converted prompt text extracted from the response. - - Raises: - InvalidJsonException: If the response is not valid JSON or missing expected keys. + str: The post-processed response text. """ - response = await self.converter_target.send_prompt_async(message=request) - - response_msg = response[0].get_value() - response_msg = remove_markdown_json(response_msg) - - try: - parsed_response = json.loads(response_msg) - if "mutated_text" not in parsed_response: - raise InvalidJsonException( - message=f"Invalid JSON encountered; missing 'mutated_text' key: {response_msg}" - ) - return str(parsed_response["mutated_text"]) - - except json.JSONDecodeError: - raise InvalidJsonException(message=f"Invalid JSON encountered: {response_msg}") from None + warnings.warn( + "send_persuasion_prompt_async is deprecated; the converter now uses the unified " + "_send_with_retries_async helper from LLMGenericTextConverter.", + DeprecationWarning, + stacklevel=2, + ) + return await self._send_with_retries_async(request) diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index b940f12c93..052dd47375 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -3,40 +3,23 @@ import logging import pathlib -import uuid -from textwrap import dedent -from typing import Optional - -from tenacity import ( - AsyncRetrying, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier -from pyrit.models import ( - Message, - MessagePiece, - PromptDataType, - SeedPrompt, -) -from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS, PromptTarget +from pyrit.models import SeedPrompt +from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) -class TranslationConverter(PromptConverter): +class TranslationConverter(LLMGenericTextConverter): """ Translates prompts into different languages using an LLM. """ - SUPPORTED_INPUT_TYPES = ("text",) - SUPPORTED_OUTPUT_TYPES = ("text",) - TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + RETRY_EXCEPTIONS = (Exception,) @apply_defaults def __init__( @@ -44,7 +27,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] language: str, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, max_retries: int = 3, max_wait_time_in_seconds: int = 60, ) -> None: @@ -55,34 +38,40 @@ def __init__( converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. language (str): The language for the conversion. E.g. Spanish, French, leetspeak, etc. - prompt_template (SeedPrompt, Optional): The prompt template for the conversion. - max_retries (int): Maximum number of retries for the conversion. - max_wait_time_in_seconds (int): Maximum wait time in seconds between retries. + prompt_template (SeedPrompt | None): The prompt template for the conversion. + max_retries (int): Maximum number of retry attempts on failure. + max_wait_time_in_seconds (int): Upper bound for exponential backoff between retries. Raises: ValueError: If converter_target is not provided and no default has been configured. ValueError: If the language is not provided. """ - super().__init__(converter_target=converter_target) - self.converter_target = converter_target - - # Retry strategy for the conversion - self._max_retries = max_retries - self._max_wait_time_in_seconds = max_wait_time_in_seconds + if not language: + raise ValueError("Language must be provided for translation conversion") - # set to default strategy if not provided - prompt_template = ( + system_prompt_template = ( prompt_template if prompt_template else SeedPrompt.from_yaml_file(pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "translation_converter.yaml") ) - if not language: - raise ValueError("Language must be provided for translation conversion") + user_prompt_template = SeedPrompt.from_yaml_file( + pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "translation_user_prompt.yaml" + ) self.language = language.lower() - - self.system_prompt = prompt_template.render_template_value(languages=language) + self.system_prompt = system_prompt_template.render_template_value(languages=language) + + super().__init__( + converter_target=converter_target, + system_prompt_template=system_prompt_template, + user_prompt_template_with_objective=user_prompt_template, + max_retry_attempts=max_retries, + retry_wait_max_seconds=max_wait_time_in_seconds, + languages=language, + language=self.language, + ) + self.converter_target = converter_target def _build_identifier(self) -> ComponentIdentifier: """ @@ -98,66 +87,14 @@ def _build_identifier(self) -> ComponentIdentifier: children={"converter_target": self.converter_target.get_identifier()}, ) - async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: + def _process_response(self, response_text: str) -> str: """ - Convert the given prompt by translating it using the converter target. + Strip surrounding whitespace from the LLM response. Args: - prompt (str): The prompt to be converted. - input_type (PromptDataType): The type of input data. + response_text (str): The raw text returned by the LLM. Returns: - ConverterResult: The result containing the generated version of the prompt. - - Raises: - ValueError: If the input type is not supported. + str: The trimmed response text. """ - conversation_id = str(uuid.uuid4()) - - self.converter_target.set_system_prompt(system_prompt=self.system_prompt, conversation_id=conversation_id) - - if not self.input_supported(input_type): - raise ValueError("Input type not supported") - - formatted_prompt = dedent( - f"Translate the following to {self.language} between the begin and end tags:" - "=== begin ===\n" - f"{prompt}\n" - "=== end ===\n" - ) - - logger.debug(f"Formatted Prompt: {formatted_prompt}") - - request = Message( - [ - MessagePiece( - role="user", - original_value=prompt, - converted_value=formatted_prompt, - conversation_id=conversation_id, - sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), - original_value_data_type=input_type, - converted_value_data_type=input_type, - converter_identifiers=[self.get_identifier()], - ) - ] - ) - - translation = await self._send_translation_prompt_async(request) - return ConverterResult(output_text=translation, output_type="text") - - async def _send_translation_prompt_async(self, request: Message) -> str: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(self._max_retries), - wait=wait_exponential(multiplier=1, min=1, max=self._max_wait_time_in_seconds), - retry=retry_if_exception_type(Exception), # covers all exceptions - ): - with attempt: - logger.debug(f"Attempt {attempt.retry_state.attempt_number} for translation") - response = await self.converter_target.send_prompt_async(message=request) - response_msg = response[0].get_value() - return response_msg.strip() - - # when we exhaust all retries without success, raise an exception - raise Exception(f"Failed to translate after {self._max_retries} attempts") + return response_text.strip() diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 44dca26ff0..a5311ccd0f 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -4,45 +4,38 @@ import json import logging import pathlib -import uuid -from textwrap import dedent -from typing import Optional +import warnings from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.exceptions import ( InvalidJsonException, - pyrit_json_retry, remove_markdown_json, ) from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, - MessagePiece, - PromptDataType, SeedPrompt, ) -from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS, PromptTarget +from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) -class VariationConverter(PromptConverter): +class VariationConverter(LLMGenericTextConverter): """ Generates variations of the input prompts using the converter target. """ - SUPPORTED_INPUT_TYPES = ("text",) - SUPPORTED_OUTPUT_TYPES = ("text",) - TARGET_REQUIREMENTS = CHAT_TARGET_REQUIREMENTS + RETRY_EXCEPTIONS = (InvalidJsonException,) @apply_defaults def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the specified target and prompt template. @@ -50,25 +43,34 @@ def __init__( Args: converter_target (PromptTarget): The target to which the prompt will be sent for conversion. Can be omitted if a default has been configured via PyRIT initialization. - prompt_template (SeedPrompt, optional): The template used for generating the system prompt. + prompt_template (SeedPrompt | None): The template used for generating the system prompt. If not provided, a default template will be used. Raises: ValueError: If converter_target is not provided and no default has been configured. """ - super().__init__(converter_target=converter_target) - self.converter_target = converter_target - - # set to default strategy if not provided - prompt_template = ( + system_prompt_template = ( prompt_template if prompt_template else SeedPrompt.from_yaml_file(pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "variation_converter.yaml") ) + user_prompt_template = SeedPrompt.from_yaml_file( + pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "variation_user_prompt.yaml" + ) + self.number_variations = 1 + self.system_prompt = str( + system_prompt_template.render_template_value(number_iterations=str(self.number_variations)) + ) - self.system_prompt = str(prompt_template.render_template_value(number_iterations=str(self.number_variations))) + super().__init__( + converter_target=converter_target, + system_prompt_template=system_prompt_template, + user_prompt_template_with_objective=user_prompt_template, + number_iterations=str(self.number_variations), + ) + self.converter_target = converter_target def _build_identifier(self) -> ComponentIdentifier: """ @@ -81,83 +83,40 @@ def _build_identifier(self) -> ComponentIdentifier: children={"converter_target": self.converter_target.get_identifier()}, ) - async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: + def _process_response(self, response_text: str) -> str: """ - Convert the given prompt by generating variations of it using the converter target. + Parse the JSON list response and return the first variation. Args: - prompt (str): The prompt to be converted. - input_type (PromptDataType): The type of input data. + response_text (str): The raw text returned by the LLM. Returns: - ConverterResult: The result containing the generated variations. + str: The first variation extracted from the JSON list. Raises: - ValueError: If the input type is not supported. + InvalidJsonException: If the response is not valid JSON or does not contain the expected list shape. """ - if not self.input_supported(input_type): - raise ValueError("Input type not supported") - - conversation_id = str(uuid.uuid4()) - - self.converter_target.set_system_prompt( - system_prompt=self.system_prompt, - conversation_id=conversation_id, - attack_identifier=None, - ) - - prompt = dedent( - f"Create {self.number_variations} variation of the seed prompt given by the user between the " - "begin and end tags" - "=== begin ===" - f"{prompt}" - "=== end ===" - ) - - request = Message( - [ - MessagePiece( - role="user", - original_value=prompt, - converted_value=prompt, - conversation_id=conversation_id, - sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), - original_value_data_type=input_type, - converted_value_data_type=input_type, - converter_identifiers=[self.get_identifier()], - ) - ] - ) - response_msg = await self.send_variation_prompt_async(request) - - return ConverterResult(output_text=response_msg, output_type="text") + cleaned = remove_markdown_json(response_text) + try: + parsed = json.loads(cleaned) + return str(parsed[0]) + except (json.JSONDecodeError, IndexError, KeyError, TypeError): + raise InvalidJsonException(message=f"Invalid JSON response: {cleaned}") from None - @pyrit_json_retry async def send_variation_prompt_async(self, request: Message) -> str: """ - Send the message to the converter target and retrieve the response. + Delegate to the unified retry helper. Deprecated shim retained for backward compatibility. Args: - request (Message): The message to be sent to the converter target. + request (Message): The message to send to the converter target. Returns: - str: The response message from the converter target. - - Raises: - InvalidJsonException: If the response is not valid JSON or does not contain the expected keys. + str: The post-processed response text. """ - response = await self.converter_target.send_prompt_async(message=request) - - response_msg = response[0].get_value() - response_msg = remove_markdown_json(response_msg) - try: - response = json.loads(response_msg) - - except json.JSONDecodeError: - raise InvalidJsonException(message=f"Invalid JSON response: {response_msg}") from None - - try: - return str(response[0]) - except KeyError: - raise InvalidJsonException(message=f"Invalid JSON response: {response_msg}") from None + warnings.warn( + "send_variation_prompt_async is deprecated; the converter now uses the unified " + "_send_with_retries_async helper from LLMGenericTextConverter.", + DeprecationWarning, + stacklevel=2, + ) + return await self._send_with_retries_async(request) diff --git a/tests/unit/prompt_converter/test_generic_llm_converter.py b/tests/unit/prompt_converter/test_generic_llm_converter.py index c2e6e88ad9..c99308ee38 100644 --- a/tests/unit/prompt_converter/test_generic_llm_converter.py +++ b/tests/unit/prompt_converter/test_generic_llm_converter.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from unit.mocks import get_mock_target_identifier -from pyrit.models import Message, MessagePiece +from pyrit.exceptions import InvalidJsonException +from pyrit.models import Message, MessagePiece, SeedPrompt from pyrit.prompt_converter import ( LLMGenericTextConverter, MaliciousQuestionGeneratorConverter, @@ -107,3 +108,187 @@ def test_generic_llm_converter_init_default_templates_empty() -> None: converter = LLMGenericTextConverter(converter_target=target) assert converter._system_prompt_template is None assert converter._user_prompt_template_with_objective is None + + +def test_generic_llm_converter_default_no_retry_exceptions() -> None: + target = MagicMock() + converter = LLMGenericTextConverter(converter_target=target) + assert converter._retry_exceptions == () + + +def test_generic_llm_converter_class_attr_retry_exceptions() -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (ValueError,) + + target = MagicMock() + converter = _RetryingConverter(converter_target=target) + assert converter._retry_exceptions == (ValueError,) + + +def test_generic_llm_converter_instance_retry_exceptions_overrides_class_attr() -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (ValueError,) + + target = MagicMock() + converter = _RetryingConverter(converter_target=target, retry_exceptions=(KeyError,)) + assert converter._retry_exceptions == (KeyError,) + + +async def test_convert_async_no_user_template_sets_only_original_value(mock_target) -> None: + converter = LLMGenericTextConverter(converter_target=mock_target) + await converter.convert_async(prompt="hello") + + sent_message = mock_target.send_prompt_async.call_args[1]["message"] + piece = sent_message.message_pieces[0] + assert piece.original_value == "hello" + assert piece.converted_value == "hello" + + +async def test_convert_async_with_user_template_preserves_original_and_renders_converted(mock_target) -> None: + user_template = SeedPrompt( + value="Wrap: [{{ objective }}]", + parameters=["objective"], + data_type="text", + ) + converter = LLMGenericTextConverter(converter_target=mock_target, user_prompt_template_with_objective=user_template) + await converter.convert_async(prompt="raw input") + + sent_message = mock_target.send_prompt_async.call_args[1]["message"] + piece = sent_message.message_pieces[0] + assert piece.original_value == "raw input" + assert piece.converted_value == "Wrap: [raw input]" + + +async def test_convert_async_user_template_receives_extra_kwargs(mock_target) -> None: + user_template = SeedPrompt( + value="Lang={{ language }} Obj={{ objective }}", + parameters=["objective", "language"], + data_type="text", + ) + converter = LLMGenericTextConverter( + converter_target=mock_target, + user_prompt_template_with_objective=user_template, + language="spanish", + ) + await converter.convert_async(prompt="hello") + + sent_message = mock_target.send_prompt_async.call_args[1]["message"] + assert sent_message.message_pieces[0].converted_value == "Lang=spanish Obj=hello" + + +async def test_convert_async_input_validation_raises_before_set_system_prompt(mock_target) -> None: + system_template = SeedPrompt(value="sys", data_type="text") + converter = LLMGenericTextConverter(converter_target=mock_target, system_prompt_template=system_template) + with pytest.raises(ValueError, match="Input type not supported"): + await converter.convert_async(prompt="hello", input_type="image_path") + mock_target.set_system_prompt.assert_not_called() + mock_target.send_prompt_async.assert_not_called() + + +async def test_convert_async_process_response_hook_called(mock_target) -> None: + class _UpperCaseConverter(LLMGenericTextConverter): + def _process_response(self, response_text: str) -> str: + return response_text.upper() + + converter = _UpperCaseConverter(converter_target=mock_target) + result = await converter.convert_async(prompt="anything") + assert result.output_text == "PROMPT VALUE" + + +async def test_send_with_retries_no_retry_when_empty_exception_tuple(mock_target) -> None: + converter = LLMGenericTextConverter(converter_target=mock_target) + mock_target.send_prompt_async.side_effect = ValueError("boom") + with pytest.raises(ValueError, match="boom"): + await converter.convert_async(prompt="hello") + assert mock_target.send_prompt_async.call_count == 1 + + +async def test_send_with_retries_retries_on_configured_exception(mock_target) -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (InvalidJsonException,) + + def _process_response(self, response_text: str) -> str: + raise InvalidJsonException(message="bad") + + converter = _RetryingConverter(converter_target=mock_target) + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(InvalidJsonException): + await converter.convert_async(prompt="hello") + assert mock_target.send_prompt_async.call_count == 2 # RETRY_MAX_NUM_ATTEMPTS=2 in conftest + + +async def test_send_with_retries_does_not_retry_unrelated_exception(mock_target) -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (InvalidJsonException,) + + converter = _RetryingConverter(converter_target=mock_target) + mock_target.send_prompt_async.side_effect = ValueError("not retried") + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(ValueError, match="not retried"): + await converter.convert_async(prompt="hello") + assert mock_target.send_prompt_async.call_count == 1 + + +async def test_send_with_retries_succeeds_after_one_failure(mock_target) -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (InvalidJsonException,) + _calls = 0 + + def _process_response(self, response_text: str) -> str: + type(self)._calls += 1 + if type(self)._calls == 1: + raise InvalidJsonException(message="first") + return response_text + + converter = _RetryingConverter(converter_target=mock_target) + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await converter.convert_async(prompt="hello") + assert result.output_text == "prompt value" + assert mock_target.send_prompt_async.call_count == 2 + + +async def test_send_with_retries_uses_static_attempt_count_when_provided(mock_target) -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (InvalidJsonException,) + + def _process_response(self, response_text: str) -> str: + raise InvalidJsonException(message="bad") + + converter = _RetryingConverter(converter_target=mock_target, max_retry_attempts=4) + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(InvalidJsonException): + await converter.convert_async(prompt="hello") + assert mock_target.send_prompt_async.call_count == 4 + + +async def test_send_with_retries_no_wait_by_default(mock_target) -> None: + """Default wait is none (0 seconds), matching pyrit_json_retry behavior.""" + + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (InvalidJsonException,) + + def _process_response(self, response_text: str) -> str: + raise InvalidJsonException(message="bad") + + converter = _RetryingConverter(converter_target=mock_target) + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with pytest.raises(InvalidJsonException): + await converter.convert_async(prompt="hello") + for call in mock_sleep.call_args_list: + assert call.args[0] == 0.0 + + +async def test_send_with_retries_uses_exponential_wait_when_max_seconds_provided(mock_target) -> None: + class _RetryingConverter(LLMGenericTextConverter): + RETRY_EXCEPTIONS = (InvalidJsonException,) + + def _process_response(self, response_text: str) -> str: + raise InvalidJsonException(message="bad") + + converter = _RetryingConverter(converter_target=mock_target, retry_wait_max_seconds=10) + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with pytest.raises(InvalidJsonException): + await converter.convert_async(prompt="hello") + # waits should be > 0 between attempts (exponential backoff) + nonzero_waits = [c for c in mock_sleep.call_args_list if c.args[0] > 0] + assert len(nonzero_waits) >= 1 diff --git a/tests/unit/prompt_converter/test_persuasion_converter.py b/tests/unit/prompt_converter/test_persuasion_converter.py index 7fd66e45ed..95910f3c21 100644 --- a/tests/unit/prompt_converter/test_persuasion_converter.py +++ b/tests/unit/prompt_converter/test_persuasion_converter.py @@ -81,10 +81,58 @@ async def test_persuasion_converter_send_prompt_async_bad_json_exception_retries ) mock_create.return_value = [message] - with pytest.raises(InvalidJsonException): - await prompt_persuasion.convert_async(prompt="testing", input_type="text") - # RETRY_MAX_NUM_ATTEMPTS is set to 2 in conftest.py - assert mock_create.call_count == 2 + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(InvalidJsonException): + await prompt_persuasion.convert_async(prompt="testing", input_type="text") + + # RETRY_MAX_NUM_ATTEMPTS is set to 2 in conftest.py + assert mock_create.call_count == 2 + + +async def test_persuasion_converter_extracts_mutated_text(sqlite_instance): + prompt_target = MockPromptTarget() + prompt_persuasion = PersuasionConverter( + converter_target=prompt_target, persuasion_technique="authority_endorsement" + ) + + response = Message( + message_pieces=[ + MessagePiece( + role="assistant", + conversation_id="test-id", + original_value='{"mutated_text": "rephrased prompt"}', + original_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + sequence=1, + ) + ] + ) + with patch.object(prompt_target, "send_prompt_async", new=AsyncMock(return_value=[response])): + result = await prompt_persuasion.convert_async(prompt="testing") + assert result.output_text == "rephrased prompt" + + +async def test_persuasion_converter_missing_mutated_text_raises_invalid_json(sqlite_instance): + prompt_target = MockPromptTarget() + prompt_persuasion = PersuasionConverter( + converter_target=prompt_target, persuasion_technique="authority_endorsement" + ) + response = Message( + message_pieces=[ + MessagePiece( + role="assistant", + conversation_id="test-id", + original_value='{"other_key": "value"}', + original_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + sequence=1, + ) + ] + ) + with patch.object(prompt_target, "send_prompt_async", new=AsyncMock(return_value=[response])): + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(InvalidJsonException, match="missing 'mutated_text' key"): + await prompt_persuasion.convert_async(prompt="testing") def test_persuasion_converter_input_supported(): @@ -94,3 +142,10 @@ def test_persuasion_converter_input_supported(): ) assert prompt_persuasion.input_supported("text") is True assert prompt_persuasion.input_supported("image_path") is False + + +def test_persuasion_converter_identifier_includes_technique(sqlite_instance): + prompt_target = MockPromptTarget() + prompt_persuasion = PersuasionConverter(converter_target=prompt_target, persuasion_technique="logical_appeal") + identifier = prompt_persuasion.get_identifier() + assert identifier.params["persuasion_technique"] == "logical_appeal" diff --git a/tests/unit/prompt_converter/test_translation_converter.py b/tests/unit/prompt_converter/test_translation_converter.py index 3b8969cbe4..f12ad2ab61 100644 --- a/tests/unit/prompt_converter/test_translation_converter.py +++ b/tests/unit/prompt_converter/test_translation_converter.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from textwrap import dedent from unittest.mock import AsyncMock, patch import pytest @@ -29,18 +30,61 @@ def test_translator_converter_languages_validation_throws(languages, sqlite_inst TranslationConverter(converter_target=prompt_target, language=languages) -async def test_translation_converter_convert_async_retrieve_key_capitalization_mismatch(sqlite_instance): +async def test_translation_converter_returns_stripped_response(sqlite_instance): prompt_target = MockPromptTarget() translation_converter = TranslationConverter(converter_target=prompt_target, language="spanish") - with patch.object(translation_converter, "_send_translation_prompt_async", new=AsyncMock(return_value="hola")): - raised = False - try: - await translation_converter.convert_async(prompt="hello") - except KeyError: - raised = True # There should be no KeyError + response = Message( + message_pieces=[ + MessagePiece( + role="assistant", + conversation_id="test-id", + original_value=" hola \n", + original_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + sequence=1, + ) + ] + ) + with patch.object(prompt_target, "send_prompt_async", new=AsyncMock(return_value=[response])): + result = await translation_converter.convert_async(prompt="hello") - assert raised is False + assert result.output_text == "hola" + assert result.output_type == "text" + + +async def test_translation_converter_user_prompt_byte_for_byte_equivalent(sqlite_instance): + """Regression: the SeedPrompt-rendered user prompt must match the previous f-string output exactly.""" + prompt_target = MockPromptTarget() + translation_converter = TranslationConverter(converter_target=prompt_target, language="Spanish") + + raw_prompt = "tell me about the history of the internet" + expected = dedent( + f"Translate the following to {translation_converter.language} between the begin and end tags:" + "=== begin ===\n" + f"{raw_prompt}\n" + "=== end ===\n" + ) + + response = Message( + message_pieces=[ + MessagePiece( + role="assistant", + conversation_id="test-id", + original_value="hola", + original_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + sequence=1, + ) + ] + ) + with patch.object(prompt_target, "send_prompt_async", new=AsyncMock(return_value=[response])) as mock_send: + await translation_converter.convert_async(prompt=raw_prompt) + + sent_message = mock_send.call_args[1]["message"] + piece = sent_message.message_pieces[0] + assert piece.original_value == raw_prompt + assert piece.converted_value == expected async def test_translation_converter_retries_on_exception(sqlite_instance): @@ -51,7 +95,6 @@ async def test_translation_converter_retries_on_exception(sqlite_instance): ) mock_send_prompt = AsyncMock(side_effect=Exception("Test failure")) - # Mock asyncio.sleep to avoid exponential backoff delays with patch.object(prompt_target, "send_prompt_async", mock_send_prompt): with patch("asyncio.sleep", new_callable=AsyncMock): with pytest.raises(Exception): # noqa: B017 @@ -87,7 +130,6 @@ async def test_translation_converter_succeeds_after_retries(sqlite_instance): mock_send_prompt = AsyncMock() mock_send_prompt.side_effect = [Exception("First failure"), Exception("Second failure"), [success_response]] - # Mock asyncio.sleep to avoid exponential backoff delays with patch.object(prompt_target, "send_prompt_async", mock_send_prompt): with patch("asyncio.sleep", new_callable=AsyncMock): result = await translation_converter.convert_async(prompt="hello") @@ -102,3 +144,10 @@ def test_translation_converter_input_supported(sqlite_instance): translation_converter = TranslationConverter(converter_target=prompt_target, language="spanish") assert translation_converter.input_supported("text") is True assert translation_converter.input_supported("image_path") is False + + +def test_translation_converter_identifier_includes_language(sqlite_instance): + prompt_target = MockPromptTarget() + translation_converter = TranslationConverter(converter_target=prompt_target, language="Spanish") + identifier = translation_converter.get_identifier() + assert identifier.params["language"] == "spanish" diff --git a/tests/unit/prompt_converter/test_variation_converter.py b/tests/unit/prompt_converter/test_variation_converter.py index 2bfda055fa..30bcf6cf8e 100644 --- a/tests/unit/prompt_converter/test_variation_converter.py +++ b/tests/unit/prompt_converter/test_variation_converter.py @@ -54,13 +54,62 @@ async def test_variation_converter_send_prompt_async_bad_json_exception_retries( mock_create.return_value = [message] - with pytest.raises(InvalidJsonException): - await prompt_variation.convert_async(prompt="testing", input_type="text") + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(InvalidJsonException): + await prompt_variation.convert_async(prompt="testing", input_type="text") # RETRY_MAX_NUM_ATTEMPTS is set to 2 in conftest.py assert mock_create.call_count == 2 +async def test_variation_converter_extracts_first_element_from_json_list(sqlite_instance): + prompt_target = MockPromptTarget() + prompt_variation = VariationConverter(converter_target=prompt_target) + + response = Message( + message_pieces=[ + MessagePiece( + role="assistant", + conversation_id="test-id", + original_value='["first variation", "second variation"]', + original_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + sequence=1, + ) + ] + ) + with patch.object(prompt_target, "send_prompt_async", new=AsyncMock(return_value=[response])): + result = await prompt_variation.convert_async(prompt="testing") + assert result.output_text == "first variation" + + +async def test_variation_converter_preserves_original_and_converted_values(sqlite_instance): + prompt_target = MockPromptTarget() + prompt_variation = VariationConverter(converter_target=prompt_target) + + response = Message( + message_pieces=[ + MessagePiece( + role="assistant", + conversation_id="test-id", + original_value='["variation"]', + original_value_data_type="text", + prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), + sequence=1, + ) + ] + ) + with patch.object(prompt_target, "send_prompt_async", new=AsyncMock(return_value=[response])) as mock_send: + await prompt_variation.convert_async(prompt="hello world") + + sent_message = mock_send.call_args[1]["message"] + piece = sent_message.message_pieces[0] + assert piece.original_value == "hello world" + assert "hello world" in piece.converted_value + assert "=== begin ===" in piece.converted_value + assert "=== end ===" in piece.converted_value + + def test_variation_converter_input_supported(sqlite_instance): prompt_target = MockPromptTarget() converter = VariationConverter(converter_target=prompt_target)