80

I have an enum with the following structure:

enum Expression {
    Add(Add),
    Mul(Mul),
    Var(Var),
    Coeff(Coeff)
}

where the 'members' of each variant are structs.

Now I want to compare if two enums have the same variant. So if I have

let a = Expression::Add({something});
let b = Expression::Add({somethingelse});

cmpvariant(a, b) should be true. I can imagine a simple double match code that goes through all the options for both enum instances. However, I am looking for a fancier solution, if it exists. If not, is there overhead for the double match? I imagine that internally I am just comparing two ints (ideally).

Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366
Ben Ruijl
  • 4,973
  • 3
  • 31
  • 44
  • 1
    Not related to your question, but could you please tell me what the syntax `Add(Add)` means? What is the first `Add` and what is the second one? – Martin Thoma Sep 13 '15 at 21:11
  • 3
    @moose: The first `Add` is the name of the enum variant. The second one is the type of that variant, and it presumes the existence of another type (a `struct` or another `enum`, or perhaps a type alias) named `Add`, whose definition is not shown here. Note that the names for the variants do not need to be the same as the names of the types of those variants, that's just how the OP chose to name them. – Benjamin Lindley Sep 14 '15 at 03:30

1 Answers1

94

As of Rust 1.21.0, you can use std::mem::discriminant:

fn variant_eq(a: &Op, b: &Op) -> bool {
    std::mem::discriminant(a) == std::mem::discriminant(b)
}

This is nice because it can be very generic:

fn variant_eq<T>(a: &T, b: &T) -> bool {
    std::mem::discriminant(a) == std::mem::discriminant(b)
}

Before Rust 1.21.0, I'd match on the tuple of both arguments and ignore the contents of the tuple with _ or ..:

struct Add(u8);
struct Sub(u8);

enum Op {
    Add(Add),
    Sub(Sub),
}

fn variant_eq(a: &Op, b: &Op) -> bool {
    match (a, b) {
        (&Op::Add(..), &Op::Add(..)) => true,
        (&Op::Sub(..), &Op::Sub(..)) => true,
        _ => false,
    }
}

fn main() {
    let a = Op::Add(Add(42));
    
    let b = Op::Add(Add(42));
    let c = Op::Add(Add(21));
    let d = Op::Sub(Sub(42));

    println!("{}", variant_eq(&a, &b));
    println!("{}", variant_eq(&a, &c));
    println!("{}", variant_eq(&a, &d));
}

I took the liberty of renaming the function though, as the components of enums are called variants, and really you are testing to see if they are equal, not comparing them (which is usually used for ordering / sorting).

For performance, let's look at the LLVM IR in generated by Rust 1.60.0 in release mode (and marking variant_eq as #[inline(never)]). The Rust Playground can show you this:

; playground::variant_eq
; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind readonly uwtable willreturn
define internal fastcc noundef zeroext i1 @_ZN10playground10variant_eq17hc64d59c7864eb861E(i8 %a.0.0.val, i8 %b.0.0.val) unnamed_addr #2 {
start:
  %_8.not = icmp eq i8 %a.0.0.val, %b.0.0.val
  ret i1 %_8.not
}

This code directly compares the variant discriminant.

If you wanted to have a macro to generate the function, something like this might be good start.

struct Add(u8);
struct Sub(u8);

macro_rules! foo {
        (enum $name:ident {
            $($vname:ident($inner:ty),)*
        }) => {
            enum $name {
                 $($vname($inner),)*
            }

            impl $name {
                fn variant_eq(&self, b: &Self) -> bool {
                    match (self, b) {
                        $((&$name::$vname(..), &$name::$vname(..)) => true,)*
                        _ => false,
                    }
                }
            }
        }
    }

foo! {
    enum Op {
        Add(Add),
        Sub(Sub),
    }
}

fn main() {
    let a = Op::Add(Add(42));

    let b = Op::Add(Add(42));
    let c = Op::Add(Add(21));
    let d = Op::Sub(Sub(42));

    println!("{}", Op::variant_eq(&a, &b));
    println!("{}", Op::variant_eq(&a, &c));
    println!("{}", Op::variant_eq(&a, &d));
}

The macro does have limitations though - all the variants need to have a single variant. Supporting unit variants, variants with more than one type, struct variants, visibility, etc are all real hard. Perhaps a procedural macro would make it a bit easier.

Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366
  • Thanks! I will edit my original post to reflect the correct language. Do you have an idea of the overhead of the double matching? – Ben Ruijl Sep 13 '15 at 20:45
  • 1
    The [discriminant_value](http://doc.rust-lang.org/std/intrinsics/fn.discriminant_value.html) exists to do this though being an intrinsic it is unstable. It is at least used when generating code for `deriving` so atleast the ==, < etc operators are as fast as they could be when automatically derived. – Mar Sep 13 '15 at 21:29
  • how can I make this function be generic over enum types ? or can it be put as a macro ? – shakram02 Mar 20 '17 at 14:22
  • 9
    It is still ugly. I would like to have something like `some_variant == SomeEnum::Variant(_)`. Instead, I have to write dozens of helpers. – Evgeni Nabokov Apr 18 '20 at 18:43
  • Why don't you just do `if let SomeEnum::Variant(..) = some_variant { /*...*/ }`? – RecursiveExceptionException Sep 15 '20 at 16:14
  • 2
    @RecursiveExceptionException because that's not the goal of the OP. They have two instances of an enum that they wish to compare to each other. Your code solves the problem of "is this one instance this specific variant". That's covered by [How do I conditionally check if an enum is one variant or another?](https://stackoverflow.com/q/51429501/155423) – Shepmaster Sep 15 '20 at 16:20
  • @Shepmaster Should've read the question better, sorry for that! – RecursiveExceptionException Sep 15 '20 at 16:23
  • 2
    This has since been optimized and produces much terser assembly output: https://godbolt.org/z/5s3xcrcoK – Kuba Beránek Apr 11 '22 at 15:34