Skip to content

Commit c438fd5

Browse files
committed
Fixed required form field and security in schema.
1 parent d29647b commit c438fd5

File tree

2 files changed

+89
-36
lines changed

2 files changed

+89
-36
lines changed

fastopenapi/openapi/generator.py

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def process_route_parameters(self, route) -> tuple[list[dict], dict | None]:
202202
body_fields: dict[str, dict] = {}
203203
form_fields = {}
204204
multipart_fields = {}
205+
form_required = []
205206
has_explicit_embed = False
206207

207208
for param_name, param in sig.parameters.items():
@@ -215,36 +216,70 @@ def process_route_parameters(self, route) -> tuple[list[dict], dict | None]:
215216
if result is None:
216217
continue
217218

218-
param_type, data = result
219-
220-
if param_type == "parameter":
221-
parameters.append(data)
222-
elif param_type == "parameters": # Multiple parameters from model
223-
parameters.extend(data)
224-
elif param_type == "request_body":
225-
body_fields[param_name] = data
226-
if isinstance(param.default, Body) and param.default.embed:
227-
has_explicit_embed = True
228-
elif param_type == "form":
229-
form_fields[param_name] = data
230-
elif param_type == "multipart": # pragma: no cover
231-
multipart_fields[param_name] = data
219+
self._classify_parameter_result(
220+
result,
221+
param_name,
222+
param,
223+
parameters,
224+
body_fields,
225+
form_fields,
226+
multipart_fields,
227+
form_required,
228+
)
229+
if isinstance(param.default, Body) and param.default.embed:
230+
has_explicit_embed = True
232231

233232
request_body = self._resolve_request_body(
234-
body_fields, form_fields, multipart_fields, has_explicit_embed
233+
body_fields,
234+
form_fields,
235+
multipart_fields,
236+
has_explicit_embed,
237+
form_required,
235238
)
236239
return parameters, request_body
237240

241+
def _classify_parameter_result(
242+
self,
243+
result: tuple[str, Any],
244+
param_name: str,
245+
param: inspect.Parameter,
246+
parameters: list,
247+
body_fields: dict,
248+
form_fields: dict,
249+
multipart_fields: dict,
250+
form_required: list,
251+
) -> None:
252+
"""Classify a processed parameter result into the appropriate collection"""
253+
param_type, data = result
254+
255+
if param_type == "parameter":
256+
parameters.append(data)
257+
elif param_type == "parameters":
258+
parameters.extend(data)
259+
elif param_type == "request_body":
260+
body_fields[param_name] = data
261+
elif param_type == "form":
262+
form_fields[param_name] = data
263+
if self._is_form_field_required(param):
264+
form_required.append(param_name)
265+
elif param_type == "multipart": # pragma: no cover
266+
multipart_fields[param_name] = data
267+
if self._is_form_field_required(param):
268+
form_required.append(param_name)
269+
238270
def _resolve_request_body(
239271
self,
240272
body_fields: dict[str, dict],
241273
form_fields: dict,
242274
multipart_fields: dict,
243275
has_explicit_embed: bool,
276+
form_required: list[str] | None = None,
244277
) -> dict | None:
245278
"""Resolve final request body from collected fields"""
246279
if form_fields or multipart_fields:
247-
return self._build_form_request_body(form_fields, multipart_fields)
280+
return self._build_form_request_body(
281+
form_fields, multipart_fields, form_required or []
282+
)
248283
if len(body_fields) > 1 or has_explicit_embed:
249284
return self._build_embedded_request_body(body_fields)
250285
if body_fields:
@@ -496,30 +531,43 @@ def _build_embedded_request_body(body_fields: dict[str, dict]) -> dict:
496531
}
497532

498533
def _build_form_request_body(
499-
self, form_fields: dict, multipart_fields: dict
534+
self,
535+
form_fields: dict,
536+
multipart_fields: dict,
537+
required: list[str] | None = None,
500538
) -> dict | None:
501539
"""Build request body for form/multipart data"""
502540
if multipart_fields:
503-
# Multipart form data
504541
all_fields = {**form_fields, **multipart_fields}
505-
return {
506-
"content": {
507-
"multipart/form-data": {
508-
"schema": {"type": "object", "properties": all_fields}
509-
}
510-
}
542+
schema: dict[str, Any] = {
543+
"type": "object",
544+
"properties": all_fields,
511545
}
546+
if required:
547+
schema["required"] = required
548+
return {"content": {"multipart/form-data": {"schema": schema}}}
512549
elif form_fields:
513-
# URL-encoded form data
550+
schema = {
551+
"type": "object",
552+
"properties": form_fields,
553+
}
554+
if required:
555+
schema["required"] = required
514556
return {
515-
"content": {
516-
"application/x-www-form-urlencoded": {
517-
"schema": {"type": "object", "properties": form_fields}
518-
}
519-
}
557+
"content": {"application/x-www-form-urlencoded": {"schema": schema}}
520558
}
521559
return None
522560

561+
@staticmethod
562+
def _is_form_field_required(param: inspect.Parameter) -> bool:
563+
"""Check if a Form/File parameter is required"""
564+
if isinstance(param.default, BaseParam):
565+
return (
566+
param.default.default is ...
567+
or param.default.default is PydanticUndefined
568+
)
569+
return param.default is inspect.Parameter.empty
570+
523571
def _build_query_params_from_model(
524572
self, model_class: type[BaseModel]
525573
) -> list[dict]:
@@ -564,7 +612,7 @@ class ResponseBuilder:
564612
def __init__(self, schema_builder: SchemaBuilder):
565613
self.schema_builder = schema_builder
566614

567-
def build_responses(self, route) -> dict:
615+
def build_responses(self, route, has_security: bool = False) -> dict:
568616
"""Build responses section with enhanced error handling"""
569617
from http import HTTPStatus
570618

@@ -577,7 +625,7 @@ def build_responses(self, route) -> dict:
577625
)
578626

579627
# Add error responses
580-
self._add_security_error_responses(responses, route)
628+
self._add_security_error_responses(responses, route, has_security)
581629
self._add_custom_error_responses(responses, route)
582630

583631
return responses
@@ -603,9 +651,11 @@ def _add_response_model(
603651
schema = self.schema_builder.get_model_schema(response_model)
604652
responses[status_code]["content"] = {"application/json": {"schema": schema}}
605653

606-
def _add_security_error_responses(self, responses: dict, route) -> None:
654+
def _add_security_error_responses(
655+
self, responses: dict, route, has_security: bool = False
656+
) -> None:
607657
"""Add security-related error responses"""
608-
if not route.meta.get("security"):
658+
if not has_security:
609659
return
610660

611661
error_responses = {"401": "Unauthorized", "403": "Forbidden"}
@@ -763,7 +813,10 @@ def _build_operation(self, route) -> dict:
763813
parameters, request_body = self.parameter_processor.process_route_parameters(
764814
route
765815
)
766-
responses = self.response_builder.build_responses(route)
816+
has_security = bool(
817+
route.meta.get("security")
818+
) or self._has_security_dependency(route)
819+
responses = self.response_builder.build_responses(route, has_security)
767820

768821
operation = {
769822
"summary": route.meta.get("summary")

tests/openapi/test_openapi_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def test_response_builder_error_responses(self):
437437
route.meta = {"security": True}
438438

439439
builder = ResponseBuilder(self.generator.schema_builder)
440-
responses = builder.build_responses(route)
440+
responses = builder.build_responses(route, has_security=True)
441441

442442
# Should add security error responses
443443
assert "401" in responses

0 commit comments

Comments
 (0)