What is preventing your desired results is that you are iterating over the class' __bases__
- these only list the immediate superclasses. If you change your metacass to iterate over __mro__
, Python's linearized sequence of all of one class' ancestors, it will work:
In [14]: class Abstract(metaclass=Meta):
...: def __call__(self):
...: print("Abstract")
...:
...: class Base(Abstract):
...: def __call__(self):
...: print("Base")
...:
...: class Super(Abstract):
...: def __call__(self):
...: print("Super")
...:
...: class Parent:
...: def __call__(self):
...: print("Parent")
...:
...: class Child(Parent):
...: def __call__(self):
...: print("Child")
...:
In [15]: Child.__mro__
Out[15]: (__main__.Child, __main__.Parent, object)
Anyway, this turns out to be a bit trickier than seem at first glance - there are corner cases - what if one of your eligible classes do not feature a __call__
for example? What if one of the methods do include an ordinary "super()" call? Ok, add a marker to avoid unwanted re-entrancy in cases one does put a "super()" - what if it is running in a multithreaded environment and two instances are being
creaed at the same time?
All in all, one has to make the correct combination of using Python's attribute
fetching mechanisms - to pick the methods in the correct instances. I made a choice of copying the original __call__
method to anothr method in the class itself, so it can not only store the original method, but also work as a marker for the eligible classes.
Also, note that this works just the same for __call__
as it would work for any other method - so I factored the name "__call__"
to a constant to ensure that (and it could be made a list of methods, or all methods whose name have a certain prefix, and so on).
from functools import wraps
from threading import local as threading_local
MARKER_METHOD = "_auto_super_original"
AUTO_SUPER = "__call__"
class Meta(type):
def __new__(meta, name, bases, attr):
original_call = attr.pop(AUTO_SUPER, None)
avoid_rentrancy = threading_local()
avoid_rentrancy.running = False
@wraps(original_call)
def recursive_call(self, *args, _wrap_call_mro=None, **kwargs):
if getattr(avoid_rentrancy, "running", False):
return
avoid_rentrancy.running = True
mro = _wrap_call_mro or self.__class__.__mro__
try:
for index, supercls in enumerate(mro[1:], 1):
if MARKER_METHOD in supercls.__dict__:
supercls.__call__(self, *args, _wrap_call_mro=mro[index:], **kwargs)
break
getattr(mro[0], MARKER_METHOD)(self, *args, **kwargs)
finally:
avoid_rentrancy.running = False
if original_call:
attr[MARKER_METHOD] = original_call
attr[AUTO_SUPER] = recursive_call
return super().__new__(
meta, name, bases, attr
)
ANd this is working on the console - I added a few more
intermediate classes to cover for the corner-cases:
class Abstract(metaclass=Meta):
def __call__(self):
print("Abstract")
class Base1(Abstract):
def __call__(self):
print("Base1")
class Base2(Abstract):
def __call__(self):
print("Base2")
class Super(Base1):
def __call__(self):
print("Super")
class NonColaborativeParent():
def __call__(self):
print("Parent")
class ForgotAndCalledSuper(Super):
def __call__(self):
super().__call__()
print("Forgot and called super")
class NoCallParent(Super):
pass
class Child(NoCallParent, ForgotAndCalledSuper, Parent, Base2):
def __call__(self):
print("Child")
Result:
In [96]: Child()()
Abstract
Base2
Base1
Super
Child
Forgot and called super
Child