-4

I would like to be able to execute Python functions one line at a time, so that I can interleave the execution of two (or more) functions arbitrarily.

The first way to achieve this, that comes to mind, would be to manually insert a yield statement between each line in the function, and after calling the function, use next() on the returned iterator whenever I want the next line in the function to be executed.

This would work, per se, but it'd be ugly as all hell and doing all that manual work would just feel very wrong. Is there a simpler and/or more automated way to achieve what I want?

I have seen this somewhat similar question, but despite the asker trying to convince people that he doesn't want a debugger, all the answers can be summarized with "use a debugger" (perhaps because he didn't explain why a debugger isn't what he needs). Anyway, the aforementioned question (and certainly the answers) aren't applicable, since I need to be able to step through many functions simultaneously, and as far as I can see, a debugger (or sys.settrace(), which is used to implement debuggers under Python, for that matter) can only be used to step through a single body of code, and can't be used to freely switch between executing multiple bodies of code.

jonrsharpe
  • 115,751
  • 26
  • 228
  • 437
Aleksi Torhamo
  • 6,452
  • 2
  • 34
  • 44
  • What are you really trying to do? This sort of sounds like you want to have multiple threads or coroutines, but strictly control their execution at a very low level. What rules would decide which thing gets to execute next? – David Maze Aug 17 '18 at 20:32
  • @DavidMaze: Well, even if there was a better way to do what I wanted to do this time, I'd have wanted to know the best way to do line-by-line execution. But the driver for this was writing test cases that interleave the execution of two functions (in many different ways) to test that two objects with a shared resource can't interfere with each other in any situation. (And there's no threads involved, so this is sufficient; I'm not testing for thread-safety, even though it might sound I'm trying to) I have a more detailed explanation at https://github.com/alexer/yieldifier/blob/master/README – Aleksi Torhamo Aug 17 '18 at 20:54
  • (It's in the second/last section, titled "Why?"; I'd have explained here, but the whole thing is too long to fit in a comment, or even two, and I'm bad at being brief :) – Aleksi Torhamo Aug 17 '18 at 20:57
  • @DavidMaze: https://meta.stackoverflow.com/questions/372756/why-was-this-question-put-on-hold-as-too-broad – Aleksi Torhamo Aug 19 '18 at 02:58

1 Answers1

12

Short answer: You can use the ast module to add the yield statements automatically. If you don't want to code it up yourself, you can also use yieldifier - a project which I created for the express purpose of answering this question.

Long answer: Unfortunately, I can't think of a "truly clean" way of doing what you want, but I can at least think of a few automated and less intrusive ways of adding the yield statements, without having to worry about actually adding them to the source code. (Disclaimer for the inevitable nitpickers: yield is an expression (since Python 2.5, anyway), but I use the term "yield statement" to refer to expression statements with yield as the sole expression)

The first method is modifying the AST (Abstract Syntax Tree), and the second one is modifying the bytecode. Modifying the AST is easier, but you need access to the source code to generate the AST. Modifying the bytecode is significantly harder, the details change between Python versions (well, that's true for the AST too, strictly speaking, but to a much lesser degree), and bytecode is an implementation detail of cpython, but you don't need access to the source code, and it's possible to do execution one bytecode at a time in addition to one line at a time, should you want it.


The AST method

Getting the AST for your function

The first step to using the AST method is getting hold of the source code. You can do this however you wish (whatever is easiest depending on your situation), eg. by reading the source file directly, if you already know the path, or using inspect.getsource() (but I'd personally get the source for the whole module instead of just the target function, to avoid relying on the line number information being correct).

After you have the source code, you can get the AST and locate the target function with something like this:

# Parse the source to an AST
mod_tree = ast.parse(source)
# Locate all the top-level functions in the module's tree
func_trees = {obj.name: obj for obj in mod_tree.body if isinstance(obj, ast.FunctionDef)}
# Get the AST of target_func
func_tree = func_trees['target_func']

Now you have the AST for your function; Next, you want to add the yield statements to the AST.

Modifying the AST

With AST, there is only a single place where it's easy to automatically add the yield statements. Fortunately for us, that place is exactly where we want them, anyway - that is, between existing statements. (So, strictly, this isn't line-by-line execution, but statement-by-statement execution, which is usually what you actually want)

If your function only has statements on a single level or you only want to step through toplevel statements, you can simply do something like this:

# Create a new body for the function, where each statement will be followed by a yield
new_body = []
for i, obj in enumerate(func_tree.body):
    new_body.append(obj)
    new_body.append(ast.Expr(value=ast.Yield(value=ast.Num(n=i))))
# Replace the old function body with the new one in the AST
func_tree.body = new_body
# Compile requires that all AST nodes have lineno and col_offset;
# This is easier than defining them manually for each node
ast.fix_missing_locations(func_tree)

Note that ast.fix_missing_locations(), as well as the line before it both modify the original tree directly; If you need the original tree for something, you could use copy.deepcopy() to copy the whole tree before modification, or create a new ast.FunctionDef node (with all values copied directly from the original, except for body = new_body), and either manually fill in lineno and col_offset for all newly created nodes, or call ast.fix_missing_locations() only on the newly created ast.Expr nodes.

Another way to do the modification is to subclass ast.NodeTransformer, which will automatically do the modification recursively. That will look something like this:

class Yieldifier(ast.NodeTransformer):
    def __init__(self):
        ast.NodeTransformer.__init__(self)
        self.added = 0

    def generic_visit(self, node):
        ast.NodeTransformer.generic_visit(self, node)
        if isinstance(node, ast.stmt):
            self.added += 1
            return [node, ast.Expr(value=ast.Yield(value=ast.Num(n=self.added)))]
        else:
            return node

Yieldifier().visit(func_tree)
ast.fix_missing_locations(func_tree)

As you can probably guess, this also modifies the tree directly, so the same warning as before applies. Since ast.NodeTransformer modifies the AST recursively and indiscriminately, this will also "yieldify" any nested functions, and even nested classes (which will result in a SyntaxError). If you want or need to avoid this, you should skip calling ast.NodeTransformer.generic_visit(self, node) in those cases.

Now that you have the modified AST for the function, there are two different approaches to compiling it.

Compiling the modified AST

If you have a modified tree for the whole module, and running the module doesn't have any inconvenient side effects, you can just compile the whole module, and the function should automatically get the proper environment (ie. imported modules etc.). That'll look something like this:

env = {}
exec(compile(mod_tree, source_path, 'exec'), env)
func = env[func_tree.name]

If you only have the modified tree for the function, or there are some other complications, you can also wrap the function tree in a new empty module, like this: mod_tree = ast.Module(body=[func_tree]). (compile in exec mode expects a whole module, so you can't just give it the function tree only) Note that in this case, you need to manually fill the env dict with the environment the function expects.

Either way, congratulations! You should now have an automatically "yieldified" version of your function! So, what's the bad news?

Fixing the line numbers (eg. for correct tracebacks)

Using ast.fix_missing_locations() is a lot easier than filling all the locations manually, but it comes at a cost: It absolutely butchers the line numbering! It copies the line number from the node's parent, so if you have a for-loop on line 5 and print() on line 6, the yield following the print will claim it's on line 5. This might not sound like a big deal, but it confuses the hell out of the compiler, which then seems to try to insert negative numbers into the code object's co_lnotab, which is supposed to contain unsigned bytes. As a result, "one line backwards" becomes "255 lines forwards", and a ten-line function might claim it has 1000's of lines. You might think, "so what?", and the answer is "nothing"... until you get your first exception, and the traceback either shows nothing (because the file is not long enough) or the wrong lines of code. So, how do you fix this?

def getmaxloc(node):
    loc = None
    for node_ in ast.walk(node):
        if not hasattr(node_, 'lineno') or not hasattr(node_, 'col_offset'):
            continue
        loc_ = (node_.lineno, node_.col_offset)
        if loc is None or loc_ > loc:
            loc = loc_
    return loc

You can use getmaxloc(node)[0] to get the maximum line number present in a node's tree, and use the same line number for the new nodes you add after it, which guarantees that your new nodes won't cause a jump backwards in line numbers. In our case, you can just give the correct lineno to the ast.Expr node, and then call ast.fix_missing_locations(), which will now copy the correct line number to the rest of the new nodes.

While you might think that this could be done more simply with ast.NodeTransformer, by just keeping track of the line number of the latest visited node, it handles the fields of AST nodes in an order that is essentially random for our purposes. For example, the body field of ast.FunctionDef is handled before the decorator_list, although the latter occurs first in the source. This means that if you want to make more involved adjustments to the locations (say, if instead of "hiding" the yield statements on the same line to get the line numbers to match the source code, you wanted the line numbers to indicate that the yield statements are on their own lines, and so you'd have to update every lineno), you're probably going to have to handle at least each statement-containing node type explicitly and differently.

Putting it all together

Finally, here's a link to a complete self-contained example of one way of doing the above. (I'd have included the code for the example in the answer, too, but I went over the 30k character limit for answers.. Oops. :)

(Note: The ZeroDivisionError that you get at the end, if you run the example, is there by design to demonstrate that the traceback is correct)


The bytecode method

The details of bytecode change between Python versions, and the dis module unfortunately doesn't provide enough tools to do much of anything in a version-independent manner (in strict contrast to the ast module). Since this is the case, I'll be choosing a Python version, namely 3.4, and running with it.

dis.get_instructions(), introduced in Python 3.4, is a step in the right direction - in previous versions, you had to copy-paste the code of dis.disassemble() and modify it yourself if you wanted to process information about the bytecode instead of just printing it out.

Getting to the bytecode

Getting to the bytecode is easy: Just say func.__code__.co_code. The raw bytecode probably isn't much help though, so like I alluded to before, we'll use dis.get_instructions(func) to get the information in a more human-readable format. Unfortunately, the dis.Instruction objects returned by the function are immutable, which is a bit impractical for what we're doing, so we'll immediately move the data to a mutable object of our own.

Since we're going to be adding to the bytecode, we'll end up breaking all jump target offsets. This is why, before modifying the bytecode and messing up all the offsets, we'll also record the target Instruction object for each jump instruction.

This will look something like this:

class Instruction:
    def __init__(self, name, op, arg=None, argval=None, argrepr=None, offset=None, starts_line=None, is_jump_target=False):
        self.name = name
        self.op = op
        self.arg = arg
        self.argval = argval
        self.argrepr = argrepr
        self.offset = offset
        self.starts_line = starts_line
        self.is_jump_target = is_jump_target
        self.target = None

hasjump = set(dis.hasjrel + dis.hasjabs)
def get_instructions(func):
    """Get the bytecode for the function, in a mutable format, with jump target links"""
    insns = [Instruction(*insn) for insn in dis.get_instructions(func)]
    insn_map = {insn.offset: insn for insn in insns}

    for insn in insns:
        if insn.op in hasjump:
            insn.target = insn_map[insn.argval]

    return insns

Modifying the bytecode

After getting the bytecode to a convenient format, we need to add the yield statements to it. In the case of bytecode, we have a lot more freedoms with the placement of the yield statements than with the AST. The Python syntax places restrictions on what you can do and where, but the bytecode really doesn't; For example, you can't say "while buf = f.read(32):", but the bytecode doesn't prevent you from doing the equivalent.

So, what's the bad news? The bytecode giveth, and the bytecode taketh away. The flexibility of bytecode comes back to bite you in the butt when it's time to decide where exactly you want to place your yield statements. It's easy if you want to yield after every bytecode - just insert yield statements between every bytecode. But anything more sophisticated than that will have complications.

Let's look at the disassembly for our test function:

def target(n):
    print('line 1')
    print('line 2')
    for i in range(3):
        print('line 4')
        print('line 5')
    print('line 6')
    math.pi / n

It'll look like this:

  6           0 LOAD_GLOBAL              0 (print)
              3 LOAD_CONST               1 ('line 1')
              6 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
              9 POP_TOP

  7          10 LOAD_GLOBAL              0 (print)
             13 LOAD_CONST               2 ('line 2')
             16 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             19 POP_TOP

  8          20 SETUP_LOOP              40 (to 63)
             23 LOAD_GLOBAL              1 (range)
             26 LOAD_CONST               3 (3)
             29 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             32 GET_ITER
        >>   33 FOR_ITER                26 (to 62)
             36 STORE_FAST               1 (i)

  9          39 LOAD_GLOBAL              0 (print)
             42 LOAD_CONST               4 ('line 4')
             45 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             48 POP_TOP

 10          49 LOAD_GLOBAL              0 (print)
             52 LOAD_CONST               5 ('line 5')
             55 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             58 POP_TOP
             59 JUMP_ABSOLUTE           33
        >>   62 POP_BLOCK

 11     >>   63 LOAD_GLOBAL              0 (print)
             66 LOAD_CONST               6 ('line 6')
             69 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             72 POP_TOP

 12          73 LOAD_GLOBAL              2 (math)
             76 LOAD_ATTR                3 (pi)
             79 LOAD_FAST                0 (n)
             82 BINARY_TRUE_DIVIDE
             83 POP_TOP
             84 LOAD_CONST               0 (None)
             87 RETURN_VALUE

On a quick glance, it looks like it would suffice to locate the first bytecode on each line, and then add a yield right before it, so it would be just after the last bytecode on the previous line. But if you look more closely, you'll notice that the yield that would get inserted after line 10 - before the instruction originally at offset 63 - would not get inserted after the print('line 5') - rather, it would get inserted after the whole for loop! In source code terms, it's the correct line, but wrong level of indentation. The last line inside the for loop is also the line where the for loop ends, since python doesn't have an endfor, or the like.

What this means is that if you really want to insert the yield right after the last statement in similar situations, things start to get pretty hairy pretty fast. It's one thing to write a hack that probably works most of the time on the exact interpreter version that you're using, but if you want something that you can trust to always work correctly, I'd say at the very minimum, you're going to have to start actually analyzing the control flow and/or the stack state. I'm not saying it's impossible, I've even done it to some degree in the past, but it's a completely different can of worms.

Fortunately, altough the simple method places the yield after the whole for loop, instead of after the last statement, it all ends up working okay, since a yield also gets inserted just after the start of the for loop, just before the first statement in the body of the loop. I won't pretend to know if it'll always work out for all cpython 3.4 -generated bytecode, but in any case, it should never break anything; The worst thing that should happen is that you won't get a yield between some statements, where you'd have wanted it.

It's also worth mentioning that unlike with the AST, here it'll truly be line-by-line execution. co_lnotab contains an entry for every logical and physical line change (and when a number doesn't fit in a single entry), but when the physical line changes, it's not possible to say based on co_lnotab, whether it's also a statement boundary (and the same goes for discriminating between statements that happen to be 255 bytecodes long versus split entries caused by statements longer than that).

dis.findlinestarts() and Instruction.starts_line both only indicate when the physical line changes, but by looking directly at co_lnotab it's possible to yield at all statement boundaries. Either way, you'll always be yielding at a line change (well, okay, lines without any code are exempt), whether it actually ends a statement or not!

With all that out of the way, let's look at the code:

def new_insn(name, arg=None, argval=None, starts_line=None, is_jump_target=False):
    return Instruction(name, dis.opmap[name], arg, argval, None, None, starts_line, is_jump_target)

co = func.__code__
insns = get_instructions(func)

yieldno = 1
new_consts = list(co.co_consts)
new_insns = []
for insn in insns:
    if insn.starts_line is not None and insn.offset:
        try:
            arg = new_consts.index(yieldno)
        except ValueError:
            arg = len(new_consts)
            new_consts.append(yieldno)
        yieldno += 1
        new_insns.extend([
            new_insn('LOAD_CONST', arg),
            new_insn('YIELD_VALUE'),
            new_insn('POP_TOP'),
        ])
    new_insns.append(insn)

When the code finds an instruction that starts a line, it inserts a yield before it (unless we're at the start, since then there is no previous line where we'd want to append the yield). Like we did with the AST, we yield a sequential number, so we first add the number into the constants used by the bytecode, if needed, and then give the index of the constant as an argument to the relevant instruction.

Fixing the jump target offsets

Like I mentioned right at the start, adding the yield statements will change the offsets of most instructions in the bytecode - this means that most jumps now point to the wrong place. To prepare for this, we added the target attribute to the instructions, and now's the time to make use of that.

First, we have to calculate the new offset of each instruction, and for this, we need to know the size of all instructions. An instruction is 1 byte long if it has no arguments, 3 bytes long if it has an argument that fits in 2 bytes, and 6 bytes long if it has an argument that fits in 4 bytes (well, to be pedantic, if the argument is longer than 2 bytes, the instruction is encoded as two opcodes, with the first one being EXTENDED_ARG, whose argument will contain the upper 2 bytes of the whole 4-byte argument).

So the length of jump instructions is dependent on where they jump, and where they jump is dependent on the length of instructions. Fortunately, if fixing the argument of a jump grows its size, that can only cause other jumps to grow, never shrink, so you can't get into an infinite loop of triggered recalculations.

Without further ado, the code:

def calc_insn_size(insn):
    """Calculate how many bytes the bytecode for the instruction will take"""
    return (6 if insn.arg >= 65536 else 3) if insn.op >= dis.HAVE_ARGUMENT else 1

def _recalc_insn_offsets(insns):
    """Calculate the offset of each instruction in the resulting bytecode"""
    offset = 0
    for insn in insns:
        insn.offset = offset
        offset += calc_insn_size(insn)
    return offset

def _recalc_jump_offsets(insns):
    """Calculate the target offset of each jump instruction

    Return value tells whether this caused the encoding of any jump instruction to change in size
    """
    size_changed = 0
    for insn in insns:
        size = calc_insn_size(insn)
        if insn.op in dis.hasjabs:
            insn.arg = insn.target.offset
            insn.argval = insn.target.offset
        elif insn.op in dis.hasjrel:
            insn.arg = insn.target.offset - (insn.offset + size)
            insn.argval = insn.target.offset
        new_size = calc_insn_size(insn)
        if new_size != size:
            size_changed += 1
    return size_changed

def _reset_jump_offsets(insns):
    """Reset all jump target offsets to 0 (so that jumps will use the smaller encoding by default)"""
    for insn in insns:
        if insn.op in hasjump:
            insn.arg = 0

def fix_offsets(insns):
    """Calculate all instruction and jump target offsets"""
    size = _recalc_insn_offsets(insns)
    _reset_jump_offsets(insns)
    # Updating the jump target offsets might cause the encoding size of some jump instructions to grow
    # If that happens, we have to recalculate the instruction offsets, some of which have grown, which means
    # we have to update the jump targets again. Naturally, this has to be repeated until things settle down.
    while _recalc_jump_offsets(insns):
        size = _recalc_insn_offsets(insns)
    return size

Putting the parts back together

There isn't really much to say about the rest. We have to encode the instructions, which would normally be a "read the docs if you want to figure out the format" -type of deal, but in python's case the instruction encoding is so simple you can figure it out just by looking at the code for all of two seconds. We also have to encode the line numbers (aka. co_lnotab), which is more involved, but since I don't want to recreate it's documentation, I'll just point you to its documentation in case you're interested.

The rest is just stuffing all the updated parts together into a code object and a function, like so:

def encode_insn(insn):
    """Generate bytecode for the instruction"""
    l = [insn.op]
    if insn.op >= dis.HAVE_ARGUMENT:
        l += [insn.arg & 0xff, (insn.arg >> 8) & 0xff]
        if insn.arg >= 65536:
            l = [dis.EXTENDED_ARG, (insn.arg >> 16) & 0xff, (insn.arg >> 24) & 0xff] + l
    return bytes(l)

def calc_lnotab(insns, firstlineno=0):
    """Calculate the line number table for the bytecode"""
    # Details of the format of co_lnotab are explained in Objects/lnotab_notes.txt, so I won't bother repeating all of that
    new_lnotab = []
    prev_offset, prev_lineno = 0, firstlineno
    for insn in insns:
        if insn.starts_line:
            offset, lineno = insn.offset - prev_offset, insn.starts_line - prev_lineno
            prev_offset, prev_lineno = insn.offset, insn.starts_line
            assert (offset > 0 or prev_offset == 0) and lineno > 0
            while offset > 255:
                new_lnotab.extend((255, 0))
                offset -= 255
            while lineno > 255:
                new_lnotab.extend((offset, 255))
                offset = 0
                lineno -= 255
            new_lnotab.extend((offset, lineno))
    return bytes(new_lnotab)

new_lnotab = calc_lnotab(new_insns, co.co_firstlineno)
new_bytecode = b''.join(map(encode_insn, new_insns))

new_code = types.CodeType(
    co.co_argcount,
    co.co_kwonlyargcount,
    co.co_nlocals,
    # To be safe (the stack should usually be empty at points where we yield)
    co.co_stacksize + 1,
    # We added yields, so the function is a generator now
    co.co_flags | 0x20,
    new_bytecode,
    tuple(new_consts),
    co.co_names,
    co.co_varnames,
    co.co_filename,
    co.co_name,
    co.co_firstlineno,
    new_lnotab,
    co.co_freevars,
    co.co_cellvars,
)

new_func = types.FunctionType(
    new_code,
    func.__globals__,
    func.__name__,
    func.__defaults__,
    func.__closure__,
)

Putting it all together

Finally, here's a link to a self-contained example of doing the above. (I'd have included the code for the complete example in the answer, too, but I went over the 30k character limit for answers.. Oops. :)

(Note: The ZeroDivisionError that you get at the end, if you run the example, is there by design to demonstrate that the traceback is correct)

Aleksi Torhamo
  • 6,452
  • 2
  • 34
  • 44