We start off with a function, such as the following:
def funky_function(x1, x2, x3, x4, /):
return ", ".join(" ".join(str(x).split()) for x in [x1, x2, x3, x4])
After decorating the function, we should have to write parentheses around each and every calling argument.
r = f(1)(2)(3)(4) # GOOD
# r = f(1, 2, 3, 4) # BAD
One potential application is generalizing a decorator.
Some decorators only work on single-argument functions.
We might want a new decorator which works on multi-argument functions.
For example, you could implement function overloading (multiple-dispatching) by generalizing functools.singledispatch
. Other people have already implemented multi-methods; so that is not really what my question is about. However, I wanted to provide some an example application as motivation.
from functools import singledispatch
@singledispatch
def fun(arg):
pass
@fun.register
def _(arg: int):
print("I am processing an integer")
@fun.register
def _(arg: list):
print("I am processing a list")
I attempted to write some code to accomplish this task, but it does not exhibit the desired behavior. Ideally, a decorated function becomes a function of one argument which returns another function.
return_value = f(1)(2)(3)(4)
Here is some code:
from functools import *
from inspect import *
class SeperatorHelper:
def __init__(self, func):
"""`func` should be callable"""
assert(callable(func))
self._func = func
def __call__(self, arg):
return type(self)(partial(self._func, arg))
def seperate_args(old_func, /):
"""
This is a decorator
@seperate_args
def foo(x1, x2, x3, x4, /):
pass
+------------------+------------------+
| NEW | OLD |
+------------------+------------------+
| f(1)(2)(3)(4) | f(1, 2, 3, 4) |
+------------------+------------------+
"""
new_func = SeperatorHelper(old_func)
new_func = wraps(old_func)(new_func)
return new_func
#######################################################
# BEGIN TESTING
#######################################################
@seperate_args
def funky_function(x1, x2, x3, x4, /):
return ", ".join(" ".join(str(x).split()) for x in [x1, x2, x3, x4])
print("signature == ", signature(funky_function))
func_calls = [
"funky_function(1)(2)",
"funky_function(1)(2)(3)(4)",
"funky_function(1)(2)(3)(4)("extra arg")",
"funky_function(1)(2)(3)(4)()()()()"
]
for func_call in func_calls:
try:
ret_val = eval(func_call)
except BaseException as exc:
ret_val = exc
# convert `ret_val` into a string
# and eliminate line-feeds, tabs, carriage-returns...
ret_val = " ".join(str(ret_val).split())
print(40*"-")
print(func_call)
print("return value".ljust(40), )
print(40 * "-")