I'd remove the virtual method calls completely and use a switch instead.
InternalIterator create_iterator(Node* node) {
NodeId id = node->id;
static_assert(NODE_COUNT == 10, "You forgot to add or remove a new case here");
switch (id) {
case BINARY_EXPR: return create_binary_expr_iterator((BinaryExpr*)(node);
case ARRAY_LITERAL: return create_array_literal_iterator((ArrayLiteral*)(node);
case FUNCTION_CALL: return create_function_call_iterator((FunctionCall*)(node);
// The rest ....
}
}
It's quite straight-forward and could most likely increase performance as you won't do a virtual dispatch, and the switch is probably getting converted to a jump table.
However, you can also traverse the ast using the visitor pattern instead of creating abstract iterators. The visitor pattern can be used together with CRTP to do static dispatch would remove some branching, and it's quite convenient working with.
Here's an example how it could look like:
template <class Visitor, class T = void*>
class AstVisitor {
AstTree& tree;
public:
explicit AstVisitor(AstTree& tree) : tree(tree) {}
T begin() { return visit(tree.start); }
protected:
T visit_ident(Node* node) {
return static_cast<Visitor*>(this)->visit_ident(&node->ident);
}
T visit_lit(Node* node) {
return static_cast<Visitor*>(this)->visit_lit(&node->lit);
}
T visit_type(Node* node) {
return static_cast<Visitor*>(this)->visit_type(&node->type);
}
T visit_bin_op(Node* node) {
return static_cast<Visitor*>(this)->visit_bin_op(&node->bin_op);
}
T visit_var_decl(Node* node) {
return static_cast<Visitor*>(this)->visit_var_decl(&node->var_decl);
}
T visit_if_expr(Node* node) {
return static_cast<Visitor*>(this)->visit_if_expr(&node->if_expr);
}
T visit_param_decl(Node* node) {
return static_cast<Visitor*>(this)->visit_param_decl(&node->param_decl);
}
T visit_fn_decl(Node* node) {
return static_cast<Visitor*>(this)->visit_fn_decl(&node->fn_decl);
}
T visit_block(Node* node) {
return static_cast<Visitor*>(this)->visit_block(&node->block);
}
T visit_expression(size_t i) {
Node* node = node_at(i);
auto id = node->id;
switch (id) {
case Node::IDENT: return static_cast<Visitor*>(this)->visit_ident(&node->ident);
case Node::LIT: return static_cast<Visitor*>(this)->visit_lit(&node->lit);
case Node::TYPE: return static_cast<Visitor*>(this)->visit_type(&node->type);
case Node::BIN_OP: return static_cast<Visitor*>(this)->visit_bin_op(&node->bin_op);
case Node::IF_EXPR: return static_cast<Visitor*>(this)->visit_if_expr(&node->if_expr);
default: return nullptr;
}
}
T visit(size_t i) {
Node* node = node_at(i);
auto id = node->id;
switch (id) {
case Node::IDENT: return static_cast<Visitor*>(this)->visit_ident(&node->ident);
case Node::LIT: return static_cast<Visitor*>(this)->visit_lit(&node->lit);
case Node::TYPE: return static_cast<Visitor*>(this)->visit_type(&node->type);
case Node::BIN_OP: return static_cast<Visitor*>(this)->visit_bin_op(&node->bin_op);
case Node::VAR_DECL: return static_cast<Visitor*>(this)->visit_var_decl(&node->var_decl);
case Node::IF_EXPR: return static_cast<Visitor*>(this)->visit_if_expr(&node->if_expr);
case Node::PARAM_DECL: return static_cast<Visitor*>(this)->visit_param_decl(&node->param_decl);
case Node::FN_DECL: return static_cast<Visitor*>(this)->visit_fn_decl(&node->fn_decl);
case Node::BLOCK: return static_cast<Visitor*>(this)->visit_block(&node->block);
case Node::MODULE: return static_cast<Visitor*>(this)->visit_module(&node->module);
default: return nullptr;
}
}
private:
Node* node_at(u32 i) { return &tree.nodes.at(i); }
};
Implementation:
class AstPrinter : public AstVisitor<AstPrinter> {
public:
int indentation = 0;
explicit AstPrinter(AstTree& tree) : AstVisitor(tree) {}
void indent() {
printf("%.*s", indentation*4, " ");
}
void* visit_ident(Node::Ident* node) {
auto name = name_at(node->name);
printf("Ident {" SV_FMT "}", SV_ARG(name));
return nullptr;
}
void* visit_lit(Node::Lit* node) {
auto print_value = [](Node::Lit* node) {
switch (node->type) {
case Type::VOID: return printf("void");
case Type::INT: return printf("%lld", node->value.int64);
case Type::REAL: return printf("%f", node->value.float64);
case Type::NUMBER_OF_TYPE_TYPES:
return printf("<invalid>");
}
};
auto type = name_at(node->type);
printf("Lit { type=" SV_FMT ", value=", SV_ARG(type));
print_value(node);
printf("}");
return nullptr;
}
void* visit_type(Node::Type* node) {
auto type = name_at(node->type);
printf("Type {" SV_FMT " }", SV_ARG(type));
return nullptr;
}
void* visit_bin_op(Node::BinOp* node) {
printf("BinOp { left=");
visit(node->left);
printf(", operation=%d, right=", node->op);
visit(node->right);
printf(" }");
return nullptr;
}
void* visit_var_decl(Node::VarDecl* node) {
auto name = name_at(node->name);
auto type = (node->type != INVALID_TYPE) ? name_at(node->type) : Str("<none>", sizeof("<none>"));
printf("VarDecl { name=" SV_FMT ", type=" SV_FMT ", expr=", SV_ARG(name), SV_ARG(type));
visit_expression(node->expr);
printf("}");
return nullptr;
}
void* visit_if_expr(Node::IfExpr* node) {
printf("IfExpr { condition=");
visit(node->cond);
printf(", left=");
visit(node->left);
printf(", right=");
visit(node->right);
printf("}");
return nullptr;
}
void* visit_param_decl(Node::ParamDecl* node) {
auto name = name_at(node->name);
auto type = (node->type != INVALID_TYPE) ? name_at(node->type) : Str("<none>", sizeof("<none>"));
printf("ParamDecl { name=" SV_FMT ", type=" SV_FMT ", expr=", SV_ARG(name), SV_ARG(type));
if (node->expr != NO_EXPR) { visit_expression(node->expr); } else { printf("<none>"); }
printf("}");
return nullptr;
}
void* visit_fn_decl(Node::FnDecl* node) {
auto name = name_at(node->name);
printf("FnDecl { name=" SV_FMT ", ", SV_ARG(name));
auto params = view_at(node->params);
for (u32 i = 0; i < params.size(); ++i) {
auto param = params[i];
printf("param[%d] = ", i);
visit(param);
printf(", ");
}
printf(", body=");
visit(node->block);
return nullptr;
}
void* visit_block(Node::Block* node) {
printf("Block {\n");
indentation += 1;
indent();
auto stmts = view_at(node->stmts);
for (u32 i = 0; i < stmts.size(); ++i) {
auto stmt = stmts[i];
indent();
printf("stmt[%d] = ", i);
visit(stmt);
printf("\n");
}
indentation -= 1;
indent();
printf("}");
return nullptr;
}
void* visit_module(Node::Module* node) {
auto name = name_at(node->name);
printf("Module {\n");
indentation += 1;
indent();
printf("name = " SV_FMT "\n", SV_ARG(name));
auto stmts = view_at(node->stmts);
for (u32 i = 0; i < stmts.size(); ++i) {
auto stmt = stmts[i];
indent();
printf("stmt[%d] = ", i);
visit(stmt);
printf("\n");
}
indentation -= 1;
indent();
printf("}");
return nullptr;
}
};
This minimizes the amount of branching and uses static dispatch through CRTP instead of relying on dynamic dispatch.