6

How can I use typing to specify that my function accepts only specific callables? For instance, I would like functionality akin to this:

import typing

def accepted_function1():
    pass
def accepted_function2():
    pass

def function_accepting_functions(foo: Literal[accepted_function1, accepted_function2]):
    foo()
STerliakov
  • 4,983
  • 3
  • 15
  • 37
jay
  • 493
  • 1
  • 3
  • 11
  • One thought was to define function-valued elements in an enumerated type, but that doesn't seem to work as expected. (The elements have their own literal, singleton types which `mypy` doesn't seem to recognize as callable.) – chepner Jun 09 '23 at 15:03
  • @chepner yeah, the only way I have figured how to do this might be to make a parent class and refactor the accepted functions as subclasses that implement __call__ appropriately. Then foo can just accept objects of the type of the parent. That sort of structure may not be immediately intuitive to other collaborators though. – jay Jun 09 '23 at 16:20
  • Not sure what @chepner means, but the obvious idea to use Enum works almost perfectly (one type-ignore comment due to mypy bug, probably, and the rest is fine). See [here](https://mypy-play.net/?mypy=master&python=3.11&flags=strict&gist=2dd8ceca28aedfe78e86c3a30f086281) (gist mine). For python below 3.11 (no `enum.member`) see [this answer](https://stackoverflow.com/a/40339397/14401160). – STerliakov Jun 10 '23 at 13:10
  • Possibly a misunderstanding on my part in using an enum. I assumed that, given something like `class Foo(Enum): A1 = accepted_function1`, that `Foo.A1` should have type `function`, not `Literal[Foo.A1]`, and that `Foo.A1` should be callable. Using `mypy 1.1.1`. – chepner Jun 10 '23 at 15:29
  • (Actually, I can *see* part of my obvious misunderstanding.) – chepner Jun 10 '23 at 15:36

5 Answers5

4

Edit: Actual 'Solution'

It's tedious but you can 'literalize' your functions by wrapping them up in an enum. To enable full control over docstrings and parameter specification you'll need a separate enum for each function. Tested with pyright>1.1.310.

from typing import TYPE_CHECKING, Literal
from enum import Enum, member

class _accepted_function1(Enum):
    @staticmethod
    def __call__(x: int): # specify your function signature here
        """My docstring"""
        # write your code here
        pass
    if TYPE_CHECKING:
        _ = __call__
    else:
        _ = member(__call__)


class _accepted_function2(Enum):
    @staticmethod
    def __call__(): # specify your function signature here
        """My docstring"""
        # write your code here
        pass
    if TYPE_CHECKING:
        _ = __call__
    else:
        _ = member(__call__)

accepted_function1 = _accepted_function1._
accepted_function2 = _accepted_function2._

def function_accepting_functions(foo: Literal[accepted_function1, accepted_function2]):
    if foo is accepted_function1:
        foo(3)
    else:
        foo()

function_accepting_functions(accepted_function1)

As far as I know it's not actually possible to do what you want here. The problem is that functions just aren't types neither literals. They're objects that are created at runtime. A function is technically an instance of its signature, so you can only restrict the type on a signature basis. I can present you with three alternatives though.

1. Using an enum

I'd argue this is the most clean solution and surely the one I'd recommend. Although at least pyright kinda messes things up there. Note you need the member(...) call as functions aren't converted to enum members by default.

from enum import Enum, member

def accepted_function1():
    pass

def accepted_function2():
    pass

class AcceptedFunctions(Enum):
   
    if TYPE_CHECKING:
        accepted_function1 = accepted_function1
        accepted_function2 = accepted_function2
    else:
        accepted_function1 = member(accepted_function1)
        accepted_function2 = member(accepted_function2)


def function_accepting_functions(foo: AcceptedFunctions):
    foo.value()

function_accepting_functions(AcceptedFunctions.accepted_function1)

You could use this approach to 'literalize' your functions. You need all your functions to have the same signature though (if they aren't the same signature you'll get a huge headache trying to overload __call__ correctly. I didn't manage to do that as you cannot have forward refs for enum literals)

from enum import Enum, member

def _accepted_function1():
    pass

def _accepted_function2():
    pass

class _AcceptedFunction(Enum):

    if TYPE_CHECKING:
        accepted_function1 = _accepted_function1
        accepted_function2 = _accepted_function2
    else:
        accepted_function1 = member(_accepted_function1)
        accepted_function2 = member(_accepted_function2)

    def __call__(self) -> None:
        return self.value()
    
accepted_function1 = _AcceptedFunction.accepted_function1
accepted_function2 = _AcceptedFunction.accepted_function2

def function_accepting_functions(foo: Literal[accepted_function1, accepted_function2]):
    foo()

function_accepting_functions(accepted_function1)

2. Using a decorator and a custom function type

This solution just fools the type checker into treating your accepted_function* as a different type. You'll loose access to any function attributes that way though. You can implement proxy-properties though if you wanna - check typeshed/types.pyi::FunctionType to see properties that exist on functions - I've implemented closure for you but you could add all other attributes normal functions have add there as well.

from typing import Generic, TypeVar, ParamSpec, Any, Callable
from types import CellType
import sys

T = TypeVar("T")
P = ParamSpec("P")

class acceptable_function(Generic[P, T]):
    
    def __init__(self, func: Callable[P, T]) -> None:
        self.__func = func

    def __call__(self, *args: P.args, **kwds: P.kwargs) -> T:
        return self.__func(*args, **kwds)
    
    @property
    def __closure__(self) -> tuple[CellType, ...] | None: 
        return self.__func.__closure__

    # you may add more properties that exist on functions...

@acceptable_function
def accepted_function1():
    pass

@acceptable_function
def accepted_function2():
    pass

def function_accepting_functions(foo: acceptable_function):
    foo()

function_accepting_functions(accepted_function1)

3. Add some restrictions on the function signature

If you have control over the signatures of the accepted functions and they all have the same signature you can add a 'secret' parameter to ensure no other functions are passed.


from typing import Protocol

class __Helper:
    pass

class AcceptedFunction(Protocol):

    def __call__(self, *, __secret: __Helper = __Helper()) -> None:
        ...


def accepted_function1(*, __secret: __Helper = __Helper()):
    pass

def accepted_function2(*, __secret: __Helper = __Helper()):
    pass

def function_accepting_functions(foo: AcceptedFunction):
    foo()

function_accepting_functions(accepted_function1)
Robin Gugel
  • 853
  • 3
  • 8
0

The other answers either don't meet the question requirements (since they accept callables of the same signature as accepted_function1 and accepted_function2) or operate at runtime and are not checkable with mypy. Below is a solution allowing only the two specific functions you desire when type checked.

from typing import final

@final
class A:
    @staticmethod
    def function():
        pass

@final
class B:
    @staticmethod
    def function():
        pass

def function_accepting_functions(accepted_class: type[A] | type[B]):
    accepted_class.function()

function_accepting_functions(A)
function_accepting_functions(B)

No class can be passed beside A and B because accepted_class is only a union of those and these classes are both @final. A.function represents accepted_funtion1, while B.function represents accepted_function2.

Mario Ishac
  • 5,060
  • 3
  • 21
  • 52
0

For me the best option will be to create a protocol or an abstract class, I am not sure if that will be directly intuitive to everyone. But you can define a nice interface with it and control the valid functions that your function_accepting_functions needs.

I will initially use an abstract class rather than a protocol, just to inherit the __call__ method to make your class callable.

NOTE I have not used Cammel case for classes since there will be used as a callable.

Using abstract class:

from abc import ABC, abstractstaticmethod

class accepted_function(ABC):
    @abstractstaticmethod
    def run_function()
        ...
    
    @classmethod
    def __call__(cls)
        cls.run_function()


class accepted_function2(accepted_function):
    @staticmethod
    def run_function():
        pass

class accepted_function1(accepted_function):
    @staticmethod
    def run_function():
        pass
    

def function_accepting_functions(foo: accepted_function):
    foo()

Using Protocol from typing module, I would remove the __call__ since this makes that in every class the __call__ method would need to be implemented and there will be a lot of repeated code. In that case, just add foo.run_function()

from typing import Protocol

class accepted_function(Protocol):
    @staticmethod
    def run_function()
        ...

class accepted_function2:
    @staticmethod
    def run_function():
        pass


class accepted_function1:
    @staticmethod
    def run_function():
        pass
    

def function_accepting_functions(foo: accepted_function):
    foo.run_function()
Lucas M. Uriarte
  • 2,403
  • 5
  • 19
-2

It is posible but not with typing (or at least not as easily) as every function inherits from the built-in class function, it's easier to get it if you think functions are lambdas.

To know if a two functions are the same you just need to compare them, with is preferably; but what does that have to do with type hints?, well, python gives you "full" access to the function meaning that you can read the typing of any given argument and implement logic to handle specific needs.

One drawback is that it can give you none existing errors in you ide.

Here i'm comparing to the exact function but in the same way you can filter by regex or by parent class. One additional note, use enumerate because __annotations__ doesn't include the arguments without type and it does contain the hints from **kwargs and the return type of the function if given, so additional checks are required.

def fake_assert(valid, message):
    if not valid:
        print(message)
    return not valid

def ensure_call_hint(foo):
    hints = foo.__annotations__
    var_names = foo.__code__.co_varnames
    def wrapper(*args, **kwargs):
        for i, arg in enumerate(var_names):
            if arg in hints:
                if isinstance(hints[arg], list | tuple):
                    if callable(args[i]):
                        # assert args[i] in hints[arg], 'Invalid function'
                        if fake_assert(args[i] in hints[arg], '(1) Invalid function'):
                            return lambda *a, **b: None
                    else:
                        # assert type(args[i]) in hints[arg], 'Invalid type'
                        if fake_assert(type(args[i]) in hints[arg], '(2) Invalid type'):
                            return lambda *a, **b: None
                elif callable(args[i]) and callable(hints[arg]):
                    # assert args[i] is hints[arg], 'Invalid function'
                    if fake_assert(args[i] is hints[arg], '(3) Invalid function'):
                        return lambda *a, **b: None
                else:
                    # assert type(args[i]) is hints[arg], 'Invalid type'
                    if fake_assert(type(args[i]) is hints[arg], '(4) Invalid type'):
                        return lambda *a, **b: None

        print('(0) Valid call')
        return foo(*args, **kwargs)
    return wrapper


def faa():
    pass
def foo():
    pass
def fii():
    pass

class Impostor:
    @staticmethod
    def faa():
        pass


@ensure_call_hint
def func(func_arg: (faa, foo, int), var_b: fii):
    if callable(func_arg):
        func_arg()


print('func_arg valid    ', end=''); func(faa, fii)
print('func_arg invalid  ', end=''); func(fii, fii)
print('func_arg impostor ', end=''); func(Impostor.faa, fii)
print('func_arg int      ', end=''); func(0, fii)
print('func_arg str      ', end=''); func('0', fii)

print('var_b valid       ', end=''); func(faa, fii)
print('var_b invalid     ', end=''); func(faa, faa)
print('var_b int         ', end=''); func(faa, 0)
func_arg valid    (0) Valid call
func_arg invalid  (1) Invalid function
func_arg impostor (1) Invalid function
func_arg int      (0) Valid call
func_arg str      (2) Invalid type
var_b valid       (0) Valid call
var_b invalid     (3) Invalid function
var_b int         (4) Invalid type
SrPanda
  • 854
  • 1
  • 5
  • 9
-3

I think it's better to use the Callable type from the typing module along with Union or Tuple to specify that your function accepts only specific callables using type hints. An example:

from typing import Callable, Union

def accepted_function1():
    pass

def accepted_function2():
    pass

def function_accepting_functions(foo: Union[Callable[[], None], Callable[[], None]]):
    foo()

# Alternatively, using Tuple:
def function_accepting_functions(foo: Tuple[Callable[[], None], Callable[[], None]]):
    foo[0]()

Note that Union and Tuple are just two different ways to achieve similar functionality. You can choose the one that fits your requirements better.

Also, please note that the Literal type hint you mentioned in your question is primarily used for specifying exact literal values, such as specific strings or numbers, rather than callables. In this case, Callable is the appropriate type of hint to use.

Here's an example demonstrating the usage of type hints to specify that a function accepts only specific callable:

from typing import Callable, Union

def add(a: int, b: int) -> int:
    return a + b

def subtract(a: int, b: int) -> int:
    return a - b

def multiply(a: int, b: int) -> int:
    return a * b

def function_accepting_functions(foo: Union[Callable[[int, int], int], Callable[[int, int], int]]):
    result = foo(5, 3)
    print("Result:", result)

# Using add function
function_accepting_functions(add)  # Output: Result: 8

# Using subtract function
function_accepting_functions(subtract)  # Output: Result: 2

# Using multiply function
function_accepting_functions(multiply)  # Output: Result: 15

# Using a function that does not match the specified signature will raise a type error.
# For example, passing a function that takes different arguments or returns a different type.
# function_accepting_functions(len)  # Raises a type error

In this example, function_accepting_functions accepts callables that take two int arguments and return an int. The foo parameter can be any function that matches this signature.

You can pass different functions as arguments to function_accepting_functions, such as add, subtract, or multiply, and the function will invoke the passed function accordingly.

  • >{Note that Union and Tuple are just two different ways to achieve similar functionality} What do you mean by that? They're not related in any way. Furthermore, `Union[T, T]` is exactly the same as just `T`. What the OP was asking is how to restrict the things you can call `foo` with a known set of functions, not just any function that matches the signature. – decorator-factory Jun 21 '23 at 09:02