Skip to content

Commit 8dde89e

Browse files
authored
Use stricter serializer for unions of simple types (#1132)
1 parent 4df7624 commit 8dde89e

File tree

2 files changed

+134
-10
lines changed

2 files changed

+134
-10
lines changed

src/serializers/type_serializers/simple.rs

+18-10
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ use std::borrow::Cow;
55

66
use serde::Serialize;
77

8+
use crate::PydanticSerializationUnexpectedValue;
89
use crate::{definitions::DefinitionsBuilder, input::Int};
910

1011
use super::{
1112
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType,
12-
SerMode, TypeSerializer,
13+
SerCheck, SerMode, TypeSerializer,
1314
};
1415

1516
#[derive(Debug, Clone)]
@@ -85,7 +86,7 @@ impl TypeSerializer for NoneSerializer {
8586
}
8687

8788
macro_rules! build_simple_serializer {
88-
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident) => {
89+
($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident, $subtypes_allowed:expr) => {
8990
#[derive(Debug, Clone)]
9091
pub struct $struct_name;
9192

@@ -114,12 +115,15 @@ macro_rules! build_simple_serializer {
114115
let py = value.py();
115116
match extra.ob_type_lookup.is_type(value, $ob_type) {
116117
IsType::Exact => Ok(value.into_py(py)),
117-
IsType::Subclass => match extra.mode {
118-
SerMode::Json => {
119-
let rust_value = value.extract::<$rust_type>()?;
120-
Ok(rust_value.to_object(py))
121-
}
122-
_ => infer_to_python(value, include, exclude, extra),
118+
IsType::Subclass => match extra.check {
119+
SerCheck::Strict => Err(PydanticSerializationUnexpectedValue::new_err(None)),
120+
SerCheck::Lax | SerCheck::None => match extra.mode {
121+
SerMode::Json => {
122+
let rust_value = value.extract::<$rust_type>()?;
123+
Ok(rust_value.to_object(py))
124+
}
125+
_ => infer_to_python(value, include, exclude, extra),
126+
},
123127
},
124128
IsType::False => {
125129
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
@@ -160,6 +164,10 @@ macro_rules! build_simple_serializer {
160164
fn get_name(&self) -> &str {
161165
Self::EXPECTED_TYPE
162166
}
167+
168+
fn retry_with_lax_check(&self) -> bool {
169+
$subtypes_allowed
170+
}
163171
}
164172
};
165173
}
@@ -168,7 +176,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult<Cow<str>> {
168176
Ok(key.str()?.to_string_lossy())
169177
}
170178

171-
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key);
179+
build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true);
172180

173181
pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
174182
let v = if key.is_true().unwrap_or(false) {
@@ -179,4 +187,4 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
179187
Ok(Cow::Borrowed(v))
180188
}
181189

182-
build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key);
190+
build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key, false);

tests/serializers/test_union.py

+116
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses
22
import json
33
import re
4+
import uuid
5+
from decimal import Decimal
46
from typing import Any, ClassVar, Union
57

68
import pytest
@@ -510,3 +512,117 @@ class Item(BaseModel):
510512
)
511513

512514
assert s.to_python([DBUser(name='John', password='secret')]) == [{'name': 'John'}]
515+
516+
517+
EXAMPLE_UUID = uuid.uuid4()
518+
519+
520+
class IntSubclass(int):
521+
pass
522+
523+
524+
@pytest.mark.parametrize('reverse', [False, True])
525+
@pytest.mark.parametrize(
526+
'core_schema_left,core_schema_right,input_value,expected_value',
527+
[
528+
(core_schema.int_schema(), core_schema.bool_schema(), True, True),
529+
(core_schema.int_schema(), core_schema.bool_schema(), 1, 1),
530+
(core_schema.str_schema(), core_schema.int_schema(), 1, 1),
531+
(core_schema.str_schema(), core_schema.int_schema(), '1', '1'),
532+
(core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1),
533+
(
534+
core_schema.decimal_schema(),
535+
core_schema.int_schema(),
536+
Decimal('1'),
537+
Decimal('1'),
538+
),
539+
(core_schema.decimal_schema(), core_schema.int_schema(), 1, 1),
540+
(
541+
core_schema.decimal_schema(),
542+
core_schema.float_schema(),
543+
Decimal('1.'),
544+
Decimal('1.'),
545+
),
546+
(
547+
core_schema.decimal_schema(),
548+
core_schema.str_schema(),
549+
Decimal('_1'),
550+
Decimal('_1'),
551+
),
552+
(
553+
core_schema.decimal_schema(),
554+
core_schema.str_schema(),
555+
'_1',
556+
'_1',
557+
),
558+
(
559+
core_schema.uuid_schema(),
560+
core_schema.str_schema(),
561+
EXAMPLE_UUID,
562+
EXAMPLE_UUID,
563+
),
564+
(
565+
core_schema.uuid_schema(),
566+
core_schema.str_schema(),
567+
str(EXAMPLE_UUID),
568+
str(EXAMPLE_UUID),
569+
),
570+
],
571+
)
572+
def test_union_serializer_picks_exact_type_over_subclass(
573+
core_schema_left, core_schema_right, input_value, expected_value, reverse
574+
):
575+
s = SchemaSerializer(
576+
core_schema.union_schema(
577+
[core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right]
578+
)
579+
)
580+
assert s.to_python(input_value) == expected_value
581+
582+
583+
@pytest.mark.parametrize('reverse', [False, True])
584+
@pytest.mark.parametrize(
585+
'core_schema_left,core_schema_right,input_value,expected_value',
586+
[
587+
(core_schema.int_schema(), core_schema.bool_schema(), True, True),
588+
(core_schema.int_schema(), core_schema.bool_schema(), 1, 1),
589+
(core_schema.str_schema(), core_schema.int_schema(), 1, 1),
590+
(core_schema.str_schema(), core_schema.int_schema(), '1', '1'),
591+
(core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1),
592+
(
593+
core_schema.decimal_schema(),
594+
core_schema.int_schema(),
595+
Decimal('1'),
596+
'1',
597+
),
598+
(core_schema.decimal_schema(), core_schema.int_schema(), 1, 1),
599+
(
600+
core_schema.decimal_schema(),
601+
core_schema.float_schema(),
602+
Decimal('1.'),
603+
'1',
604+
),
605+
(
606+
core_schema.decimal_schema(),
607+
core_schema.str_schema(),
608+
Decimal('_1'),
609+
'1',
610+
),
611+
(
612+
core_schema.decimal_schema(),
613+
core_schema.str_schema(),
614+
'_1',
615+
'_1',
616+
),
617+
],
618+
)
619+
def test_union_serializer_picks_exact_type_over_subclass_json(
620+
core_schema_left, core_schema_right, input_value, expected_value, reverse
621+
):
622+
s = SchemaSerializer(
623+
core_schema.union_schema(
624+
[core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right]
625+
)
626+
)
627+
assert s.to_python(input_value, mode='json') == expected_value
628+
assert s.to_json(input_value) == json.dumps(expected_value).encode()

0 commit comments

Comments
 (0)