2

I have a number of operations I want to "fuse" together. Let's say there are 3 possible operations:

sq = lambda x: x**2
add = lambda x: x+3
mul = lambda x: x*5

I also have an array of operations:

ops = [add, sq, mul, sq]

I can then create a function from these operations:

def generateF(ops):
    def inner(x):
        for op in ops:
            x = op(x)
        return x
    return inner
f = generateF(ops)
f(3) # returns 32400

fastF = lambda x: (5*(x+3)**2)**2

f and fastF does the same thing, but fastF is around 1.7-2 times faster than f on my benchmark, which makes sense. My question is, how can I write generateF function that returns a function that is as fast as fastF? The operations are restricted to basic operations like __add__, __mul__, __matmul__, __rrshift__, etc (essentially most numeric operations). generateF can take as long as you'd like, because it will be done before reaching hot code.

The context is that this is a part of my library, so I can define every legal operation, and thus know exactly what they are. The operation definitions are not given to us by the end user randomly (the user can only pick the order of the operations), so we can utilize every outside knowledge about them.

This might seem like premature optimization, but it is not, as f is hot code. Dropping to C is not an option, as the operations can be complex (think, PyTorch tensor multiply), and x can be of any type. Currently, I'm thinking about modifying python's bytecode, but that is very unpleasant, as bytecode specifications changes for every Python version, so I wanted to ask here first before diving into that solution.

157 239n
  • 349
  • 3
  • 15
  • I would suggest looking at decoraters, I believe they can help you in this case. – Zaid Al Shattle Nov 12 '21 at 21:08
  • @ZaidAlShattle Can you elaborate? I know and use decorators a lot but I can't imagine a solution using them. – 157 239n Nov 12 '21 at 21:09
  • I do actually think that perhaps I was wrong in this case about decorators, but I do have an idea which I will post in a few minutes that should hopefully help – Zaid Al Shattle Nov 12 '21 at 21:11
  • 2
    I don't think you can do this. There's no way to extract the body of a function and merge it into a new function, so you're always going to get the overhead of calling 3 functions. – Barmar Nov 12 '21 at 21:14
  • I have added some context which might help to constrain the problem down – 157 239n Nov 12 '21 at 21:15
  • I thought about using string manipulation to setup the calculation function (which would make the calculation itself faster when you `eval()` it), but the string manipulation would take longer, so I am not sure if that fits your use case at all – Zaid Al Shattle Nov 12 '21 at 21:16
  • 3
    You could get exactly the same result as `fastF` by generating Python source code and then calling `exec()`/`eval()`/`compile()` on it. Your individual operations would become string manipulations: `sq = lambda x: f"({x})**2"` for example. Start with the innermost `x` actually being `"x"`, then put `"lambda x:"` in front of the result. – jasonharper Nov 12 '21 at 21:16
  • 1
    If some of the operations are super long (Pytorch vector ops), then is the slow step *really* the extra function call overhead? Either the function call overhead is the slowest thing happening (in which case, dip into C and accelerate it), or you're running slow Python functions (in which case, the extra calls aren't killing you) – Silvio Mayolo Nov 12 '21 at 21:17
  • Two things working against you are the overhead of calling the lambda functions, which includes creating a new function context object for each, and the assignment to the intermediate variable. Cython may help hter. – tdelaney Nov 12 '21 at 21:18
  • 1
    @SilvioMayolo Yes I understand that. I included in PyTorch tensor ops just to be consistent with everything. `f` will mostly operate on random generic objects, which includes PyTorch tensors, but it will mostly operate on lightweight objects, so function call overhead here is a lot. – 157 239n Nov 12 '21 at 21:20
  • 2
    Beside the point, but [named lambdas are bad practice](/q/38381556/4518341). Use a `def` instead. I'm pretty sure they're equally performant. – wjandrea Nov 12 '21 at 21:25

3 Answers3

4

Here is a very hacky version of synthesizing a new function from the bytecode of the given functions. The basic technique is to keep the LOAD_FAST opcode only at the beginning of the first function, and strip off the RETURN_VALUE opcode except at the end of the last function. This leaves the value being manipulated on the stack in between (what were originally) your functions. When you're done, you don't have any function calls.

import dis, inspect

sq = lambda x: x**2
add = lambda x: x+3
mul = lambda x: x*5

ops = [add, sq, mul, sq]

def synthF(ops):
    bytecode = bytearray()
    constants = []
    stacksize = 0
    for i, op in enumerate(ops):
        code = op.__code__
        # works only with functions having one argument and no other vars
        assert code.co_argcount == code.co_nlocals == 1
        assert not code.co_freevars
        stacksize = max(stacksize, code.co_stacksize)
        opcodes = bytearray(code.co_code)
        # starts with LOAD_FAST argument 0 (i.e. we're doing something with our arg)
        assert opcodes[0] == dis.opmap["LOAD_FAST"] and opcodes[1] == 0
        # ends with RETURN_VALUE
        assert opcodes[-2] == dis.opmap["RETURN_VALUE"] and opcodes[-1] == 0
        if bytecode:        # if this isn't our first function, our variable is already on the stock
            opcodes = opcodes[2:]
        # adjust LOAD_CONSTANT opcodes. each function can have constants,
        # but their indexes start at 0 in each function.  since we're
        # putting these into a single function we need to accumulate the
        # constants used in each function and adjust the indexes used in
        # the function's bytecode to access the value by its index in the
        # accumulated list.
        offset = 0
        if bytecode:
            while True:
                none = code.co_consts[0] is None
                offset = opcodes.find(dis.opmap["LOAD_CONST"], offset)
                if offset < 0:
                    break
                if not offset % 2 and (not none or opcodes[offset+1]):
                    opcodes[offset+1] += len(constants) - none
                offset += 2
            # first constant is always None. don't include multiple copies
            # (to be safe, we actually check that)
            constants.extend(code.co_consts[none:])
        else:
            assert code.co_consts[0] is None
            constants.extend(code.co_consts)
        # add our adjusted bytecode, cutting off the RETURN_VALUE opcode
        bytecode.extend(opcodes[:-2])
    bytecode.extend([dis.opmap["RETURN_VALUE"], 0])

    func = type(ops[0])(type(code)(1, 1, 0, 1, stacksize, inspect.CO_OPTIMIZED, bytes(bytecode),
                tuple(constants), (), ("x",), "<generated>", "<generated>", 0, b''),
                globals())

    return func

f = synthF(ops)
assert f(3) == 32400

Gross, and lots of caveats (called out in comments) but it works, and should be about as fast as your expression, since it compiles to virtually the same bytecode. It would need a bit of work to support concatenating more complex functions.

kindall
  • 178,883
  • 35
  • 278
  • 309
1

Here's an alternative using chaining. This way, there's only function calls in your generated function calls, no iteration.

def makeF(ops):
    f = ops[0]
    for op in ops[1:]:
        f = lambda x, op=op, f=f: op(f(x))
    return f

Bad news: it replaces each function call with two, so it's actually slower than your iterative version. :/

kindall
  • 178,883
  • 35
  • 278
  • 309
  • Can confirm. It's very slightly slower than my iterative version (1.02x slower). But thanks for suggesting it nonetheless – 157 239n Nov 12 '21 at 21:28
0

As there seems to be no solution, this is what I've settled on. Knowing that most operations will be short (<4 operations total), I just hard code in to get rid of the for loop.

def generateF(ops):
    l = len(ops)
    if l == 1:
        return ops[0]
    if l == 2:
        a, b = ops
        return lambda x: b(a(x))
    if l == 3:
        a, b, c = ops
        return lambda x: c(b(a(x)))
    if l == 4:
        a, b, c, d = ops
        return lambda x: d(c(b(a(x))))
    def inner(x):
        for op in ops:
            x = op(x)
        return x
    return inner
fastF = generateF(ops)

This is only 1.4x slower than fastF (originally 1.7-2x slower). If you have any other ideas, I will consider it.

157 239n
  • 349
  • 3
  • 15