Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 166 additions & 62 deletions src/cfnlint/rules/resources/stepfunctions/StateMachineDefinition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from collections import deque
from copy import deepcopy
from typing import Any
from typing import Any, Deque, Dict, Iterable, List, Set

import cfnlint.data.schemas.other.step_functions
from cfnlint.helpers import is_function
Expand Down Expand Up @@ -54,9 +54,8 @@ def _fix_message(self, err: ValidationError) -> ValidationError:
err.context[i] = self._fix_message(c_err)
return err

def _convert_schema_to_jsonata(self):
schema = self.schema
schema = deepcopy(schema)
def _convert_schema_to_jsonata(self) -> dict[str, Any]:
schema = deepcopy(self.schema)
schema["definitions"]["common"] = deepcopy(schema["definitions"]["common"])
schema["definitions"]["common"]["allOf"] = [
{
Expand Down Expand Up @@ -104,7 +103,10 @@ def _convert_schema_to_jsonata(self):
}
return schema

def _clean_schema(self, validator: Validator, instance: Any):
def _clean_schema(
self, validator: Validator, instance: Any
) -> Iterable[tuple[dict[str, Any], Validator]]:
"""Yield the appropriate schema and validator depending on QueryLanguage."""
for ql, ql_validator in get_value_from_path(
validator, instance, deque(["QueryLanguage"])
):
Expand All @@ -114,75 +116,174 @@ def _clean_schema(self, validator: Validator, instance: Any):
if ql == "JSONata":
yield self._convert_schema_to_jsonata(), ql_validator

def _validate_start_at(
def _add_transition_error(
self,
errors: List[ValidationError],
states: Dict[str, Any],
target: Any,
k: str,
base_path: Deque[str | int],
from_state_name: str,
field_path: List[str | int],
) -> None:
"""Record an error if a transition target does not exist in States."""
if not isinstance(target, str):
return
if target in states:
return

path = deque([k, *base_path, "States", from_state_name, *field_path])
display_path = "/" + "/".join(
[*map(str, base_path), "States", from_state_name, *map(str, field_path)]
)
errors.append(
ValidationError(
f"Missing transition target {target!r} at {display_path}",
path=path,
rule=self,
)
)

def _validate_states(
self,
definition: Any,
k: str,
add_path_to_message: bool,
path: deque | None = None,
add_path_to_message: bool, # kept for API compatibility, not used
path: Deque[str | int] | None = None,
) -> ValidationResult:
"""
Per the Amazon States Language specification, 'StartAt must' reference
a valid state name that exists in the States object.
"""Validate state reachability and transition targets.

Reference: https://states-language.net/spec.html#toplevelfields
Per the Amazon States Language specification:
- 'StartAt' must reference an existing state name.
- Every state must be reachable from 'StartAt'.
- Every transition target (Next, Default, Choices[].Next, Catch[].Next)
must reference an existing state name.
"""

start_at = definition.get("StartAt")
states = definition.get("States")

# Check if StartAt is missing
if not isinstance(start_at, str) or not isinstance(states, dict):
return # Early return to avoid further checks
return

# Check if StartAt state exists in States object
path_list: List[str | int] = list(path) if path is not None else []

# 1. Validate that StartAt points to an existing state
if start_at not in states:
if path is None: # Top level StartAt
error_path = deque([k, "StartAt"])
display_path = "/StartAt"
else: # Nested StartAt like Parallel or Map
error_path = deque([k] + list(path) + ["StartAt"])
display_path = f"/{'/'.join(str(item) for item in path)}/StartAt"
display_path = "/" + "/".join([*map(str, path_list), "StartAt"])
yield ValidationError(
f"Missing 'Next' target {start_at!r} at {display_path}",
path=deque([k, *path_list, "StartAt"]),
rule=self,
)
return

message = f"Missing 'Next' target '{start_at}' at {display_path}"
# 2. Traverse reachable states from StartAt and collect errors for
# missing transition targets along the way.
reachable: Set[str] = set()
to_visit: Set[str] = {start_at}
transition_errors: List[ValidationError] = []

yield ValidationError(message, path=error_path, rule=self)
while to_visit:
current = to_visit.pop()
if current in reachable or current not in states:
continue

# Validate nested StartAt in Parallel and Map states
base_path = deque() if path is None else deque(path)
for state_name, state in states.items():
reachable.add(current)
state = states[current]
if not isinstance(state, dict):
continue

base_path = deque([*path_list])

# Next / Default
for field in ("Next", "Default"):
target = state.get(field)
self._add_transition_error(
transition_errors,
states,
target,
k,
base_path,
current,
[field],
)
if isinstance(target, str) and target in states:
to_visit.add(target)

# Choices[].Next
for idx, choice in enumerate(state.get("Choices") or []):
if not isinstance(choice, dict):
continue
target = choice.get("Next")
self._add_transition_error(
transition_errors,
states,
target,
k,
base_path,
current,
["Choices", idx, "Next"],
)
if isinstance(target, str) and target in states:
to_visit.add(target)

# Catch[].Next
for idx, catch in enumerate(state.get("Catch") or []):
if not isinstance(catch, dict):
continue
target = catch.get("Next")
self._add_transition_error(
transition_errors,
states,
target,
k,
base_path,
current,
["Catch", idx, "Next"],
)
if isinstance(target, str) and target in states:
to_visit.add(target)

# Recurse into nested state machines (Parallel branches, Map processors)
state_type = state.get("Type")

if state_type == "Parallel":
branches = state.get("Branches", [])
if not isinstance(branches, list):
continue
for idx, branch in enumerate(branches):
branch_path = deque(base_path)
branch_path.extend(["States", state_name, "Branches", idx])
yield from self._validate_start_at(
branch, k, add_path_to_message, branch_path
)
for idx, branch in enumerate(state.get("Branches") or []):
if isinstance(branch, dict):
yield from self._validate_states(
branch,
k,
add_path_to_message,
deque([*path_list, "States", current, "Branches", idx]),
)

if state_type == "Map":
# ItemProcessor (distributed/inline mode)
processor = state.get("ItemProcessor")
if isinstance(processor, dict):
processor_path = deque(base_path)
processor_path.extend(["States", state_name, "ItemProcessor"])
yield from self._validate_start_at(
processor, k, add_path_to_message, processor_path
)
# Iterator (classic map)
iterator = state.get("Iterator")
if isinstance(iterator, dict):
iterator_path = deque(base_path)
iterator_path.extend(["States", state_name, "Iterator"])
yield from self._validate_start_at(
iterator, k, add_path_to_message, iterator_path
)
for sub_key in ("ItemProcessor", "Iterator"):
sub = state.get(sub_key)
if isinstance(sub, dict):
yield from self._validate_states(
sub,
k,
add_path_to_message,
deque([*path_list, "States", current, sub_key]),
)

# 3. Report unreachable states (defined but never visited)
for state_name in states:
if state_name not in reachable:
display_path = "/" + "/".join(
[*map(str, path_list), "States", state_name]
)
yield ValidationError(
f"State {state_name!r} is not reachable at {display_path}",
path=deque([k, *path_list, "States", state_name]),
rule=self,
)

# 4. Emit any transition target errors collected during traversal
for err in transition_errors:
yield err

def _validate_step(
self,
Expand All @@ -192,6 +293,7 @@ def _validate_step(
add_path_to_message: bool,
k: str,
) -> ValidationResult:
# First run JSON Schema validation
for err in validator.iter_errors(value):
if validator.is_type(err.instance, "string"):
if (
Expand All @@ -215,8 +317,8 @@ def _validate_step(

yield self._clean_error(err)

# Validate StartAt exists
yield from self._validate_start_at(value, k, add_path_to_message)
# Then run start/transition/reachability validation on the same value
yield from self._validate_states(value, k, add_path_to_message)

def validate(
self, validator: Validator, keywords: Any, instance: Any, schema: dict[str, Any]
Expand All @@ -225,7 +327,7 @@ def validate(
if not validator.cfn.has_serverless_transform():
definition_keys.append("DefinitionString")

substitutions = []
substitutions: list[str] = []
props_substitutions = instance.get("DefinitionSubstitutions", {})
if validator.is_type(props_substitutions, "object"):
substitutions = list(props_substitutions.keys())
Expand All @@ -245,10 +347,8 @@ def validate(
try:
value = json.loads(value)
add_path_to_message = True
for schema, schema_validator in self._clean_schema(
validator, value
):
resolver = RefResolver.from_schema(schema, store=self.store)
for schema_obj, schema_validator in self._clean_schema(validator, value):
resolver = RefResolver.from_schema(schema_obj, store=self.store)
step_validator = schema_validator.evolve(
context=validator.context.evolve(
functions=[],
Expand All @@ -263,12 +363,16 @@ def validate(
except json.JSONDecodeError:
return
else:
for schema, schema_validator in self._clean_schema(validator, value):
resolver = RefResolver.from_schema(schema, store=self.store)
for schema_obj, schema_validator in self._clean_schema(validator, value):
resolver = RefResolver.from_schema(schema_obj, store=self.store)
step_validator = schema_validator.evolve(
resolver=resolver,
schema=schema,
schema=schema_obj,
)
yield from self._validate_step(
step_validator, substitutions, value, add_path_to_message, k
step_validator,
substitutions,
value,
add_path_to_message,
k,
)
Loading