2

I want to write a decorator that "pre-resolves" some expressions in a function. That is,

CONSTS = {'pos': 1}

@resolve(CONSTS)
def f(x: list):
    return x[CONSTS['pos']]

should be the same as if I had written

def f(x: list):
    return x[1]

The only use will be to substitute string expressions from a single dict with numerical values.

Is it possible to do this, perhaps using ast?

(Reason: the transformed function is being passed to another decorator that performs further ast magic, and does not support the usage of dict.)

nth
  • 1,442
  • 15
  • 12
  • The decorator cannot really change the *inside* of the function it decorates. It can only wrap a call to `f` inside another function, add attributes to `f`, etc. – chepner Jul 21 '23 at 15:17
  • It's definitely possible to do this by creating a new code object with different `co_code` and `names`, although this is probably not the nicest solution – Lecdi Jul 21 '23 at 15:32
  • Why do you want to do this? Could you give us some background? – md2perpe Jul 21 '23 at 16:17
  • @md2perpe The example I gave was a sort of MRE. The real code implements a complicated probabilistic algorithm. To improve readability, and work around the limitations of Python type identifiers, the code makes liberal use of dicts that represent mathematical quantities. For example, the probability $\mathbb{P}(X>i)$ is written in the code as `p['X>i']`, which contains a Numpy vector. Now I need to make the code run faster by switching to `numba.cuda`. However, Numba's CUDA implementation only supports a limited subset of Python. No dicts. – nth Jul 21 '23 at 16:40
  • I am not sure if the answer solves your question. I am not sure from your last comment if the equation initially returns a list or a dictionary. – Lucas M. Uriarte Jul 21 '23 at 17:15

1 Answers1

2

This was a fun exercise in metaprogramming.

Because this uses AST rewriting, you need access to the source code of any file you use @resolve in.

This also requires you to call resolve with a variable name for the constant dictionary, although it can probably support more complex expressions with very little changes.

If you decide to change the name of resolve (for example to resolve_constants or something, don't forget to also change ast.Name('resolve') in visit_FunctionDef.

It should theoretically work in combination with other decorators, however I haven't tested that at all.

import ast
from functools import partial
import inspect
import sys

CONSTS = {'pos': 1}

class Resolver(ast.NodeTransformer):
    def __init__(self, consts):
        self.consts = consts

    def visit_Subscript(self, node):
        match node:
            case ast.Subscript(ast.Name(self.consts_name), ast.Constant(value)):
                return ast.Constant(self.consts[value])
        return super().generic_visit(node)

    def visit_FunctionDef(self, node):
        # prevent infinite invocation of resolve
        # and also record the name of the constants dictionary
        for deco in node.decorator_list:
            match deco:
                case ast.Call(ast.Name('resolve'), [ast.Name(consts_name)]):
                    self.consts_name = consts_name
                    break
        node.decorator_list = []
        return super().generic_visit(node)

def resolve(f, *, dictionary=None):
    if dictionary is None:
        return partial(resolve, dictionary=f)

    # Get the AST from source code
    source = inspect.getsource(f)
    tree = ast.parse(source)

    # Transform the AST
    new_tree = ast.fix_missing_locations(Resolver(dictionary).visit(tree))

    # Execute the AST as part of the original namespace, so it can still access other globals
    ns = sys.modules[f.__module__].__dict__
    exec(compile(new_tree, f.__module__, 'exec'), ns)

    # The desired function is found in the namespace
    return ns[f.__name__]

@resolve(CONSTS)
def f(x: list):
    return x[CONSTS['pos']]

# So you know it's not accessing the dictionary during the function call:
del CONSTS

print(f([1, 2]))
Jasmijn
  • 9,370
  • 2
  • 29
  • 43