Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pydantic_forms/core/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import structlog
from pydantic import BaseModel, ConfigDict, PydanticUndefinedAnnotation, version

from pydantic_forms.utils.required import determine_required_form_fields

logger = structlog.get_logger(__name__)


Expand Down Expand Up @@ -45,6 +47,16 @@ def get_value(k: str, v: Any) -> Any:
mutable_data = {k: get_value(k, v) for k, v in data.items()}
super().__init__(**mutable_data)

@classmethod
def model_json_schema(cls, *args: Any, **kwargs: Any) -> dict[str, Any]:
schema = super().model_json_schema(*args, **kwargs)
required_fields = determine_required_form_fields(cls)

# TODO add toggle
if new_required := [k for k, v in required_fields.items() if v]:
schema["required"] = new_required
return schema

if PYDANTIC_VERSION in ("2.9", "2.10", "2.11"):

@classmethod
Expand Down
205 changes: 205 additions & 0 deletions pydantic_forms/utils/required.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright 2019-2026 SURF.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from collections.abc import Iterable
from typing import (
Annotated,
Any,
TypeVar,
Union,
get_args,
get_origin,
)

from more_itertools import first
from pydantic import BaseModel
from pydantic.fields import FieldInfo


def is_union(tp: type[Any] | None) -> bool:
return tp is Union or tp is types.UnionType # type: ignore[comparison-overlap]


def get_origin_and_args(t: Any) -> tuple[Any, tuple[Any, ...]]:
"""Return the origin and args of the given type.

When wrapped in Annotated[] this is removed.
"""
origin, args = get_origin(t), get_args(t)
if origin is not Annotated:
return origin, args

t_unwrapped = first(args)
return get_origin(t_unwrapped), get_args(t_unwrapped)


def is_union_type(t: Any, test_type: type | None = None) -> bool:
"""Check if `t` is union type (Union[Type, AnotherType]).

Optionally check if T is of `test_type` We cannot check for literal Nones.

>>> is_union_type(Union[int, str])
True
>>> is_union_type(Annotated[Union[int, str], "foo"])
True
>>> is_union_type(Union[int, str], str)
True
>>> is_union_type(Union[int, str], bool)
False
>>> is_union_type(Union[int, str], Union[int, str])
True
>>> is_union_type(Union[int, None])
True
>>> is_union_type(Annotated[Union[int, None], "foo"])
True
>>> is_union_type(int)
False
"""
origin, args = get_origin_and_args(t)
if not is_union(origin):
return False
if not test_type:
return True

if is_of_type(t, test_type):
return True

for arg in args:
result = is_of_type(arg, test_type)
if result:
return result
return False


def is_of_type(t: Any, test_type: Any) -> bool:
"""Check if annotation type is valid for type.

>>> is_of_type(list, list)
True
>>> is_of_type(list[int], list[int])
True
>>> is_of_type(strEnum, str)
True
>>> is_of_type(strEnum, Enum)
True
>>> is_of_type(int, str)
False
>>> is_of_type(Any, Any)
True
>>> is_of_type(Any, int)
True
"""
if t is Any:
return True

if is_union_type(test_type):
return any(get_origin(t) is get_origin(arg) for arg in get_args(test_type))

if (
get_origin(t)
and get_origin(test_type)
and get_origin(t) is get_origin(test_type)
and get_args(t) == get_args(test_type)
):
return True

if test_type is t:
# Test type is a typing type instance and matches
return True

# Workaround for the fact that you can't call issubclass on typing types
try:
return issubclass(t, test_type)
except TypeError:
return False


def filter_nonetype(types_: Iterable[Any]) -> Iterable[Any]:
def not_nonetype(type_: Any) -> bool:
return type_ is not None.__class__

return filter(not_nonetype, types_)


def is_optional_type(t: Any, test_type: type | None = None) -> bool:
"""Check if `t` is optional type (Union[None, ...]).

And optionally check if T is of `test_type`

>>> is_optional_type(Optional[int])
True
>>> is_optional_type(Annotated[Optional[int], "foo"])
True
>>> is_optional_type(Annotated[int, "foo"])
False
>>> is_optional_type(Union[None, int])
True
>>> is_optional_type(Union[int, str, None])
True
>>> is_optional_type(Union[int, str])
False
>>> is_optional_type(Optional[int], int)
True
>>> is_optional_type(Optional[int], str)
False
>>> is_optional_type(Annotated[Optional[int], "foo"], int)
True
>>> is_optional_type(Annotated[Optional[int], "foo"], str)
False
>>> is_optional_type(Optional[State], int)
False
>>> is_optional_type(Optional[State], State)
True
"""
origin, args = get_origin_and_args(t)

if is_union(origin) and None.__class__ in args:
field_type = first(filter_nonetype(args))
return test_type is None or is_of_type(field_type, test_type)
return False


# TODO The above code is copy-pasted from orchestrator-core/types.py.
# Maybe something to move to a shared lib some day.


def _is_required(field: FieldInfo) -> bool:
"""Determine whether a FormPage field is required.

Our logic extends that of Pydantic because of our common practice to use FormPage to transmit data.
TODO explain better
"""
match field.annotation, field.is_required(), field.default, field.json_schema_extra:
case _, True, _, _:
# Pydantic considers the field as required
return True
case _, False, None, _:
# Pydantic considers the field as optional, and the default is none
return False
case _, _, _, {"format": "read_only_field"}:
# pydantic-forms fields which we never want to mark as required
# TODO: is this complete?
return False
case t, _, _, _:
# A field is required if it's not optional (makes sense, doesn't it?)
return not is_optional_type(t)
case _:
# For any combination we've missed, the safest assumption is that it's not required
return False


BaseModelDerivative = TypeVar("BaseModelDerivative", bound=BaseModel)


def determine_required_form_fields(form: type[BaseModelDerivative]) -> dict[str, bool]:
return {name: _is_required(field) for name, field in form.model_fields.items()}
1 change: 1 addition & 0 deletions pydantic_forms/validators/components/read_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _get_read_only_schema(default: Any) -> dict:
"uniforms": forms_schema, # Deprecated
constants.EXTRA_PROPERTIES: forms_schema,
"type": _get_json_type(default),
"format": "read_only_field",
}


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ doc = [
dev = [
"toml",
"bumpversion",
"inline-snapshot",
"mypy_extensions",
"pre-commit",
"pydocstyle",
Expand Down
33 changes: 33 additions & 0 deletions tests/unit_tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from inline_snapshot import snapshot
from pydantic import ConfigDict, model_validator

from pydantic_forms.core import FormPage, generate_form, post_form
Expand All @@ -7,6 +8,7 @@

# TODO: Remove when generic forms of pydantic_forms are ready
from pydantic_forms.utils.json import json_dumps, json_loads
from pydantic_forms.utils.required import determine_required_form_fields


class TestChoices(strEnum):
Expand Down Expand Up @@ -223,3 +225,34 @@ def validator(cls, values: dict) -> dict:
assert len(e.value.errors) == 1
assert e.value.errors[0]["loc"] == ("__root__",)
assert e.value.errors[0]["msg"] == "too high"


class FormWithAllDefaultScenarios(FormPage):
field1: int
field2: int = 1
field3: int | None # Probably not used
field4: int | None = None # Dito
field5: int | None = 1


def test_defaults():
assert FormWithAllDefaultScenarios.model_json_schema() == snapshot(
{
"additionalProperties": False,
"properties": {
"field1": {"title": "Field1", "type": "integer"},
"field2": {"default": 1, "title": "Field2", "type": "integer"},
"field3": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Field3"},
"field4": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "Field4"},
"field5": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": 1, "title": "Field5"},
},
"required": ["field1", "field2", "field3"],
"title": "unknown",
"type": "object",
}
)


def test_defaults2():
requireds = determine_required_form_fields(FormWithAllDefaultScenarios)
assert requireds == snapshot({"field1": True, "field2": True, "field3": True, "field4": False, "field5": False})
61 changes: 33 additions & 28 deletions tests/unit_tests/test_display_subscription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from uuid import uuid4

from inline_snapshot import Is, snapshot

from pydantic_forms.core import FormPage
from pydantic_forms.validators import DisplaySubscription, Label, migration_summary

Expand Down Expand Up @@ -35,34 +37,37 @@ class Form(FormPage):
label: Label
summary: Summary

expected = {
"$defs": {"MigrationSummaryValue": {"properties": {}, "title": "MigrationSummaryValue", "type": "object"}},
"additionalProperties": False,
"properties": {
"display_sub": {
"default": str(some_sub_id),
"format": "subscription",
"title": "Display Sub",
"type": "string",
},
"label": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"format": "label",
"default": None,
"title": "Label",
"type": "string",
},
"summary": {
"$ref": "#/$defs/MigrationSummaryValue",
"default": None,
"format": "summary",
"type": "string",
"uniforms": {"data": {"headers": ["one"]}},
"extraProperties": {"data": {"headers": ["one"]}},
expected = snapshot(
{
"$defs": {"MigrationSummaryValue": {"properties": {}, "title": "MigrationSummaryValue", "type": "object"}},
"additionalProperties": False,
"properties": {
"display_sub": {
"default": Is(str(some_sub_id)),
"format": "subscription",
"title": "Display Sub",
"type": "string",
},
"label": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"default": None,
"format": "label",
"title": "Label",
"type": "string",
},
"summary": {
"$ref": "#/$defs/MigrationSummaryValue",
"default": None,
"extraProperties": {"data": {"headers": ["one"]}},
"format": "summary",
"type": "string",
"uniforms": {"data": {"headers": ["one"]}},
},
},
},
"title": "unknown",
"type": "object",
}
"title": "unknown",
"type": "object",
"required": ["display_sub"],
}
)

assert Form.model_json_schema() == expected
3 changes: 3 additions & 0 deletions tests/unit_tests/test_read_only_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Form(FormPage):
**enum,
"const": schema_value,
"default": schema_value,
"format": "read_only_field",
"title": "Read Only",
"uniforms": {"disabled": True, "value": schema_value},
"extraProperties": {"disabled": True, "value": schema_value},
Expand Down Expand Up @@ -88,6 +89,7 @@ class Form(FormPage):
"read_only_list": {
"default": schema_value,
"items": expected_item_type,
"format": "read_only_field",
"title": "Read Only List",
"uniforms": {"disabled": True, "value": schema_value},
"extraProperties": {"disabled": True, "value": schema_value},
Expand Down Expand Up @@ -175,6 +177,7 @@ class Form(FormPage):
"read_only_list": {
"default": schema_value,
"items": expected_item_type,
"format": "read_only_field",
"title": "Read Only List",
"uniforms": {"disabled": True, "value": schema_value},
"extraProperties": {"disabled": True, "value": schema_value},
Expand Down
Loading