Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions docs/cli-reference/field-customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -3678,8 +3678,7 @@ This is useful when schemas have descriptive titles that should be preserved.

class ProcessingTaskTitle(BaseModel):
processing_status_union: ProcessingStatusUnionTitle | None = Field(
default_factory=lambda: ProcessingStatusUnionTitle('COMPLETED'),
title='Processing Status Union Title',
'COMPLETED', title='Processing Status Union Title', validate_default=True
)
processing_status: ProcessingStatusTitle | None = 'COMPLETED'
name: str | None = None
Expand All @@ -3706,10 +3705,7 @@ This is useful when schemas have descriptive titles that should be preserved.
RootModel[ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle]
):
root: ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle = (
Field(
default_factory=lambda: ExtendedProcessingTask('COMPLETED'),
title='Processing Status Union Title',
)
Field('COMPLETED', title='Processing Status Union Title', validate_default=True)
)


Expand Down
4 changes: 1 addition & 3 deletions docs/cli-reference/model-customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -6014,9 +6014,7 @@ The `--use-one-literal-as-default` flag configures the code generation behavior.

class NestedNullableEnum(BaseModel):
nested_version: NestedVersion | None = Field(
default_factory=lambda: NestedVersion('RC1'),
description='nullable enum',
examples=['RC2'],
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
)


Expand Down
2 changes: 1 addition & 1 deletion docs/cli-reference/template-customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -2892,7 +2892,7 @@ helps maintain consistency with codebases that prefer double-quote formatting.
class MapState2(BaseModel):
latitude: Latitude
longitude: Longitude
zoom: Zoom | None = Field(default_factory=lambda: Zoom(0))
zoom: Zoom | None = Field(0, validate_default=True)
bearing: Bearing | None = None
pitch: Pitch
drag_rotate: DragRotate | None = Field(None, alias="dragRotate")
Expand Down
4 changes: 1 addition & 3 deletions docs/cli-reference/typing-customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -1350,9 +1350,7 @@ of Enum classes for all enumerations.

class NestedNullableEnum(BaseModel):
nested_version: NestedVersion | None = Field(
default_factory=lambda: NestedVersion('RC1'),
description='nullable enum',
examples=['RC2'],
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
)


Expand Down
18 changes: 5 additions & 13 deletions docs/llms-full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7242,9 +7242,7 @@ The `--use-one-literal-as-default` flag configures the code generation behavior.

class NestedNullableEnum(BaseModel):
nested_version: NestedVersion | None = Field(
default_factory=lambda: NestedVersion('RC1'),
description='nullable enum',
examples=['RC2'],
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
)


Expand Down Expand Up @@ -11141,8 +11139,7 @@ This is useful when schemas have descriptive titles that should be preserved.

class ProcessingTaskTitle(BaseModel):
processing_status_union: ProcessingStatusUnionTitle | None = Field(
default_factory=lambda: ProcessingStatusUnionTitle('COMPLETED'),
title='Processing Status Union Title',
'COMPLETED', title='Processing Status Union Title', validate_default=True
)
processing_status: ProcessingStatusTitle | None = 'COMPLETED'
name: str | None = None
Expand All @@ -11169,10 +11166,7 @@ This is useful when schemas have descriptive titles that should be preserved.
RootModel[ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle]
):
root: ProcessingStatusDetail | ExtendedProcessingTask | ProcessingStatusTitle = (
Field(
default_factory=lambda: ExtendedProcessingTask('COMPLETED'),
title='Processing Status Union Title',
)
Field('COMPLETED', title='Processing Status Union Title', validate_default=True)
)


Expand Down Expand Up @@ -12537,9 +12531,7 @@ of Enum classes for all enumerations.

class NestedNullableEnum(BaseModel):
nested_version: NestedVersion | None = Field(
default_factory=lambda: NestedVersion('RC1'),
description='nullable enum',
examples=['RC2'],
'RC1', description='nullable enum', examples=['RC2'], validate_default=True
)


Expand Down Expand Up @@ -19083,7 +19075,7 @@ helps maintain consistency with codebases that prefer double-quote formatting.
class MapState2(BaseModel):
latitude: Latitude
longitude: Longitude
zoom: Zoom | None = Field(default_factory=lambda: Zoom(0))
zoom: Zoom | None = Field(0, validate_default=True)
bearing: Bearing | None = None
pitch: Pitch
drag_rotate: DragRotate | None = Field(None, alias="dragRotate")
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ optional-dependencies.http = [
optional-dependencies.ruff = [
"ruff>=0.9.10",
]
optional-dependencies.ryaml = [
"ryaml>=0.5.1",
]
optional-dependencies.validation = [
"openapi-spec-validator>=0.2.8,<0.8",
"prance>=0.18.2",
Expand Down
14 changes: 11 additions & 3 deletions src/datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,15 @@


def load_yaml(stream: str | TextIO) -> YamlValue:
"""Load YAML content from a string or file-like object."""
"""Load YAML content using ryaml (if available) or PyYAML."""
from datamodel_code_generator.util import get_yaml_backend # noqa: PLC0415

if get_yaml_backend() == "ryaml":
import ryaml # noqa: PLC0415 # ty: ignore[unresolved-import]

text = stream if isinstance(stream, str) else stream.read()
return ryaml.loads(text)

import yaml # noqa: PLC0415

from datamodel_code_generator.util import SafeLoader # noqa: PLC0415
Expand Down Expand Up @@ -933,11 +941,11 @@ def get_header_and_first_line(csv_file: IO[str]) -> dict[str, Any]:

def infer_input_type(text: str) -> InputFileType:
"""Automatically detect the input file type from text content."""
import yaml.parser # noqa: PLC0415
from datamodel_code_generator.util import get_yaml_parse_errors # noqa: PLC0415

try:
data = load_yaml(text)
except yaml.parser.ParserError:
except get_yaml_parse_errors():
return InputFileType.CSV
if isinstance(data, dict):
if is_openapi(data):
Expand Down
31 changes: 30 additions & 1 deletion src/datamodel_code_generator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import warnings
from functools import lru_cache
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -83,6 +83,35 @@ class CustomSafeLoader(_SafeLoader): # type: ignore[valid-type,misc]
return CustomSafeLoader


YamlBackend = Literal["ryaml", "pyyaml"]


@lru_cache(maxsize=1)
def get_yaml_backend() -> YamlBackend:
"""Detect the available YAML backend ('ryaml' or 'pyyaml')."""
try:
import ryaml # noqa: PLC0415, F401 # ty: ignore[unresolved-import]
except ImportError:
return "pyyaml"
else:
return "ryaml"


@lru_cache(maxsize=1)
def get_yaml_parse_errors() -> tuple[type[Exception], ...]:
"""Return YAML parse error types for both backends."""
import yaml # noqa: PLC0415

errors: list[type[Exception]] = [yaml.YAMLError]
try:
import ryaml # noqa: PLC0415 # ty: ignore[unresolved-import]

errors.append(ryaml.InvalidYamlError)
except ImportError:
Comment thread Dismissed
pass
return tuple(errors)


@lru_cache(maxsize=1)
def _get_base_model_class() -> type:
"""Get BaseModel class with strict=False config lazily."""
Expand Down
122 changes: 122 additions & 0 deletions tests/test_yaml_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Tests for YAML backend detection and ryaml/PyYAML switching."""

from __future__ import annotations

import io
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

import pytest
import yaml

from datamodel_code_generator import InputFileType, infer_input_type, load_yaml
from datamodel_code_generator.util import get_yaml_backend, get_yaml_parse_errors

if TYPE_CHECKING:
from collections.abc import Iterator


@pytest.fixture(autouse=True)
def _clear_caches() -> Iterator[None]:
"""Clear lru_cache before and after each test."""
get_yaml_backend.cache_clear()
get_yaml_parse_errors.cache_clear()
yield
get_yaml_backend.cache_clear()
get_yaml_parse_errors.cache_clear()


class TestGetYamlBackend:
"""Tests for get_yaml_backend()."""

def test_without_ryaml(self) -> None:
"""When ryaml is not importable, returns 'pyyaml'."""
with patch.dict("sys.modules", {"ryaml": None}):
assert get_yaml_backend() == "pyyaml"

def test_with_ryaml(self) -> None:
"""When ryaml is importable, returns 'ryaml'."""
mock_ryaml = MagicMock()
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
assert get_yaml_backend() == "ryaml"


class TestGetYamlParseErrors:
"""Tests for get_yaml_parse_errors()."""

def test_pyyaml_only(self) -> None:
"""Without ryaml, only yaml.YAMLError is returned."""
with patch.dict("sys.modules", {"ryaml": None}):
errors = get_yaml_parse_errors()
assert errors == (yaml.YAMLError,)

def test_includes_ryaml(self) -> None:
"""With ryaml, InvalidYamlError is included."""
mock_ryaml = MagicMock()
mock_ryaml.InvalidYamlError = type("InvalidYamlError", (Exception,), {})
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
errors = get_yaml_parse_errors()
assert yaml.YAMLError in errors
assert mock_ryaml.InvalidYamlError in errors
assert len(errors) == 2


class TestLoadYaml:
"""Tests for load_yaml() with backend switching."""

def test_pyyaml_fallback_string(self) -> None:
"""When ryaml is unavailable, PyYAML is used for string input."""
with patch.dict("sys.modules", {"ryaml": None}):
result = load_yaml("key: value")
assert result == {"key": "value"}

def test_pyyaml_fallback_textio(self) -> None:
"""When ryaml is unavailable, PyYAML is used for TextIO input."""
with patch.dict("sys.modules", {"ryaml": None}):
result = load_yaml(io.StringIO("key: value"))
assert result == {"key": "value"}

def test_with_ryaml_string(self) -> None:
"""When ryaml is available, ryaml.loads() is used for string input."""
mock_ryaml = MagicMock()
mock_ryaml.loads.return_value = {"key": "value"}
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
result = load_yaml("key: value")
mock_ryaml.loads.assert_called_once_with("key: value")
assert result == {"key": "value"}

def test_with_ryaml_textio(self) -> None:
"""When ryaml is available, TextIO.read() is called before ryaml.loads()."""
mock_ryaml = MagicMock()
mock_ryaml.loads.return_value = {"key": "value"}
stream = io.StringIO("key: value")
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
result = load_yaml(stream)
mock_ryaml.loads.assert_called_once_with("key: value")
assert result == {"key": "value"}


class TestInferInputType:
"""Tests for infer_input_type() with backend error handling."""

def test_csv_with_pyyaml_error(self) -> None:
"""YAML parse error from PyYAML returns CSV type."""
with patch.dict("sys.modules", {"ryaml": None}):
result = infer_input_type("a,b,c\n1,2,3\n::")
assert result == InputFileType.CSV

def test_csv_with_ryaml_error(self) -> None:
"""YAML parse error from ryaml returns CSV type."""
mock_invalid_yaml_error = type("InvalidYamlError", (Exception,), {})
mock_ryaml = MagicMock()
mock_ryaml.InvalidYamlError = mock_invalid_yaml_error
mock_ryaml.loads.side_effect = mock_invalid_yaml_error("parse error")
with patch.dict("sys.modules", {"ryaml": mock_ryaml}):
result = infer_input_type(":::invalid yaml:::")
assert result == InputFileType.CSV

def test_openapi_detection(self) -> None:
"""OpenAPI input is detected correctly regardless of backend."""
with patch.dict("sys.modules", {"ryaml": None}):
result = infer_input_type("openapi: '3.0.0'\ninfo:\n title: Test\n version: '1.0'")
assert result == InputFileType.OpenAPI
Loading
Loading