Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c2f78b6

Browse files
committedOct 28, 2024·
add discriminator property support
1 parent a0b1bb7 commit c2f78b6

File tree

14 files changed

+650
-66
lines changed

14 files changed

+650
-66
lines changed
 

‎end_to_end_tests/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
""" Generate a complete client and verify that it is correct """
2+
import pytest
3+
4+
pytest.register_assert_rewrite('end_to_end_tests.end_to_end_live_tests')

‎end_to_end_tests/baseline_openapi_3.0.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,15 +2841,17 @@
28412841
"modelType": {
28422842
"type": "string"
28432843
}
2844-
}
2844+
},
2845+
"required": ["modelType"]
28452846
},
28462847
"ADiscriminatedUnionType2": {
28472848
"type": "object",
28482849
"properties": {
28492850
"modelType": {
28502851
"type": "string"
28512852
}
2852-
}
2853+
},
2854+
"required": ["modelType"]
28532855
}
28542856
},
28552857
"parameters": {

‎end_to_end_tests/baseline_openapi_3.1.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,15 +2835,17 @@ info:
28352835
"modelType": {
28362836
"type": "string"
28372837
}
2838-
}
2838+
},
2839+
"required": ["modelType"]
28392840
},
28402841
"ADiscriminatedUnionType2": {
28412842
"type": "object",
28422843
"properties": {
28432844
"modelType": {
28442845
"type": "string"
28452846
}
2846-
}
2847+
},
2848+
"required": ["modelType"]
28472849
}
28482850
}
28492851
"parameters": {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import importlib
2+
from typing import Any
3+
4+
import pytest
5+
6+
7+
def live_tests_3_x():
8+
_test_model_with_discriminated_union()
9+
10+
11+
def _import_model(module_name, class_name: str) -> Any:
12+
module = importlib.import_module(f"my_test_api_client.models.{module_name}")
13+
module = importlib.reload(module) # avoid test contamination from previous import
14+
return getattr(module, class_name)
15+
16+
17+
def _test_model_with_discriminated_union():
18+
ModelType1Class = _import_model("a_discriminated_union_type_1", "ADiscriminatedUnionType1")
19+
ModelType2Class = _import_model("a_discriminated_union_type_2", "ADiscriminatedUnionType2")
20+
ModelClass = _import_model("model_with_discriminated_union", "ModelWithDiscriminatedUnion")
21+
22+
assert (
23+
ModelClass.from_dict({"discriminated_union": {"modelType": "type1"}}) ==
24+
ModelClass(discriminated_union=ModelType1Class.from_dict({"modelType": "type1"}))
25+
)
26+
assert (
27+
ModelClass.from_dict({"discriminated_union": {"modelType": "type2"}}) ==
28+
ModelClass(discriminated_union=ModelType2Class.from_dict({"modelType": "type2"}))
29+
)
30+
with pytest.raises(TypeError):
31+
ModelClass.from_dict({"discriminated_union": {"modelType": "type3"}})
32+
with pytest.raises(TypeError):
33+
ModelClass.from_dict({"discriminated_union": {}})

‎end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_1.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Dict, List, Type, TypeVar
22

33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
from ..types import UNSET, Unset
7-
86
T = TypeVar("T", bound="ADiscriminatedUnionType1")
97

108

119
@_attrs_define
1210
class ADiscriminatedUnionType1:
1311
"""
1412
Attributes:
15-
model_type (Union[Unset, str]):
13+
model_type (str):
1614
"""
1715

18-
model_type: Union[Unset, str] = UNSET
16+
model_type: str
1917
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2018

2119
def to_dict(self) -> Dict[str, Any]:
2220
model_type = self.model_type
2321

2422
field_dict: Dict[str, Any] = {}
2523
field_dict.update(self.additional_properties)
26-
field_dict.update({})
27-
if model_type is not UNSET:
28-
field_dict["modelType"] = model_type
24+
field_dict.update(
25+
{
26+
"modelType": model_type,
27+
}
28+
)
2929

3030
return field_dict
3131

3232
@classmethod
3333
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
3434
d = src_dict.copy()
35-
model_type = d.pop("modelType", UNSET)
35+
model_type = d.pop("modelType")
3636

3737
a_discriminated_union_type_1 = cls(
3838
model_type=model_type,

‎end_to_end_tests/golden-record/my_test_api_client/models/a_discriminated_union_type_2.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Dict, List, Type, TypeVar
22

33
from attrs import define as _attrs_define
44
from attrs import field as _attrs_field
55

6-
from ..types import UNSET, Unset
7-
86
T = TypeVar("T", bound="ADiscriminatedUnionType2")
97

108

119
@_attrs_define
1210
class ADiscriminatedUnionType2:
1311
"""
1412
Attributes:
15-
model_type (Union[Unset, str]):
13+
model_type (str):
1614
"""
1715

18-
model_type: Union[Unset, str] = UNSET
16+
model_type: str
1917
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
2018

2119
def to_dict(self) -> Dict[str, Any]:
2220
model_type = self.model_type
2321

2422
field_dict: Dict[str, Any] = {}
2523
field_dict.update(self.additional_properties)
26-
field_dict.update({})
27-
if model_type is not UNSET:
28-
field_dict["modelType"] = model_type
24+
field_dict.update(
25+
{
26+
"modelType": model_type,
27+
}
28+
)
2929

3030
return field_dict
3131

3232
@classmethod
3333
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
3434
d = src_dict.copy()
35-
model_type = d.pop("modelType", UNSET)
35+
model_type = d.pop("modelType")
3636

3737
a_discriminated_union_type_2 = cls(
3838
model_type=model_type,

‎end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,34 @@ def _parse_discriminated_union(
5959
return data
6060
if isinstance(data, Unset):
6161
return data
62-
try:
63-
if not isinstance(data, dict):
64-
raise TypeError()
65-
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)
66-
67-
return componentsschemas_a_discriminated_union_type_0
68-
except: # noqa: E722
69-
pass
70-
try:
71-
if not isinstance(data, dict):
72-
raise TypeError()
73-
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
74-
75-
return componentsschemas_a_discriminated_union_type_1
76-
except: # noqa: E722
77-
pass
78-
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
62+
if not isinstance(data, dict):
63+
raise TypeError()
64+
if "modelType" in data:
65+
_discriminator_value = data["modelType"]
66+
67+
def _parse_componentsschemas_a_discriminated_union_type_1(data: object) -> ADiscriminatedUnionType1:
68+
if not isinstance(data, dict):
69+
raise TypeError()
70+
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType1.from_dict(data)
71+
72+
return componentsschemas_a_discriminated_union_type_1
73+
74+
def _parse_componentsschemas_a_discriminated_union_type_2(data: object) -> ADiscriminatedUnionType2:
75+
if not isinstance(data, dict):
76+
raise TypeError()
77+
componentsschemas_a_discriminated_union_type_2 = ADiscriminatedUnionType2.from_dict(data)
78+
79+
return componentsschemas_a_discriminated_union_type_2
80+
81+
_discriminator_mapping = {
82+
"type1": _parse_componentsschemas_a_discriminated_union_type_1,
83+
"type2": _parse_componentsschemas_a_discriminated_union_type_2,
84+
}
85+
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
86+
return cast(
87+
Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data)
88+
)
89+
raise TypeError("unrecognized value for property modelType")
7990

8091
discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))
8192

‎end_to_end_tests/test_end_to_end.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
import os
12
import shutil
23
from filecmp import cmpfiles, dircmp
34
from pathlib import Path
4-
from typing import Dict, List, Optional, Set
5+
import sys
6+
from typing import Callable, Dict, List, Optional, Set
57

68
import pytest
79
from click.testing import Result
810
from typer.testing import CliRunner
911

1012
from openapi_python_client.cli import app
13+
from .end_to_end_live_tests import live_tests_3_x
14+
1115

1216

1317
def _compare_directories(
@@ -83,8 +87,10 @@ def run_e2e_test(
8387
golden_record_path: str = "golden-record",
8488
output_path: str = "my-test-api-client",
8589
expected_missing: Optional[Set[str]] = None,
90+
live_tests: Optional[Callable[[str], None]] = None,
8691
) -> Result:
87-
output_path = Path.cwd() / output_path
92+
cwd = Path.cwd()
93+
output_path = cwd / output_path
8894
shutil.rmtree(output_path, ignore_errors=True)
8995
result = generate(extra_args, openapi_document)
9096
gr_path = Path(__file__).parent / golden_record_path
@@ -97,6 +103,13 @@ def run_e2e_test(
97103
_compare_directories(
98104
gr_path, output_path, expected_differences=expected_differences, expected_missing=expected_missing
99105
)
106+
if live_tests:
107+
old_path = sys.path.copy()
108+
sys.path.insert(0, str(output_path))
109+
try:
110+
live_tests()
111+
finally:
112+
sys.path = old_path
100113

101114
import mypy.api
102115

@@ -131,11 +144,11 @@ def _run_command(command: str, extra_args: Optional[List[str]] = None, openapi_d
131144

132145

133146
def test_baseline_end_to_end_3_0():
134-
run_e2e_test("baseline_openapi_3.0.json", [], {})
147+
run_e2e_test("baseline_openapi_3.0.json", [], {}, live_tests=live_tests_3_x)
135148

136149

137150
def test_baseline_end_to_end_3_1():
138-
run_e2e_test("baseline_openapi_3.1.yaml", [], {})
151+
run_e2e_test("baseline_openapi_3.1.yaml", [], {}, live_tests=live_tests_3_x)
139152

140153

141154
def test_3_1_specific_features():

‎openapi_python_client/parser/properties/union.py

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from ...utils import PythonIdentifier
1111
from ..errors import ParseError, PropertyError
1212
from .protocol import PropertyProtocol, Value
13-
from .schemas import Schemas
13+
from .schemas import Schemas, parse_reference_path
14+
15+
16+
@define
17+
class DiscriminatorDefinition:
18+
from .model_property import ModelProperty
19+
20+
property_name: str
21+
value_to_model_map: dict[str, ModelProperty]
1422

1523

1624
@define
@@ -24,6 +32,7 @@ class UnionProperty(PropertyProtocol):
2432
description: str | None
2533
example: str | None
2634
inner_properties: list[PropertyProtocol]
35+
discriminators: list[DiscriminatorDefinition] | None = None
2736
template: ClassVar[str] = "union_property.py.jinja"
2837

2938
@classmethod
@@ -67,16 +76,7 @@ def build(
6776
return PropertyError(detail=f"Invalid property in union {name}", data=sub_prop_data), schemas
6877
sub_properties.append(sub_prop)
6978

70-
def flatten_union_properties(sub_properties: list[PropertyProtocol]) -> list[PropertyProtocol]:
71-
flattened = []
72-
for sub_prop in sub_properties:
73-
if isinstance(sub_prop, UnionProperty):
74-
flattened.extend(flatten_union_properties(sub_prop.inner_properties))
75-
else:
76-
flattened.append(sub_prop)
77-
return flattened
78-
79-
sub_properties = flatten_union_properties(sub_properties)
79+
sub_properties, discriminators_list = _flatten_union_properties(sub_properties)
8080

8181
prop = UnionProperty(
8282
name=name,
@@ -92,6 +92,17 @@ def flatten_union_properties(sub_properties: list[PropertyProtocol]) -> list[Pro
9292
default_or_error.data = data
9393
return default_or_error, schemas
9494
prop = evolve(prop, default=default_or_error)
95+
96+
if data.discriminator:
97+
discriminator_or_error = _parse_discriminator(data.discriminator, sub_properties, schemas)
98+
if isinstance(discriminator_or_error, PropertyError):
99+
return discriminator_or_error, schemas
100+
discriminators_list = [discriminator_or_error, *discriminators_list]
101+
if discriminators_list:
102+
if error := _validate_discriminators(discriminators_list):
103+
return error, schemas
104+
prop = evolve(prop, discriminators=discriminators_list)
105+
95106
return prop, schemas
96107

97108
def convert_value(self, value: Any) -> Value | None | PropertyError:
@@ -189,3 +200,103 @@ def validate_location(self, location: oai.ParameterLocation) -> ParseError | Non
189200
if evolve(cast(Property, inner_prop), required=self.required).validate_location(location) is not None:
190201
return ParseError(detail=f"{self.get_type_string()} is not allowed in {location}")
191202
return None
203+
204+
205+
def _flatten_union_properties(
206+
sub_properties: list[PropertyProtocol],
207+
) -> tuple[list[PropertyProtocol], list[DiscriminatorDefinition]]:
208+
flattened = []
209+
discriminators = []
210+
for sub_prop in sub_properties:
211+
if isinstance(sub_prop, UnionProperty):
212+
if sub_prop.discriminators:
213+
discriminators.extend(sub_prop.discriminators)
214+
new_props, new_discriminators = _flatten_union_properties(sub_prop.inner_properties)
215+
flattened.extend(new_props)
216+
discriminators.extend(new_discriminators)
217+
else:
218+
flattened.append(sub_prop)
219+
return flattened, discriminators
220+
221+
222+
def _parse_discriminator(
223+
data: oai.Discriminator,
224+
subtypes: list[PropertyProtocol],
225+
schemas: Schemas,
226+
) -> DiscriminatorDefinition | PropertyError:
227+
from .model_property import ModelProperty
228+
229+
# See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object
230+
231+
def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None:
232+
# This is needed because, when we built the union list, $refs were changed into a copy of
233+
# the type they referred to, without preserving the original name. We need to know that
234+
# every type in the discriminator is a $ref to a top-level type and we need its name.
235+
for prop in schemas.classes_by_reference.values():
236+
if isinstance(prop, ModelProperty):
237+
if prop.class_info == matching_model.class_info:
238+
return prop
239+
return None
240+
241+
model_types_by_name: dict[str, ModelProperty] = {}
242+
for model in subtypes:
243+
# Note, model here can never be a UnionProperty, because we've already done
244+
# flatten_union_properties() before this point.
245+
if not isinstance(model, ModelProperty):
246+
return PropertyError(
247+
detail="All schema variants must be objects when using a discriminator",
248+
)
249+
top_level_model = _find_top_level_model(model)
250+
if not top_level_model:
251+
return PropertyError(
252+
detail="Inline schema declarations are not allowed when using a discriminator",
253+
)
254+
name = top_level_model.name
255+
if name.startswith("/components/schemas/"):
256+
name = name.split("/", 3)[-1]
257+
model_types_by_name[name] = top_level_model
258+
259+
# The discriminator can specify an explicit mapping of values to types, but it doesn't
260+
# have to; the default behavior is that the value for each type is simply its name.
261+
mapping: dict[str, ModelProperty] = model_types_by_name.copy()
262+
if data.mapping:
263+
for discriminator_value, model_ref in data.mapping.items():
264+
ref_path = parse_reference_path(
265+
model_ref if model_ref.startswith("#/components/schemas/") else f"#/components/schemas/{model_ref}"
266+
)
267+
if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference:
268+
return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping')
269+
name = ref_path.split("/", 3)[-1]
270+
if not (lookup_model := model_types_by_name.get(name)):
271+
return PropertyError(
272+
detail=f'Discriminator mapping referred to "{model_ref}" which is not one of the schema variants',
273+
)
274+
for original_value in (name for name, m in model_types_by_name.items() if m == lookup_model):
275+
mapping.pop(original_value)
276+
mapping[discriminator_value] = lookup_model
277+
else:
278+
mapping = model_types_by_name
279+
280+
return DiscriminatorDefinition(property_name=data.propertyName, value_to_model_map=mapping)
281+
282+
283+
def _validate_discriminators(
284+
discriminators: list[DiscriminatorDefinition],
285+
) -> PropertyError | None:
286+
prop_names_values_classes = [
287+
(discriminator.property_name, key, model.class_info.name)
288+
for discriminator in discriminators
289+
for key, model in discriminator.value_to_model_map.items()
290+
]
291+
for p, v in {(p, v) for p, v, _ in prop_names_values_classes}:
292+
if len({c for p1, v1, c in prop_names_values_classes if (p1, v1) == (p, v)}) > 1:
293+
return PropertyError(f'Discriminator property "{p}" had more than one schema for value "{v}"')
294+
return None
295+
296+
# TODO: We should also validate that property_name refers to a property that 1. exists,
297+
# 2. is required, 3. is a string (in all of these models). However, currently we can't
298+
# do that because, at the time this function is called, the ModelProperties within the
299+
# union haven't yet been post-processed and so we don't have full information about
300+
# their properties. To fix this, we may need to generalize the post-processing phase so
301+
# that any Property type, not just ModelProperty, can say it needs post-processing; then
302+
# we can defer _validate_discriminators till that phase.

‎openapi_python_client/schema/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
__all__ = [
2+
"Discriminator",
23
"MediaType",
34
"OpenAPI",
45
"Operation",
@@ -17,6 +18,7 @@
1718

1819
from .data_type import DataType
1920
from .openapi_schema_pydantic import (
21+
Discriminator,
2022
MediaType,
2123
OpenAPI,
2224
Operation,

‎openapi_python_client/templates/property_templates/union_property.py.jinja

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,36 @@
1+
{% macro construct_inner_property(inner_property) %}
2+
{% import "property_templates/" + inner_property.template as inner_template %}
3+
{% if inner_template.check_type_for_construct %}
4+
if not {{ inner_template.check_type_for_construct(inner_property, "data") }}:
5+
raise TypeError()
6+
{% endif %}
7+
{{ inner_template.construct(inner_property, "data") }}
8+
return {{ inner_property.python_name }}
9+
{%- endmacro %}
10+
11+
{% macro construct_discriminator_lookup(property) %}
12+
{% set _discriminator_properties = [] -%}
13+
{% for discriminator in property.discriminators %}
14+
{{- _discriminator_properties.append(discriminator.property_name) or "" -}}
15+
if not isinstance(data, dict):
16+
raise TypeError()
17+
if "{{ discriminator.property_name }}" in data:
18+
_discriminator_value = data["{{ discriminator.property_name }}"]
19+
{% for model in discriminator.value_to_model_map.values() %}
20+
def _parse_{{ model.python_name }}(data: object) -> {{ model.get_type_string() }}:
21+
{{ construct_inner_property(model) | indent(12, True) }}
22+
{% endfor %}
23+
_discriminator_mapping = {
24+
{% for value, model in discriminator.value_to_model_map.items() %}
25+
"{{ value }}": _parse_{{ model.python_name }},
26+
{% endfor %}
27+
}
28+
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
29+
return cast({{ property.get_type_string() }}, _parse_fn(data))
30+
{% endfor %}
31+
raise TypeError(f"unrecognized value for property {{ _discriminator_properties | join(' or ') }}")
32+
{% endmacro %}
33+
134
{% macro construct(property, source) %}
235
def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_string() }}:
336
{% if "None" in property.get_type_strings_in_union(json=True, multipart=False) %}
@@ -8,6 +41,9 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri
841
if isinstance(data, Unset):
942
return data
1043
{% endif %}
44+
{% if property.discriminators %}
45+
{{ construct_discriminator_lookup(property) | indent(4, True) }}
46+
{% else %}
1147
{% set ns = namespace(contains_unmodified_properties = false) %}
1248
{% for inner_property in property.inner_properties %}
1349
{% import "property_templates/" + inner_property.template as inner_template %}
@@ -17,24 +53,17 @@ def _parse_{{ property.python_name }}(data: object) -> {{ property.get_type_stri
1753
{% endif %}
1854
{% if inner_template.check_type_for_construct and (not loop.last or ns.contains_unmodified_properties) %}
1955
try:
20-
if not {{ inner_template.check_type_for_construct(inner_property, "data") }}:
21-
raise TypeError()
22-
{{ inner_template.construct(inner_property, "data") | indent(8) }}
23-
return {{ inner_property.python_name }}
56+
{{ construct_inner_property(inner_property) | indent(8, True) }}
2457
except: # noqa: E722
2558
pass
2659
{% else %}{# Don't do try/except for the last one nor any properties with no type checking #}
27-
{% if inner_template.check_type_for_construct %}
28-
if not {{ inner_template.check_type_for_construct(inner_property, "data") }}:
29-
raise TypeError()
30-
{% endif %}
31-
{{ inner_template.construct(inner_property, "data") | indent(4) }}
32-
return {{ inner_property.python_name }}
60+
{{ construct_inner_property(inner_property) | indent(4, True) }}
3361
{% endif %}
3462
{% endfor %}
3563
{% if ns.contains_unmodified_properties %}
3664
return cast({{ property.get_type_string() }}, data)
3765
{% endif %}
66+
{% endif %}
3867

3968
{{ property.python_name }} = _parse_{{ property.python_name }}({{ source }})
4069
{% endmacro %}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import re
2+
from typing import Any, Union
3+
4+
from openapi_python_client.parser.errors import PropertyError
5+
from openapi_python_client.parser.properties.property import Property
6+
7+
8+
def assert_prop_error(
9+
p: Union[Property, PropertyError],
10+
message_regex: str,
11+
data: Any = None,
12+
) -> None:
13+
assert isinstance(p, PropertyError)
14+
assert re.search(message_regex, p.detail)
15+
if data is not None:
16+
assert p.data == data

‎tests/test_parser/test_properties/test_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def test_property_from_data_ref_model(self, model_property_factory, config):
631631
required=required,
632632
class_info=class_info,
633633
)
634-
assert schemas == new_schemas
634+
assert new_schemas == schemas
635635

636636
def test_property_from_data_ref_not_found(self, mocker):
637637
from openapi_python_client.parser.properties import PropertyError, Schemas, property_from_data

‎tests/test_parser/test_properties/test_union.py

Lines changed: 363 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1+
from typing import Dict, List, Optional, Tuple, Union
2+
3+
import pytest
4+
from attr import evolve
5+
16
import openapi_python_client.schema as oai
7+
from openapi_python_client.config import Config
28
from openapi_python_client.parser.errors import ParseError, PropertyError
39
from openapi_python_client.parser.properties import Schemas, UnionProperty
10+
from openapi_python_client.parser.properties.model_property import ModelProperty
11+
from openapi_python_client.parser.properties.property import Property
412
from openapi_python_client.parser.properties.protocol import Value
13+
from openapi_python_client.parser.properties.schemas import Class
514
from openapi_python_client.schema import DataType, ParameterLocation
15+
from tests.test_parser.test_properties.properties_test_helpers import assert_prop_error
616

717

818
def test_property_from_data_union(union_property_factory, date_time_property_factory, string_property_factory, config):
@@ -33,6 +43,206 @@ def test_property_from_data_union(union_property_factory, date_time_property_fac
3343
assert s == Schemas()
3444

3545

46+
def _make_basic_model(
47+
name: str,
48+
props: Dict[str, oai.Schema],
49+
required_prop: Optional[str],
50+
schemas: Schemas,
51+
config: Config,
52+
) -> Tuple[ModelProperty, Schemas]:
53+
model, schemas = ModelProperty.build(
54+
data=oai.Schema.model_construct(
55+
required=[required_prop] if required_prop else [],
56+
title=name,
57+
properties=props,
58+
),
59+
name=name or "some_generated_name",
60+
schemas=schemas,
61+
required=False,
62+
parent_name="",
63+
config=config,
64+
roots={"root"},
65+
process_properties=True,
66+
)
67+
assert isinstance(model, ModelProperty)
68+
if name:
69+
schemas = evolve(
70+
schemas, classes_by_reference={**schemas.classes_by_reference, f"/components/schemas/{name}": model}
71+
)
72+
return model, schemas
73+
74+
75+
def _assert_valid_discriminator(
76+
p: Union[Property, PropertyError],
77+
expected_discriminators: List[Tuple[str, Dict[str, Class]]],
78+
) -> None:
79+
assert isinstance(p, UnionProperty)
80+
assert p.discriminators
81+
assert [(d[0], {key: model.class_info for key, model in d[1].items()}) for d in expected_discriminators] == [
82+
(d.property_name, {key: model.class_info for key, model in d.value_to_model_map.items()})
83+
for d in p.discriminators
84+
]
85+
86+
87+
def test_discriminator_with_explicit_mapping(config):
88+
from openapi_python_client.parser.properties import Schemas, property_from_data
89+
90+
schemas = Schemas()
91+
props = {"type": oai.Schema.model_construct(type="string")}
92+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
93+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
94+
data = oai.Schema.model_construct(
95+
oneOf=[
96+
oai.Reference(ref="#/components/schemas/Model1"),
97+
oai.Reference(ref="#/components/schemas/Model2"),
98+
],
99+
discriminator=oai.Discriminator.model_construct(
100+
propertyName="type",
101+
mapping={
102+
# mappings can use either a fully-qualified schema reference or just the schema name
103+
"type1": "#/components/schemas/Model1",
104+
"type2": "Model2",
105+
},
106+
),
107+
)
108+
109+
p, schemas = property_from_data(
110+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
111+
)
112+
_assert_valid_discriminator(p, [("type", {"type1": model1, "type2": model2})])
113+
114+
115+
def test_discriminator_with_implicit_mapping(config):
116+
from openapi_python_client.parser.properties import Schemas, property_from_data
117+
118+
schemas = Schemas()
119+
props = {"type": oai.Schema.model_construct(type="string")}
120+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
121+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
122+
data = oai.Schema.model_construct(
123+
oneOf=[
124+
oai.Reference(ref="#/components/schemas/Model1"),
125+
oai.Reference(ref="#/components/schemas/Model2"),
126+
],
127+
discriminator=oai.Discriminator.model_construct(
128+
propertyName="type",
129+
),
130+
)
131+
132+
p, schemas = property_from_data(
133+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
134+
)
135+
_assert_valid_discriminator(p, [("type", {"Model1": model1, "Model2": model2})])
136+
137+
138+
def test_discriminator_with_partial_explicit_mapping(config):
139+
from openapi_python_client.parser.properties import Schemas, property_from_data
140+
141+
schemas = Schemas()
142+
props = {"type": oai.Schema.model_construct(type="string")}
143+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
144+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
145+
data = oai.Schema.model_construct(
146+
oneOf=[
147+
oai.Reference(ref="#/components/schemas/Model1"),
148+
oai.Reference(ref="#/components/schemas/Model2"),
149+
],
150+
discriminator=oai.Discriminator.model_construct(
151+
propertyName="type",
152+
mapping={
153+
"type1": "#/components/schemas/Model1",
154+
# no value specified for Model2, so it defaults to just "Model2"
155+
},
156+
),
157+
)
158+
159+
p, schemas = property_from_data(
160+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
161+
)
162+
_assert_valid_discriminator(p, [("type", {"type1": model1, "Model2": model2})])
163+
164+
165+
def test_discriminators_in_nested_unions_same_property(config):
166+
from openapi_python_client.parser.properties import Schemas, property_from_data
167+
168+
schemas = Schemas()
169+
props = {"type": oai.Schema.model_construct(type="string")}
170+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
171+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
172+
model3, schemas = _make_basic_model("Model3", props, "type", schemas, config)
173+
model4, schemas = _make_basic_model("Model4", props, "type", schemas, config)
174+
data = oai.Schema.model_construct(
175+
oneOf=[
176+
oai.Schema.model_construct(
177+
oneOf=[
178+
oai.Reference(ref="#/components/schemas/Model1"),
179+
oai.Reference(ref="#/components/schemas/Model2"),
180+
],
181+
discriminator=oai.Discriminator.model_construct(propertyName="type"),
182+
),
183+
oai.Schema.model_construct(
184+
oneOf=[
185+
oai.Reference(ref="#/components/schemas/Model3"),
186+
oai.Reference(ref="#/components/schemas/Model4"),
187+
],
188+
discriminator=oai.Discriminator.model_construct(propertyName="type"),
189+
),
190+
],
191+
)
192+
193+
p, schemas = property_from_data(
194+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
195+
)
196+
_assert_valid_discriminator(
197+
p,
198+
[
199+
("type", {"Model1": model1, "Model2": model2}),
200+
("type", {"Model3": model3, "Model4": model4}),
201+
],
202+
)
203+
204+
205+
def test_discriminators_in_nested_unions_different_property(config):
206+
from openapi_python_client.parser.properties import Schemas, property_from_data
207+
208+
schemas = Schemas()
209+
props1 = {"type": oai.Schema.model_construct(type="string")}
210+
props2 = {"other": oai.Schema.model_construct(type="string")}
211+
model1, schemas = _make_basic_model("Model1", props1, "type", schemas, config)
212+
model2, schemas = _make_basic_model("Model2", props1, "type", schemas, config)
213+
model3, schemas = _make_basic_model("Model3", props2, "other", schemas, config)
214+
model4, schemas = _make_basic_model("Model4", props2, "other", schemas, config)
215+
data = oai.Schema.model_construct(
216+
oneOf=[
217+
oai.Schema.model_construct(
218+
oneOf=[
219+
oai.Reference(ref="#/components/schemas/Model1"),
220+
oai.Reference(ref="#/components/schemas/Model2"),
221+
],
222+
discriminator=oai.Discriminator.model_construct(propertyName="type"),
223+
),
224+
oai.Schema.model_construct(
225+
oneOf=[
226+
oai.Reference(ref="#/components/schemas/Model3"),
227+
oai.Reference(ref="#/components/schemas/Model4"),
228+
],
229+
discriminator=oai.Discriminator.model_construct(propertyName="other"),
230+
),
231+
],
232+
)
233+
234+
p, schemas = property_from_data(
235+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
236+
)
237+
_assert_valid_discriminator(
238+
p,
239+
[
240+
("type", {"Model1": model1, "Model2": model2}),
241+
("other", {"Model3": model3, "Model4": model4}),
242+
],
243+
)
244+
245+
36246
def test_build_union_property_invalid_property(config):
37247
name = "bad_union"
38248
required = True
@@ -42,7 +252,7 @@ def test_build_union_property_invalid_property(config):
42252
p, s = UnionProperty.build(
43253
name=name, required=required, data=data, schemas=Schemas(), parent_name="parent", config=config
44254
)
45-
assert p == PropertyError(detail=f"Invalid property in union {name}", data=reference)
255+
assert_prop_error(p, f"Invalid property in union {name}", data=reference)
46256

47257

48258
def test_invalid_default(config):
@@ -82,3 +292,155 @@ def test_not_required_in_path(config):
82292

83293
err = prop.validate_location(ParameterLocation.PATH)
84294
assert isinstance(err, ParseError)
295+
296+
297+
@pytest.mark.parametrize("bad_ref", ["#/components/schemas/UnknownModel", "http://remote/Model2"])
298+
def test_discriminator_invalid_reference(bad_ref, config):
299+
from openapi_python_client.parser.properties import Schemas, property_from_data
300+
301+
schemas = Schemas()
302+
props = {"type": oai.Schema.model_construct(type="string")}
303+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
304+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
305+
data = oai.Schema.model_construct(
306+
oneOf=[
307+
oai.Reference(ref="#/components/schemas/Model1"),
308+
oai.Reference(ref="#/components/schemas/Model2"),
309+
],
310+
discriminator=oai.Discriminator.model_construct(
311+
propertyName="type",
312+
mapping={
313+
"Model1": "#/components/schemas/Model1",
314+
"Model2": bad_ref,
315+
},
316+
),
317+
)
318+
319+
p, schemas = property_from_data(
320+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
321+
)
322+
assert_prop_error(p, "^Invalid reference")
323+
324+
325+
def test_discriminator_mapping_uses_schema_not_in_list(config):
326+
from openapi_python_client.parser.properties import Schemas, property_from_data
327+
328+
schemas = Schemas()
329+
props = {"type": oai.Schema.model_construct(type="string")}
330+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
331+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
332+
model3, schemas = _make_basic_model("Model3", props, "type", schemas, config)
333+
data = oai.Schema.model_construct(
334+
oneOf=[
335+
oai.Reference(ref="#/components/schemas/Model1"),
336+
oai.Reference(ref="#/components/schemas/Model2"),
337+
],
338+
discriminator=oai.Discriminator.model_construct(
339+
propertyName="type",
340+
mapping={
341+
"Model1": "#/components/schemas/Model1",
342+
"Model3": "#/components/schemas/Model3",
343+
},
344+
),
345+
)
346+
347+
p, schemas = property_from_data(
348+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
349+
)
350+
assert_prop_error(p, "not one of the schema variants")
351+
352+
353+
def test_discriminator_invalid_variant_is_not_object(config, string_property_factory):
354+
from openapi_python_client.parser.properties import Schemas, property_from_data
355+
356+
schemas = Schemas()
357+
props = {"type": oai.Schema.model_construct(type="string")}
358+
model_type, schemas = _make_basic_model("ModelType", props, "type", schemas, config)
359+
string_type = string_property_factory()
360+
schemas = evolve(
361+
schemas,
362+
classes_by_reference={
363+
**schemas.classes_by_reference,
364+
"/components/schemas/StringType": string_type,
365+
},
366+
)
367+
data = oai.Schema.model_construct(
368+
oneOf=[
369+
oai.Reference(ref="#/components/schemas/ModelType"),
370+
oai.Reference(ref="#/components/schemas/StringType"),
371+
],
372+
discriminator=oai.Discriminator.model_construct(
373+
propertyName="type",
374+
),
375+
)
376+
377+
p, schemas = property_from_data(
378+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
379+
)
380+
assert_prop_error(p, "must be objects")
381+
382+
383+
def test_discriminator_invalid_inline_schema_variant(config, string_property_factory):
384+
from openapi_python_client.parser.properties import Schemas, property_from_data
385+
386+
schemas = Schemas()
387+
schemas = Schemas()
388+
props = {"type": oai.Schema.model_construct(type="string")}
389+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
390+
data = oai.Schema.model_construct(
391+
oneOf=[
392+
oai.Reference(ref="#/components/schemas/Model1"),
393+
oai.Schema.model_construct(
394+
type="object",
395+
properties=props,
396+
),
397+
],
398+
discriminator=oai.Discriminator.model_construct(
399+
propertyName="type",
400+
),
401+
)
402+
403+
p, schemas = property_from_data(
404+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
405+
)
406+
assert_prop_error(p, "Inline schema")
407+
408+
409+
def test_conflicting_discriminator_mappings(config):
410+
from openapi_python_client.parser.properties import Schemas, property_from_data
411+
412+
schemas = Schemas()
413+
props = {"type": oai.Schema.model_construct(type="string")}
414+
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
415+
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
416+
model3, schemas = _make_basic_model("Model3", props, "type", schemas, config)
417+
model4, schemas = _make_basic_model("Model4", props, "type", schemas, config)
418+
data = oai.Schema.model_construct(
419+
oneOf=[
420+
oai.Schema.model_construct(
421+
oneOf=[
422+
oai.Reference(ref="#/components/schemas/Model1"),
423+
oai.Reference(ref="#/components/schemas/Model2"),
424+
],
425+
discriminator=oai.Discriminator.model_construct(
426+
propertyName="type",
427+
mapping={"a": "Model1", "b": "Model2"},
428+
),
429+
),
430+
oai.Schema.model_construct(
431+
oneOf=[
432+
oai.Reference(ref="#/components/schemas/Model3"),
433+
oai.Reference(ref="#/components/schemas/Model4"),
434+
],
435+
discriminator=oai.Discriminator.model_construct(
436+
propertyName="type",
437+
mapping={"a": "Model3", "x": "Model4"},
438+
),
439+
),
440+
],
441+
)
442+
443+
p, schemas = property_from_data(
444+
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
445+
)
446+
assert_prop_error(p, '"type" had more than one schema for value "a"')

0 commit comments

Comments
 (0)
Please sign in to comment.