It actually seems like using a custom deserializer inside a Vec (or Map or etc.) is an open issue on serde and has been for a little over a year (as of time of writing): https://github.com/serde-rs/serde/issues/723
I believe the solution is to write a custom deserializer for f64
(which is fine), as well as everything which uses f64
as a subthing (e.g. Vec<f64>
, HashMap<K, f64>
, etc.). Unfortunately it does not seem like these things are composable, as implementations of these methods look like
deserialize<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where D: Deserializer<'de> { /* snip */ }
and once you have a Deserializer you can only interact with it through visitors.
Long story short, I eventually got it working, but it seems like a lot of code that shouldn't be necessary. Posting it here in the hopes that either (a) someone knows how to clean this up, or (b) this is really how it should be done, and this answer will be useful to someone. I've spent a whole day fervently reading docs and making trial and error guesses, so maybe this will be useful to someone else. The functions (de)serialize_float(s)
should be used with an appropriate #[serde( (de)serialize_with="etc." )]
above the field name.
use serde::de::{self, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
type Float = f64;
const NAN: Float = std::f64::NAN;
struct NiceFloat(Float);
impl Serialize for NiceFloat {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_float(&self.0, serializer)
}
}
pub fn serialize_float<S>(x: &Float, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if x.is_nan() {
serializer.serialize_str("NaN")
} else {
serializer.serialize_f64(*x)
}
}
pub fn serialize_floats<S>(floats: &[Float], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(floats.len()))?;
for f in floats {
seq.serialize_element(&NiceFloat(*f))?;
}
seq.end()
}
struct FloatDeserializeVisitor;
impl<'de> Visitor<'de> for FloatDeserializeVisitor {
type Value = Float;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a float or the string \"NaN\"")
}
fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(v as Float)
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(v as Float)
}
fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(v as Float)
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(v as Float)
}
fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(v as Float)
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(v as Float)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
if v == "NaN" {
Ok(NAN)
} else {
Err(E::invalid_value(de::Unexpected::Str(v), &self))
}
}
}
struct NiceFloatDeserializeVisitor;
impl<'de> Visitor<'de> for NiceFloatDeserializeVisitor {
type Value = NiceFloat;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a float or the string \"NaN\"")
}
fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(NiceFloat(v as Float))
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(NiceFloat(v as Float))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
if v == "NaN" {
Ok(NiceFloat(NAN))
} else {
Err(E::invalid_value(de::Unexpected::Str(v), &self))
}
}
}
pub fn deserialize_float<'de, D>(deserializer: D) -> Result<Float, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(FloatDeserializeVisitor)
}
impl<'de> Deserialize<'de> for NiceFloat {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let raw = deserialize_float(deserializer)?;
Ok(NiceFloat(raw))
}
}
pub struct VecDeserializeVisitor<T>(std::marker::PhantomData<T>);
impl<'de, T> Visitor<'de> for VecDeserializeVisitor<T>
where
T: Deserialize<'de> + Sized,
{
type Value = Vec<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A sequence of floats or \"NaN\" string values")
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: SeqAccess<'de>,
{
let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(value) = seq.next_element()? {
out.push(value);
}
Ok(out)
}
}
pub fn deserialize_floats<'de, D>(deserializer: D) -> Result<Vec<Float>, D::Error>
where
D: Deserializer<'de>,
{
let visitor: VecDeserializeVisitor<NiceFloat> = VecDeserializeVisitor(std::marker::PhantomData);
let seq: Vec<NiceFloat> = deserializer.deserialize_seq(visitor)?;
let raw: Vec<Float> = seq.into_iter().map(|nf| nf.0).collect::<Vec<Float>>();
Ok(raw)
}