2

Background: I'm working on a compiler of a complex language (therefore there are many different types of AST nodes). I need a way to traverse an AST, and currently I use virtual functions to create iterators for different AST nodes. For example, operator++ is implemented as:

namespace ast {

class Iterator {
  Iterator operator++() {
    if (++curr_it_)
      return curr_it_;

    auto curr_node_ = /* get the neighbor */;
    if (!curr_node_)
      return nullptr;

    return (curr_it_ = curr_node_->create_iterator());
  }
};

InternalIterator BinaryExpr::create_iterator() { /* ... */ }
InternalIterator ArrayLiteral::create_iterator() { /* ... */ }
InternalIterator FunctionCall::create_iterator() { /* ... */ }

}  // namespace ast

Problem: the profiling result shows that the call to create_iterator causes 50% branch misprediction rate. What can I do to mitigate it?

P.S. I use valgrind --tool=cachegrind --branch-sim=yes to profile my compiler, and most of the mispredictions come from indirect function call.

rhanqtl
  • 91
  • 5
  • 1
    Not a lot, probably. Is this causing a real-world performance problem? – Paul Sanders Jul 27 '23 at 13:29
  • @PaulSanders Yes, it increases the time by 10%. – rhanqtl Jul 27 '23 at 13:31
  • What is inside create_iterator? In other words are you sure it is the virtual call itself or the content of the create? Maybe you need to find a way to traverse the tree without an iterator? (iterators come with creation overhead, so maybe good old direct tree traversal is needed here) – Pepijn Kramer Jul 27 '23 at 13:31
  • 4
    Would profile-guided optimization help? https://en.wikipedia.org/wiki/Profile-guided_optimization – Benjamin Bihler Jul 27 '23 at 13:32
  • Note that the input data for the profile-guided optimizations are critical. A bad input can slow down performance for usual cases if it is not representative of real-world user inputs. Note you can specify that a branch is likely or unlikely to be taken but developers are notoriously bad at guessing such probabilities in practice (specified in the GCC documentation) except for branches involving error checks. Besides, note you can aggregate AST nodes for common blocks (or use a more compact AST representation) to mitigate this overhead. – Jérôme Richard Jul 27 '23 at 13:38
  • 1
    Make it non-virtual. This may entail making it ugly and non-maintainable, but then most nice things come with a price tag. – n. m. could be an AI Jul 27 '23 at 13:44
  • 1
    Are you reading profiling report properly or you just explain it to us in wrong way? Branch prediction means predict try to predict a conditional jump. Calling virtual method is not the conditional jump (it always happens just jumps to different thing). Maybe tool points to preceding `if` and predicting jump from `if` is bad. In such case [adding `[[unlikely]]`](https://en.cppreference.com/w/cpp/language/attributes/likely) can help, but I would fix code by changing it in such way that `curr_node_` is never `nullptr` making this condition obsolete. – Marek R Jul 27 '23 at 14:21
  • 1
    @MarekR Conditional branch prediction and indirect branch prediction are two different types of branch prediction. While technically separate behaviors, there's more similarities than differences between them and many of the techniques are applicable to both. – Sneftel Jul 27 '23 at 15:24

2 Answers2

4

If you're trying to reduce branch mispredictions, there's only a small number of ways that can be done:

  • Reduce conditional branches by using conditional instructions like CMOV or use case-specific instruction sequences which don't involve branching. For instance, if your various iterators had the same behavior but with different, like, data striding, then you could remove the virtual call and set the striding through a LUT or something. Probably not possible here.

  • Rearrange the data to help the branch predictor. Almost certainly not possible here.

  • Negate the branch condition if a branch is being mispredicted when it isn't in the branch cache. (This is one of the main things PGO would do.) Not applicable if your misprediction rate is 50%, of course.

  • Rearrange the code so that the branch predictor works better. This is extremely fiddly and processor-specific (e.g. you'd be optimizing for, say, a particular model of AMD processor) and unlikely to yield much benefit with modern branch predictors. The one exception here is that short loops [but not TOO short] are easier to predict than long ones, but in this case you don't have a lot of control over that.

That branch is being predicted poorly because branching is what that code does. Branch prediction, at a high level, is a way to dynamically optimize the execution of low-entropy traces. Put differently, it's a way for the CPU to notice and leverage patterns in which branches you take. If there's no patterns to recognize, branch (mis)prediction is not applicable to your situation.

Sneftel
  • 40,271
  • 12
  • 71
  • 104
1

I'd remove the virtual method calls completely and use a switch instead.

InternalIterator create_iterator(Node* node) {
    NodeId id = node->id;

    static_assert(NODE_COUNT == 10, "You forgot to add or remove a new case here");
    switch (id) {
        case BINARY_EXPR:   return create_binary_expr_iterator((BinaryExpr*)(node);
        case ARRAY_LITERAL: return create_array_literal_iterator((ArrayLiteral*)(node);
        case FUNCTION_CALL: return create_function_call_iterator((FunctionCall*)(node);
        // The rest ....
    }
}

It's quite straight-forward and could most likely increase performance as you won't do a virtual dispatch, and the switch is probably getting converted to a jump table.

However, you can also traverse the ast using the visitor pattern instead of creating abstract iterators. The visitor pattern can be used together with CRTP to do static dispatch would remove some branching, and it's quite convenient working with.

Here's an example how it could look like:

template <class Visitor, class T = void*>
class AstVisitor {
    AstTree& tree;
public:
    explicit AstVisitor(AstTree& tree) : tree(tree) {}

    T begin() { return visit(tree.start); }

protected:
    T visit_ident(Node* node) {
        return static_cast<Visitor*>(this)->visit_ident(&node->ident);
    }

    T visit_lit(Node* node) {
        return static_cast<Visitor*>(this)->visit_lit(&node->lit);
    }

    T visit_type(Node* node) {
        return static_cast<Visitor*>(this)->visit_type(&node->type);
    }

    T visit_bin_op(Node* node) {
        return static_cast<Visitor*>(this)->visit_bin_op(&node->bin_op);
    }

    T visit_var_decl(Node* node) {
        return static_cast<Visitor*>(this)->visit_var_decl(&node->var_decl);
    }

    T visit_if_expr(Node* node) {
        return static_cast<Visitor*>(this)->visit_if_expr(&node->if_expr);
    }

    T visit_param_decl(Node* node) {
        return static_cast<Visitor*>(this)->visit_param_decl(&node->param_decl);
    }

    T visit_fn_decl(Node* node) {
        return static_cast<Visitor*>(this)->visit_fn_decl(&node->fn_decl);
    }

    T visit_block(Node* node) {
        return static_cast<Visitor*>(this)->visit_block(&node->block);
    }

    T visit_expression(size_t i) {
        Node* node = node_at(i);
        auto  id   = node->id;
        switch (id) {
            case Node::IDENT:     return static_cast<Visitor*>(this)->visit_ident(&node->ident);
            case Node::LIT:       return static_cast<Visitor*>(this)->visit_lit(&node->lit);
            case Node::TYPE:      return static_cast<Visitor*>(this)->visit_type(&node->type);
            case Node::BIN_OP:    return static_cast<Visitor*>(this)->visit_bin_op(&node->bin_op);
            case Node::IF_EXPR:   return static_cast<Visitor*>(this)->visit_if_expr(&node->if_expr);
            default:  return nullptr;
        }
    }

    T visit(size_t i) {
        Node* node = node_at(i);
        auto  id   = node->id;
        switch (id) {
            case Node::IDENT:     return static_cast<Visitor*>(this)->visit_ident(&node->ident);
            case Node::LIT:       return static_cast<Visitor*>(this)->visit_lit(&node->lit);
            case Node::TYPE:      return static_cast<Visitor*>(this)->visit_type(&node->type);
            case Node::BIN_OP:    return static_cast<Visitor*>(this)->visit_bin_op(&node->bin_op);
            case Node::VAR_DECL:  return static_cast<Visitor*>(this)->visit_var_decl(&node->var_decl);
            case Node::IF_EXPR:   return static_cast<Visitor*>(this)->visit_if_expr(&node->if_expr);
            case Node::PARAM_DECL: return static_cast<Visitor*>(this)->visit_param_decl(&node->param_decl);
            case Node::FN_DECL:   return static_cast<Visitor*>(this)->visit_fn_decl(&node->fn_decl);
            case Node::BLOCK:     return static_cast<Visitor*>(this)->visit_block(&node->block);
            case Node::MODULE:    return static_cast<Visitor*>(this)->visit_module(&node->module);
            default: return nullptr;
        }
    }

private:
    Node* node_at(u32 i) { return &tree.nodes.at(i); }
};

Implementation:

class AstPrinter : public AstVisitor<AstPrinter> {
public:
    int indentation = 0;

    explicit AstPrinter(AstTree& tree) : AstVisitor(tree) {}

    void indent() {
        printf("%.*s", indentation*4, "                                                                              ");
    }

    void* visit_ident(Node::Ident* node) {
        auto name = name_at(node->name);
        printf("Ident {" SV_FMT "}", SV_ARG(name));
        return nullptr;
    }

    void* visit_lit(Node::Lit* node) {
        auto print_value = [](Node::Lit* node) {
            switch (node->type) {
                case Type::VOID:    return printf("void");
                case Type::INT:     return printf("%lld", node->value.int64);
                case Type::REAL:    return printf("%f",   node->value.float64);
                case Type::NUMBER_OF_TYPE_TYPES:
                    return printf("<invalid>");
            }
        };

        auto type = name_at(node->type);
        printf("Lit { type=" SV_FMT ", value=", SV_ARG(type));
        print_value(node);
        printf("}");
        return nullptr;
    }

    void* visit_type(Node::Type* node) {
        auto type = name_at(node->type);
        printf("Type {" SV_FMT " }", SV_ARG(type));
        return nullptr;
    }

    void* visit_bin_op(Node::BinOp* node) {
        printf("BinOp { left=");
        visit(node->left);
        printf(", operation=%d, right=", node->op);
        visit(node->right);
        printf(" }");
        return nullptr;
    }

    void* visit_var_decl(Node::VarDecl* node) {
        auto name = name_at(node->name);
        auto type = (node->type != INVALID_TYPE) ? name_at(node->type) : Str("<none>", sizeof("<none>"));
        printf("VarDecl { name=" SV_FMT ", type=" SV_FMT ", expr=", SV_ARG(name), SV_ARG(type));
        visit_expression(node->expr);
        printf("}");
        return nullptr;
    }

    void* visit_if_expr(Node::IfExpr* node) {
        printf("IfExpr { condition=");
        visit(node->cond);
        printf(", left=");
        visit(node->left);
        printf(", right=");
        visit(node->right);
        printf("}");
        return nullptr;
    }

    void* visit_param_decl(Node::ParamDecl* node) {
        auto name = name_at(node->name);
        auto type = (node->type != INVALID_TYPE) ? name_at(node->type) : Str("<none>", sizeof("<none>"));
        printf("ParamDecl { name=" SV_FMT ", type=" SV_FMT ", expr=", SV_ARG(name), SV_ARG(type));
        if (node->expr != NO_EXPR) { visit_expression(node->expr); } else { printf("<none>"); }
        printf("}");
        return nullptr;
    }

    void* visit_fn_decl(Node::FnDecl* node) {
        auto name = name_at(node->name);
        printf("FnDecl { name=" SV_FMT ", ", SV_ARG(name));

        auto params = view_at(node->params);
        for (u32 i = 0; i < params.size(); ++i) {
            auto param = params[i];
            printf("param[%d] = ", i);
            visit(param);
            printf(", ");
        }

        printf(", body=");
        visit(node->block);

        return nullptr;
    }

    void* visit_block(Node::Block* node) {
        printf("Block {\n");
        indentation += 1;
        indent();

        auto stmts = view_at(node->stmts);
        for (u32 i = 0; i < stmts.size(); ++i) {
            auto stmt = stmts[i];
            indent();
            printf("stmt[%d] = ", i);
            visit(stmt);
            printf("\n");
        }

        indentation -= 1;
        indent();
        printf("}");

        return nullptr;
    }

    void* visit_module(Node::Module* node) {
        auto name = name_at(node->name);

        printf("Module {\n");
        indentation += 1;

        indent();
        printf("name = " SV_FMT "\n", SV_ARG(name));

        auto stmts = view_at(node->stmts);
        for (u32 i = 0; i < stmts.size(); ++i) {
            auto stmt = stmts[i];
            indent();
            printf("stmt[%d] = ", i);
            visit(stmt);
            printf("\n");
        }

        indentation -= 1;
        indent();
        printf("}");
        return nullptr;
    }
};

This minimizes the amount of branching and uses static dispatch through CRTP instead of relying on dynamic dispatch.

Ted Klein Bergman
  • 9,146
  • 4
  • 29
  • 50
  • 1
    Would you expect a jump table jump target to be better predicted, all else being equal, than a vtable jump target? If so, why? And in what way would you suggest the OP use static dispatching to interpret an AST which is given at runtime? – Sneftel Jul 27 '23 at 16:15
  • 2
    virtual calls are slower since most (all) compilers are using vtable to implement virtual calls which has an extra indirection leading to grater chance of cache miss. This switch may improve things since cache miss can be reduced, but this is bad solution from code maintenance point of view. There is alternative to manually implement virtual calls without vtable (using pointers to functions) with help of templates. – Marek R Jul 27 '23 at 16:26
  • 1
    @Sneftel Yes, because with a vtable you need to do a lookup to find the function, and then a dispatch to the function. This means that even if the CPU looks at the code ahead, it have no way of knowing where to jump until the first lookup is made. With a jump table, the CPU knows where to jump as soon as it now the value of the node id. This gives both the compiler and CPU better opportunities to do optimizations. Also, as Marek said, you get better cache utilization as the data is more local and don't make jumps through pointers. – Ted Klein Bergman Jul 27 '23 at 17:34
  • 1
    @Sneftel I rewrote the part about static dispatch, as it was a bit misleading. I meant that you could use it together with the visitor pattern instead of using it with dynamic dispatch (which would require a vtable lookup). Often you can do static dispatch because you know what the type of the node is (or at least should be if your parser did its job correctly). – Ted Klein Bergman Jul 27 '23 at 17:36
  • The branch target predictor doesn't care how many indirections you had to go through to get to your target, though; it just cares where you jumped the last few times. As Marek rightly pointed out doing fewer indirections is kinder to the cache, but it'll have zero effect on the efficacy of branch prediction. – Sneftel Jul 27 '23 at 17:38
  • @Sneftel Yes, you're correct, it has nothing to do with the branch prediction. I read your comment a bit sloppy. It gives the CPU a greater opportunity to know where the function address is and prefetch it. – Ted Klein Bergman Jul 27 '23 at 17:54
  • 1
    @MarekR I'd argue it's more maintainable than other methods like interfaces. When your code changes by adding/removing functions, then using this approach requires less change, as each function added/removed can happen in isolation. If you use use interfaces, then each function added/removed affect every type in your code base. And the opposite is true also, when adding/removing types is more common, then interfaces are better. But for a compiler I'd assume that the amount of nodes stay relatively stagnant while functions operating on the nodes changes quite a bit. – Ted Klein Bergman Jul 27 '23 at 18:02
  • 1
    @TedKleinBergman "If you arguing you are losing!". Here is nice explanation from Uncle Bon [why switch case solution is bad/fragile](https://youtu.be/zHiWqnTWsn4). – Marek R Jul 27 '23 at 18:59
  • @MarekR I do own and have read *Clean Code* by Uncle Bob. I've also watched many of his talks. Here's a discussion from about 4 months ago (quite recently) where the man himself, Uncle Bob, writes what I'm saying: *"And as we both agreed before, and as I wrote in Clean Code (which, by the way, is not the same as your "Clean Code") when operations proliferate more rapidly than types, switch statements are better."* https://github.com/unclebob/cmuratori-discussion/blob/main/cleancodeqa-2.md – Ted Klein Bergman Jul 27 '23 at 19:58