6

I have datatypes which look like this:

#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Matrix {
    #[serde(rename = "numColumns")]
    pub num_cols: usize,
    #[serde(rename = "numRows")]
    pub num_rows: usize,
    pub data: Vec<f64>,
}

My JSON bodies look something like this:

{
    "numRows": 2,
    "numColumns": 1,
    "data": [1.0, "NaN"]
}

This is the serialization provided by Jackson (from a Java server we use), and is valid JSON. Unfortunately if we call serde_json::from_str(&blob) we get an error:

Error("invalid type: string "NaN", expected f64", [snip]

I understand there are subtleties around floating point numbers and people get very opinionated about the way things ought to be. I respect that. Rust in particular likes to be very opinionated, and I like that.

However at the end of the day these JSON blobs are what I'm going to receive, and I need that "NaN" string to deserialize to some f64 value where is_nan() is true, and which serialized back to the string "NaN", because the rest of the ecosystem uses Jackson and this is fine there.

Can this be achieved in a reasonable way?

Edit: the suggested linked questions talk about overriding the derived derializer, but they do not explain how to deserialize floats specifically.

Richard Rast
  • 1,772
  • 1
  • 14
  • 27
  • 1
    I believe your question is answered by the answers of [How to transform fields during deserialization using Serde?](https://stackoverflow.com/q/46753955/155423) and [How to transform fields during serialization using Serde?](https://stackoverflow.com/q/39383809/155423). If you disagree, please **[edit]** your question to explain the differences. Otherwise, we can mark this question as already answered. – Shepmaster Jan 31 '19 at 03:45
  • 1
    See also [Is there is a simpler way to convert a type upon deserialization?](https://stackoverflow.com/q/44836327/155423); [Convert two types into a single type with Serde](https://stackoverflow.com/q/37870428/155423). – Shepmaster Jan 31 '19 at 03:47
  • Possible duplicate of [How to transform fields during deserialization using Serde?](https://stackoverflow.com/questions/46753955/how-to-transform-fields-during-deserialization-using-serde) – Stargateur Jan 31 '19 at 10:53
  • 1
    Adding up all of those suggested posts kind of leads to an answer (that answer being, write a custom serializer / deserializer and figure it out yourself) but considering this is a not-rare use case with some closed-but-not-resolved github issues on the serde project, it seems like having a canonical answer (or better yet, a serde flag) would be a bit nicer. – Richard Rast Jan 31 '19 at 13:11
  • The title is misleading: deserializing `NaN` isn't the full problem, it's that you have the float in a list that makes it particularly tricky. Edited title would be helpful. – Cornelius Roemer Mar 17 '23 at 19:09

1 Answers1

1

It actually seems like using a custom deserializer inside a Vec (or Map or etc.) is an open issue on serde and has been for a little over a year (as of time of writing): https://github.com/serde-rs/serde/issues/723

I believe the solution is to write a custom deserializer for f64 (which is fine), as well as everything which uses f64 as a subthing (e.g. Vec<f64>, HashMap<K, f64>, etc.). Unfortunately it does not seem like these things are composable, as implementations of these methods look like

deserialize<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where D: Deserializer<'de> { /* snip */ }

and once you have a Deserializer you can only interact with it through visitors.

Long story short, I eventually got it working, but it seems like a lot of code that shouldn't be necessary. Posting it here in the hopes that either (a) someone knows how to clean this up, or (b) this is really how it should be done, and this answer will be useful to someone. I've spent a whole day fervently reading docs and making trial and error guesses, so maybe this will be useful to someone else. The functions (de)serialize_float(s) should be used with an appropriate #[serde( (de)serialize_with="etc." )] above the field name.

use serde::de::{self, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;

type Float = f64;

const NAN: Float = std::f64::NAN;

struct NiceFloat(Float);

impl Serialize for NiceFloat {
    #[inline]
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        serialize_float(&self.0, serializer)
    }
}

pub fn serialize_float<S>(x: &Float, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    if x.is_nan() {
        serializer.serialize_str("NaN")
    } else {
        serializer.serialize_f64(*x)
    }
}

pub fn serialize_floats<S>(floats: &[Float], serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    let mut seq = serializer.serialize_seq(Some(floats.len()))?;

    for f in floats {
        seq.serialize_element(&NiceFloat(*f))?;
    }

    seq.end()
}

struct FloatDeserializeVisitor;

impl<'de> Visitor<'de> for FloatDeserializeVisitor {
    type Value = Float;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("a float or the string \"NaN\"")
    }

    fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        if v == "NaN" {
            Ok(NAN)
        } else {
            Err(E::invalid_value(de::Unexpected::Str(v), &self))
        }
    }
}

struct NiceFloatDeserializeVisitor;

impl<'de> Visitor<'de> for NiceFloatDeserializeVisitor {
    type Value = NiceFloat;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("a float or the string \"NaN\"")
    }

    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(NiceFloat(v as Float))
    }

    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(NiceFloat(v as Float))
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        if v == "NaN" {
            Ok(NiceFloat(NAN))
        } else {
            Err(E::invalid_value(de::Unexpected::Str(v), &self))
        }
    }
}

pub fn deserialize_float<'de, D>(deserializer: D) -> Result<Float, D::Error>
where
    D: Deserializer<'de>,
{
    deserializer.deserialize_any(FloatDeserializeVisitor)
}

impl<'de> Deserialize<'de> for NiceFloat {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        let raw = deserialize_float(deserializer)?;
        Ok(NiceFloat(raw))
    }
}

pub struct VecDeserializeVisitor<T>(std::marker::PhantomData<T>);

impl<'de, T> Visitor<'de> for VecDeserializeVisitor<T>
where
    T: Deserialize<'de> + Sized,
{
    type Value = Vec<T>;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("A sequence of floats or \"NaN\" string values")
    }

    fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
    where
        S: SeqAccess<'de>,
    {
        let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(0));

        while let Some(value) = seq.next_element()? {
            out.push(value);
        }

        Ok(out)
    }
}

pub fn deserialize_floats<'de, D>(deserializer: D) -> Result<Vec<Float>, D::Error>
where
    D: Deserializer<'de>,
{
    let visitor: VecDeserializeVisitor<NiceFloat> = VecDeserializeVisitor(std::marker::PhantomData);

    let seq: Vec<NiceFloat> = deserializer.deserialize_seq(visitor)?;

    let raw: Vec<Float> = seq.into_iter().map(|nf| nf.0).collect::<Vec<Float>>();

    Ok(raw)
}
Richard Rast
  • 1,772
  • 1
  • 14
  • 27