Skip to content

Commit ef007bb

Browse files
committed
fix serializing complex
1 parent 9b21b0f commit ef007bb

File tree

4 files changed

+62
-41
lines changed

4 files changed

+62
-41
lines changed

src/serializers/infer.rs

+13-15
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,9 @@ pub(crate) fn infer_to_python_known(
231231
PyList::new_bound(py, items).into_py(py)
232232
}
233233
ObType::Complex => {
234-
let dict = value.downcast::<PyDict>()?;
235-
let new_dict = PyDict::new_bound(py);
236-
let _ = new_dict.set_item("real", dict.get_item("real")?);
237-
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
238-
new_dict.into_py(py)
234+
let v = value.downcast::<PyComplex>()?;
235+
let complex_str = type_serializers::complex::complex_to_str(v);
236+
complex_str.into_py(py)
239237
}
240238
ObType::Path => value.str()?.into_py(py),
241239
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
@@ -286,11 +284,9 @@ pub(crate) fn infer_to_python_known(
286284
iter.into_py(py)
287285
}
288286
ObType::Complex => {
289-
let dict = value.downcast::<PyDict>()?;
290-
let new_dict = PyDict::new_bound(py);
291-
let _ = new_dict.set_item("real", dict.get_item("real")?);
292-
let _ = new_dict.set_item("imag", dict.get_item("imag")?);
293-
new_dict.into_py(py)
287+
let v = value.downcast::<PyComplex>()?;
288+
let complex_str = type_serializers::complex::complex_to_str(v);
289+
complex_str.into_py(py)
294290
}
295291
ObType::Unknown => {
296292
if let Some(fallback) = extra.fallback {
@@ -422,10 +418,8 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
422418
ObType::Bool => serialize!(bool),
423419
ObType::Complex => {
424420
let v = value.downcast::<PyComplex>().map_err(py_err_se_err)?;
425-
let mut map = serializer.serialize_map(Some(2))?;
426-
map.serialize_entry(&"real", &v.real())?;
427-
map.serialize_entry(&"imag", &v.imag())?;
428-
map.end()
421+
let complex_str = type_serializers::complex::complex_to_str(v);
422+
Ok(serializer.collect_str::<String>(&complex_str)?)
429423
}
430424
ObType::Float | ObType::FloatSubclass => {
431425
let v = value.extract::<f64>().map_err(py_err_se_err)?;
@@ -672,7 +666,7 @@ pub(crate) fn infer_json_key_known<'a>(
672666
}
673667
Ok(Cow::Owned(key_build.finish()))
674668
}
675-
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator | ObType::Complex => {
669+
ObType::List | ObType::Set | ObType::Frozenset | ObType::Dict | ObType::Generator => {
676670
py_err!(PyTypeError; "`{}` not valid as object key", ob_type)
677671
}
678672
ObType::Dataclass | ObType::PydanticSerializable => {
@@ -689,6 +683,10 @@ pub(crate) fn infer_json_key_known<'a>(
689683
// FIXME it would be nice to have a "PyCow" which carries ownership of the Python type too
690684
Ok(Cow::Owned(key.str()?.to_string_lossy().into_owned()))
691685
}
686+
ObType::Complex => {
687+
let v = key.downcast::<PyComplex>()?;
688+
Ok(type_serializers::complex::complex_to_str(v).into())
689+
}
692690
ObType::Pattern => Ok(Cow::Owned(
693691
key.getattr(intern!(key.py(), "pattern"))?
694692
.str()?

src/serializers/ob_type.rs

+2
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ impl ObTypeLookup {
252252
ObType::Url
253253
} else if ob_type == self.multi_host_url {
254254
ObType::MultiHostUrl
255+
} else if ob_type == self.complex {
256+
ObType::Complex
255257
} else if ob_type == self.uuid_object.as_ptr() as usize {
256258
ObType::Uuid
257259
} else if is_pydantic_serializable(op_value) {

src/serializers/type_serializers/complex.rs

+19-26
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,10 @@ impl TypeSerializer for ComplexSerializer {
3333
) -> PyResult<PyObject> {
3434
let py = value.py();
3535
match value.downcast::<PyComplex>() {
36-
Ok(py_complex) => match extra.mode {
37-
SerMode::Json => {
38-
let re = py_complex.real();
39-
let im = py_complex.imag();
40-
let mut s = format!("{im}j");
41-
if re != 0.0 {
42-
let mut sign = "";
43-
if im >= 0.0 {
44-
sign = "+";
45-
}
46-
s = format!("{re}{sign}{s}");
47-
}
48-
Ok(s.into_py(py))
49-
}
50-
_ => Ok(value.into_py(py)),
51-
},
36+
Ok(py_complex) => Ok(match extra.mode {
37+
SerMode::Json => complex_to_str(py_complex).into_py(py),
38+
_ => value.into_py(py),
39+
}),
5240
Err(_) => {
5341
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
5442
infer_to_python(value, include, exclude, extra)
@@ -70,16 +58,7 @@ impl TypeSerializer for ComplexSerializer {
7058
) -> Result<S::Ok, S::Error> {
7159
match value.downcast::<PyComplex>() {
7260
Ok(py_complex) => {
73-
let re = py_complex.real();
74-
let im = py_complex.imag();
75-
let mut s = format!("{im}j");
76-
if re != 0.0 {
77-
let mut sign = "";
78-
if im >= 0.0 {
79-
sign = "+";
80-
}
81-
s = format!("{re}{sign}{s}");
82-
}
61+
let s = complex_to_str(py_complex);
8362
Ok(serializer.collect_str::<String>(&s)?)
8463
}
8564
Err(_) => {
@@ -93,3 +72,17 @@ impl TypeSerializer for ComplexSerializer {
9372
"complex"
9473
}
9574
}
75+
76+
pub fn complex_to_str(py_complex: &Bound<'_, PyComplex>) -> String {
77+
let re = py_complex.real();
78+
let im = py_complex.imag();
79+
let mut s = format!("{im}j");
80+
if re != 0.0 {
81+
let mut sign = "";
82+
if im >= 0.0 {
83+
sign = "+";
84+
}
85+
s = format!("{re}{sign}{s}");
86+
}
87+
s
88+
}

tests/serializers/test_infer.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from enum import Enum
2+
3+
from pydantic_core import SchemaSerializer, core_schema
4+
5+
6+
# serializing enum calls methods in serializers::infer
7+
def test_infer_to_python():
8+
class MyEnum(Enum):
9+
complex_ = complex(1, 2)
10+
11+
v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
12+
assert v.to_python(MyEnum.complex_, mode='json') == '1+2j'
13+
14+
15+
def test_infer_serialize():
16+
class MyEnum(Enum):
17+
complex_ = complex(1, 2)
18+
19+
v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
20+
assert v.to_json(MyEnum.complex_) == b'"1+2j"'
21+
22+
23+
def test_infer_json_key():
24+
class MyEnum(Enum):
25+
complex_ = {complex(1, 2): 1}
26+
27+
v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
28+
assert v.to_json(MyEnum.complex_) == b'{"1+2j":1}'

0 commit comments

Comments
 (0)