Some slight modifications to @Seanny123's answer:
from itertools import zip_longest
from typing import Union
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]], ignore_args=False) -> bool:
"""Compare two AST nodes for equality."""
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
continue
if ignore_args and k == "args":
continue
if not compare_ast(v, getattr(node2, k), ignore_args):
return False
return True
elif isinstance(node1, list) and isinstance(node2, list):
return all(compare_ast(n1, n2, ignore_args) for n1, n2 in zip_longest(node1, node2))
else:
return node1 == node2
Example:
import ast
node1 = ast.parse("plt.show()").body[0]
node2 = ast.parse("plt.show(*some-args)").body[0]
print(compare_ast(node1, node2, ignore_args=True))
Returns: True