0

I have some std::variant classes, each with several alternatives, and I would like to define a visitor class template that takes a variant as its template parameter and will automatically define a pure virtual void operator()(T const&) const for each alternative T in the variant. This way, I can define subclasses that inherit from instantiations of these visitor template classes, and will be forced to override each method, defined as pure virtual in its respective base class.

e.g.

#include <variant>

using VarA = std::variant<A1, A2, /* ... more alternatives ... */>;
using VarB = std::variant<B1, B2, /* ... more alternatives ... */>;

struct VarAVisitor : Visitor<VarA>
{
    // Must override 'void operator()(T const&) const' for each alternative type 'T' in VarA
};

struct VarBVisitor : Visitor<VarB>
{
    // Must override 'void operator()(T const&) const' for each alternative type 'T' in VarB
};

Basically, I am asking how would I implement the Visitor class template in the above example?

jinscoe123
  • 1,467
  • 14
  • 24
  • Is there a reason why you are designing a custom visitor instead of using `std::visit()`? Then you could use the [`overload` trick](https://stackoverflow.com/questions/66961406/) to execute a lambda for each alternative. – Remy Lebeau Jun 05 '22 at 19:09
  • The visitor will be passed as the first argument to `std::visit()`. I want to define these visitor subclasses that are forced to override the pure virtual methods of their base template classes so that the compiler will produce more readable error messages if I forget a particular alternative type. – jinscoe123 Jun 05 '22 at 19:12
  • what you are looking for is probably a variation of the "overload" variant trick, but my template kung-fu is weak when it comes to parameter packs, but I don't think I've ever seen any syntax that would allow you to pass in a `variant` type to a visitor's template argument and break apart its individual alternative types into secondary templates that can generate the needed `operator()` overloads you are looking for. It only works if you pass the alternatives themselves directly to the visitor's template arguments. – Remy Lebeau Jun 05 '22 at 19:56

1 Answers1

0

After some some googling and lots of trial and error, I managed to come up with something that does what I want. I'm sharing the solution here for anyone else who comes across the same issue.

Here is a proof of concept.

#include <iostream>
#include <variant>


template <typename> class Test { };

using Foo = std::variant<
    Test<struct A>,
    Test<struct B>,
    Test<struct C>,
    Test<struct D>
    >;

using Bar = std::variant<
    Test<struct E>,
    Test<struct F>,
    Test<struct G>,
    Test<struct H>,
    Test<struct I>,
    Test<struct J>,
    Test<struct K>,
    Test<struct L>
    >;


template <typename T>
struct DefineVirtualFunctor
{
    virtual int operator()(T const&) const = 0;
};

template <template <typename> typename Modifier, typename... Rest>
struct ForEach { };
template <template <typename> typename Modifier, typename T, typename... Rest>
struct ForEach<Modifier, T, Rest...> : Modifier<T>, ForEach<Modifier, Rest...> { };

template <typename Variant>
struct Visitor;
template <typename... Alts>
struct Visitor<std::variant<Alts...>> : ForEach<DefineVirtualFunctor, Alts...> { };


struct FooVisitor final : Visitor<Foo>
{
    int operator()(Test<A> const&) const override { return  0; }
    int operator()(Test<B> const&) const override { return  1; }
    int operator()(Test<C> const&) const override { return  2; }
    int operator()(Test<D> const&) const override { return  3; }
};

struct BarVisitor final : Visitor<Bar>
{
    int operator()(Test<E> const&) const override { return  4; }
    int operator()(Test<F> const&) const override { return  5; }
    int operator()(Test<G> const&) const override { return  6; }
    int operator()(Test<H> const&) const override { return  7; }
    int operator()(Test<I> const&) const override { return  8; }
    int operator()(Test<J> const&) const override { return  9; }
    int operator()(Test<K> const&) const override { return 10; }
    int operator()(Test<L> const&) const override { return 11; }
};


int main(int argc, char const* argv[])
{
    Foo foo;
    Bar bar;
    
    switch (argc) {
    case  0: foo = Foo{ std::in_place_index<0> }; break;
    case  1: foo = Foo{ std::in_place_index<1> }; break;
    case  2: foo = Foo{ std::in_place_index<2> }; break;
    default: foo = Foo{ std::in_place_index<3> }; break;
    }
    switch (argc) {
    case  0: bar = Bar{ std::in_place_index<0> }; break;
    case  1: bar = Bar{ std::in_place_index<1> }; break;
    case  2: bar = Bar{ std::in_place_index<2> }; break;
    case  3: bar = Bar{ std::in_place_index<3> }; break;
    case  4: bar = Bar{ std::in_place_index<4> }; break;
    case  5: bar = Bar{ std::in_place_index<5> }; break;
    case  6: bar = Bar{ std::in_place_index<6> }; break;
    default: bar = Bar{ std::in_place_index<7> }; break;
    }
    
    std::cout << std::visit(FooVisitor{ }, foo) << "\n";
    std::cout << std::visit(BarVisitor{ }, bar) << "\n";

    return 0;
}

As you can see, the Visitor class template accepts a std::variant type as a template parameter, from which it will define an interface that must be implemented in any child classes that inherit from the template class instantiation. If, in a child class, you happen to forget to override one of the pure virtual methods, you will get an error like the following.

$ g++ -std=c++17 -o example example.cc
example.cc: In function ‘int main(int, const char**)’:
example.cc:87:41: error: invalid cast to abstract class type ‘BarVisitor’
   87 |     std::cout << std::visit(BarVisitor{ }, bar) << "\n";
      |                                         ^
example.cc:51:8: note:   because the following virtual functions are pure within ‘BarVisitor’:
   51 | struct BarVisitor final : Visitor<Bar>
      |        ^~~~~~~~~~
example.cc:29:17: note:     ‘int DefineVirtualFunctor<T>::operator()(const T&) const [with T = Test<J>]’
   29 |     virtual int operator()(T const&) const = 0;
      |                 ^~~~~~~~

This is much easier to understand than the error messages that the compiler usually generates when using std::visit().

jinscoe123
  • 1,467
  • 14
  • 24