Short answer:
You can use the ast
module to add the yield
statements automatically. If you don't want to code it up yourself, you can also use yieldifier - a project which I created for the express purpose of answering this question.
Long answer:
Unfortunately, I can't think of a "truly clean" way of doing what you want, but I can at least think of a few automated and less intrusive ways of adding the yield
statements, without having to worry about actually adding them to the source code. (Disclaimer for the inevitable nitpickers: yield is an expression (since Python 2.5, anyway), but I use the term "yield statement" to refer to expression statements with yield as the sole expression)
The first method is modifying the AST (Abstract Syntax Tree), and the second one is modifying the bytecode. Modifying the AST is easier, but you need access to the source code to generate the AST. Modifying the bytecode is significantly harder, the details change between Python versions (well, that's true for the AST too, strictly speaking, but to a much lesser degree), and bytecode is an implementation detail of cpython, but you don't need access to the source code, and it's possible to do execution one bytecode at a time in addition to one line at a time, should you want it.
The AST method
Getting the AST for your function
The first step to using the AST method is getting hold of the source code. You can do this however you wish (whatever is easiest depending on your situation), eg. by reading the source file directly, if you already know the path, or using inspect.getsource()
(but I'd personally get the source for the whole module instead of just the target function, to avoid relying on the line number information being correct).
After you have the source code, you can get the AST and locate the target function with something like this:
# Parse the source to an AST
mod_tree = ast.parse(source)
# Locate all the top-level functions in the module's tree
func_trees = {obj.name: obj for obj in mod_tree.body if isinstance(obj, ast.FunctionDef)}
# Get the AST of target_func
func_tree = func_trees['target_func']
Now you have the AST for your function; Next, you want to add the yield statements to the AST.
Modifying the AST
With AST, there is only a single place where it's easy to automatically add the yield
statements. Fortunately for us, that place is exactly where we want them, anyway - that is, between existing statements. (So, strictly, this isn't line-by-line execution, but statement-by-statement execution, which is usually what you actually want)
If your function only has statements on a single level or you only want to step through toplevel statements, you can simply do something like this:
# Create a new body for the function, where each statement will be followed by a yield
new_body = []
for i, obj in enumerate(func_tree.body):
new_body.append(obj)
new_body.append(ast.Expr(value=ast.Yield(value=ast.Num(n=i))))
# Replace the old function body with the new one in the AST
func_tree.body = new_body
# Compile requires that all AST nodes have lineno and col_offset;
# This is easier than defining them manually for each node
ast.fix_missing_locations(func_tree)
Note that ast.fix_missing_locations()
, as well as the line before it both modify the original tree directly; If you need the original tree for something, you could use copy.deepcopy()
to copy the whole tree before modification, or create a new ast.FunctionDef
node (with all values copied directly from the original, except for body = new_body
), and either manually fill in lineno
and col_offset
for all newly created nodes, or call ast.fix_missing_locations()
only on the newly created ast.Expr
nodes.
Another way to do the modification is to subclass ast.NodeTransformer
, which will automatically do the modification recursively. That will look something like this:
class Yieldifier(ast.NodeTransformer):
def __init__(self):
ast.NodeTransformer.__init__(self)
self.added = 0
def generic_visit(self, node):
ast.NodeTransformer.generic_visit(self, node)
if isinstance(node, ast.stmt):
self.added += 1
return [node, ast.Expr(value=ast.Yield(value=ast.Num(n=self.added)))]
else:
return node
Yieldifier().visit(func_tree)
ast.fix_missing_locations(func_tree)
As you can probably guess, this also modifies the tree directly, so the same warning as before applies. Since ast.NodeTransformer
modifies the AST recursively and indiscriminately, this will also "yieldify" any nested functions, and even nested classes (which will result in a SyntaxError
). If you want or need to avoid this, you should skip calling ast.NodeTransformer.generic_visit(self, node)
in those cases.
Now that you have the modified AST for the function, there are two different approaches to compiling it.
Compiling the modified AST
If you have a modified tree for the whole module, and running the module doesn't have any inconvenient side effects, you can just compile the whole module, and the function should automatically get the proper environment (ie. imported modules etc.). That'll look something like this:
env = {}
exec(compile(mod_tree, source_path, 'exec'), env)
func = env[func_tree.name]
If you only have the modified tree for the function, or there are some other complications, you can also wrap the function tree in a new empty module, like this: mod_tree = ast.Module(body=[func_tree])
. (compile
in exec mode expects a whole module, so you can't just give it the function tree only) Note that in this case, you need to manually fill the env
dict with the environment the function expects.
Either way, congratulations! You should now have an automatically "yieldified" version of your function! So, what's the bad news?
Fixing the line numbers (eg. for correct tracebacks)
Using ast.fix_missing_locations()
is a lot easier than filling all the locations manually, but it comes at a cost: It absolutely butchers the line numbering! It copies the line number from the node's parent, so if you have a for
-loop on line 5 and print()
on line 6, the yield
following the print
will claim it's on line 5. This might not sound like a big deal, but it confuses the hell out of the compiler, which then seems to try to insert negative numbers into the code object's co_lnotab
, which is supposed to contain unsigned bytes. As a result, "one line backwards" becomes "255 lines forwards", and a ten-line function might claim it has 1000's of lines. You might think, "so what?", and the answer is "nothing"... until you get your first exception, and the traceback either shows nothing (because the file is not long enough) or the wrong lines of code. So, how do you fix this?
def getmaxloc(node):
loc = None
for node_ in ast.walk(node):
if not hasattr(node_, 'lineno') or not hasattr(node_, 'col_offset'):
continue
loc_ = (node_.lineno, node_.col_offset)
if loc is None or loc_ > loc:
loc = loc_
return loc
You can use getmaxloc(node)[0]
to get the maximum line number present in a node's tree, and use the same line number for the new nodes you add after it, which guarantees that your new nodes won't cause a jump backwards in line numbers. In our case, you can just give the correct lineno
to the ast.Expr
node, and then call ast.fix_missing_locations()
, which will now copy the correct line number to the rest of the new nodes.
While you might think that this could be done more simply with ast.NodeTransformer
, by just keeping track of the line number of the latest visited node, it handles the fields of AST nodes in an order that is essentially random for our purposes. For example, the body
field of ast.FunctionDef
is handled before the decorator_list
, although the latter occurs first in the source. This means that if you want to make more involved adjustments to the locations (say, if instead of "hiding" the yield
statements on the same line to get the line numbers to match the source code, you wanted the line numbers to indicate that the yield
statements are on their own lines, and so you'd have to update every lineno
), you're probably going to have to handle at least each statement-containing node type explicitly and differently.
Putting it all together
Finally, here's a link to a complete self-contained example of one way of doing the above. (I'd have included the code for the example in the answer, too, but I went over the 30k character limit for answers.. Oops. :)
(Note: The ZeroDivisionError
that you get at the end, if you run the example, is there by design to demonstrate that the traceback is correct)
The bytecode method
The details of bytecode change between Python versions, and the dis
module unfortunately doesn't provide enough tools to do much of anything in a version-independent manner (in strict contrast to the ast
module). Since this is the case, I'll be choosing a Python version, namely 3.4, and running with it.
dis.get_instructions()
, introduced in Python 3.4, is a step in the right direction - in previous versions, you had to copy-paste the code of dis.disassemble()
and modify it yourself if you wanted to process information about the bytecode instead of just printing it out.
Getting to the bytecode
Getting to the bytecode is easy: Just say func.__code__.co_code
. The raw bytecode probably isn't much help though, so like I alluded to before, we'll use dis.get_instructions(func)
to get the information in a more human-readable format. Unfortunately, the dis.Instruction
objects returned by the function are immutable, which is a bit impractical for what we're doing, so we'll immediately move the data to a mutable object of our own.
Since we're going to be adding to the bytecode, we'll end up breaking all jump target offsets. This is why, before modifying the bytecode and messing up all the offsets, we'll also record the target Instruction
object for each jump instruction.
This will look something like this:
class Instruction:
def __init__(self, name, op, arg=None, argval=None, argrepr=None, offset=None, starts_line=None, is_jump_target=False):
self.name = name
self.op = op
self.arg = arg
self.argval = argval
self.argrepr = argrepr
self.offset = offset
self.starts_line = starts_line
self.is_jump_target = is_jump_target
self.target = None
hasjump = set(dis.hasjrel + dis.hasjabs)
def get_instructions(func):
"""Get the bytecode for the function, in a mutable format, with jump target links"""
insns = [Instruction(*insn) for insn in dis.get_instructions(func)]
insn_map = {insn.offset: insn for insn in insns}
for insn in insns:
if insn.op in hasjump:
insn.target = insn_map[insn.argval]
return insns
Modifying the bytecode
After getting the bytecode to a convenient format, we need to add the yield
statements to it. In the case of bytecode, we have a lot more freedoms with the placement of the yield
statements than with the AST. The Python syntax places restrictions on what you can do and where, but the bytecode really doesn't; For example, you can't say "while buf = f.read(32):", but the bytecode doesn't prevent you from doing the equivalent.
So, what's the bad news? The bytecode giveth, and the bytecode taketh away. The flexibility of bytecode comes back to bite you in the butt when it's time to decide where exactly you want to place your yield
statements. It's easy if you want to yield after every bytecode - just insert yield
statements between every bytecode. But anything more sophisticated than that will have complications.
Let's look at the disassembly for our test function:
def target(n):
print('line 1')
print('line 2')
for i in range(3):
print('line 4')
print('line 5')
print('line 6')
math.pi / n
It'll look like this:
6 0 LOAD_GLOBAL 0 (print)
3 LOAD_CONST 1 ('line 1')
6 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
9 POP_TOP
7 10 LOAD_GLOBAL 0 (print)
13 LOAD_CONST 2 ('line 2')
16 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
19 POP_TOP
8 20 SETUP_LOOP 40 (to 63)
23 LOAD_GLOBAL 1 (range)
26 LOAD_CONST 3 (3)
29 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
32 GET_ITER
>> 33 FOR_ITER 26 (to 62)
36 STORE_FAST 1 (i)
9 39 LOAD_GLOBAL 0 (print)
42 LOAD_CONST 4 ('line 4')
45 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
48 POP_TOP
10 49 LOAD_GLOBAL 0 (print)
52 LOAD_CONST 5 ('line 5')
55 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
58 POP_TOP
59 JUMP_ABSOLUTE 33
>> 62 POP_BLOCK
11 >> 63 LOAD_GLOBAL 0 (print)
66 LOAD_CONST 6 ('line 6')
69 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
72 POP_TOP
12 73 LOAD_GLOBAL 2 (math)
76 LOAD_ATTR 3 (pi)
79 LOAD_FAST 0 (n)
82 BINARY_TRUE_DIVIDE
83 POP_TOP
84 LOAD_CONST 0 (None)
87 RETURN_VALUE
On a quick glance, it looks like it would suffice to locate the first bytecode on each line, and then add a yield
right before it, so it would be just after the last bytecode on the previous line. But if you look more closely, you'll notice that the yield
that would get inserted after line 10 - before the instruction originally at offset 63 - would not get inserted after the print('line 5')
- rather, it would get inserted after the whole for loop! In source code terms, it's the correct line, but wrong level of indentation. The last line inside the for
loop is also the line where the for
loop ends, since python doesn't have an endfor
, or the like.
What this means is that if you really want to insert the yield
right after the last statement in similar situations, things start to get pretty hairy pretty fast. It's one thing to write a hack that probably works most of the time on the exact interpreter version that you're using, but if you want something that you can trust to always work correctly, I'd say at the very minimum, you're going to have to start actually analyzing the control flow and/or the stack state. I'm not saying it's impossible, I've even done it to some degree in the past, but it's a completely different can of worms.
Fortunately, altough the simple method places the yield
after the whole for loop, instead of after the last statement, it all ends up working okay, since a yield
also gets inserted just after the start of the for
loop, just before the first statement in the body of the loop. I won't pretend to know if it'll always work out for all cpython 3.4 -generated bytecode, but in any case, it should never break anything; The worst thing that should happen is that you won't get a yield
between some statements, where you'd have wanted it.
It's also worth mentioning that unlike with the AST, here it'll truly be line-by-line execution. co_lnotab
contains an entry for every logical and physical line change (and when a number doesn't fit in a single entry), but when the physical line changes, it's not possible to say based on co_lnotab
, whether it's also a statement boundary (and the same goes for discriminating between statements that happen to be 255 bytecodes long versus split entries caused by statements longer than that).
dis.findlinestarts()
and Instruction.starts_line
both only indicate when the physical line changes, but by looking directly at co_lnotab
it's possible to yield
at all statement boundaries. Either way, you'll always be yielding at a line change (well, okay, lines without any code are exempt), whether it actually ends a statement or not!
With all that out of the way, let's look at the code:
def new_insn(name, arg=None, argval=None, starts_line=None, is_jump_target=False):
return Instruction(name, dis.opmap[name], arg, argval, None, None, starts_line, is_jump_target)
co = func.__code__
insns = get_instructions(func)
yieldno = 1
new_consts = list(co.co_consts)
new_insns = []
for insn in insns:
if insn.starts_line is not None and insn.offset:
try:
arg = new_consts.index(yieldno)
except ValueError:
arg = len(new_consts)
new_consts.append(yieldno)
yieldno += 1
new_insns.extend([
new_insn('LOAD_CONST', arg),
new_insn('YIELD_VALUE'),
new_insn('POP_TOP'),
])
new_insns.append(insn)
When the code finds an instruction that starts a line, it inserts a yield
before it (unless we're at the start, since then there is no previous line where we'd want to append the yield
). Like we did with the AST, we yield a sequential number, so we first add the number into the constants used by the bytecode, if needed, and then give the index of the constant as an argument to the relevant instruction.
Fixing the jump target offsets
Like I mentioned right at the start, adding the yield
statements will change the offsets of most instructions in the bytecode - this means that most jumps now point to the wrong place. To prepare for this, we added the target
attribute to the instructions, and now's the time to make use of that.
First, we have to calculate the new offset of each instruction, and for this, we need to know the size of all instructions. An instruction is 1 byte long if it has no arguments, 3 bytes long if it has an argument that fits in 2 bytes, and 6 bytes long if it has an argument that fits in 4 bytes (well, to be pedantic, if the argument is longer than 2 bytes, the instruction is encoded as two opcodes, with the first one being EXTENDED_ARG
, whose argument will contain the upper 2 bytes of the whole 4-byte argument).
So the length of jump instructions is dependent on where they jump, and where they jump is dependent on the length of instructions. Fortunately, if fixing the argument of a jump grows its size, that can only cause other jumps to grow, never shrink, so you can't get into an infinite loop of triggered recalculations.
Without further ado, the code:
def calc_insn_size(insn):
"""Calculate how many bytes the bytecode for the instruction will take"""
return (6 if insn.arg >= 65536 else 3) if insn.op >= dis.HAVE_ARGUMENT else 1
def _recalc_insn_offsets(insns):
"""Calculate the offset of each instruction in the resulting bytecode"""
offset = 0
for insn in insns:
insn.offset = offset
offset += calc_insn_size(insn)
return offset
def _recalc_jump_offsets(insns):
"""Calculate the target offset of each jump instruction
Return value tells whether this caused the encoding of any jump instruction to change in size
"""
size_changed = 0
for insn in insns:
size = calc_insn_size(insn)
if insn.op in dis.hasjabs:
insn.arg = insn.target.offset
insn.argval = insn.target.offset
elif insn.op in dis.hasjrel:
insn.arg = insn.target.offset - (insn.offset + size)
insn.argval = insn.target.offset
new_size = calc_insn_size(insn)
if new_size != size:
size_changed += 1
return size_changed
def _reset_jump_offsets(insns):
"""Reset all jump target offsets to 0 (so that jumps will use the smaller encoding by default)"""
for insn in insns:
if insn.op in hasjump:
insn.arg = 0
def fix_offsets(insns):
"""Calculate all instruction and jump target offsets"""
size = _recalc_insn_offsets(insns)
_reset_jump_offsets(insns)
# Updating the jump target offsets might cause the encoding size of some jump instructions to grow
# If that happens, we have to recalculate the instruction offsets, some of which have grown, which means
# we have to update the jump targets again. Naturally, this has to be repeated until things settle down.
while _recalc_jump_offsets(insns):
size = _recalc_insn_offsets(insns)
return size
Putting the parts back together
There isn't really much to say about the rest. We have to encode the instructions, which would normally be a "read the docs if you want to figure out the format" -type of deal, but in python's case the instruction encoding is so simple you can figure it out just by looking at the code for all of two seconds. We also have to encode the line numbers (aka. co_lnotab
), which is more involved, but since I don't want to recreate it's documentation, I'll just point you to its documentation in case you're interested.
The rest is just stuffing all the updated parts together into a code object and a function, like so:
def encode_insn(insn):
"""Generate bytecode for the instruction"""
l = [insn.op]
if insn.op >= dis.HAVE_ARGUMENT:
l += [insn.arg & 0xff, (insn.arg >> 8) & 0xff]
if insn.arg >= 65536:
l = [dis.EXTENDED_ARG, (insn.arg >> 16) & 0xff, (insn.arg >> 24) & 0xff] + l
return bytes(l)
def calc_lnotab(insns, firstlineno=0):
"""Calculate the line number table for the bytecode"""
# Details of the format of co_lnotab are explained in Objects/lnotab_notes.txt, so I won't bother repeating all of that
new_lnotab = []
prev_offset, prev_lineno = 0, firstlineno
for insn in insns:
if insn.starts_line:
offset, lineno = insn.offset - prev_offset, insn.starts_line - prev_lineno
prev_offset, prev_lineno = insn.offset, insn.starts_line
assert (offset > 0 or prev_offset == 0) and lineno > 0
while offset > 255:
new_lnotab.extend((255, 0))
offset -= 255
while lineno > 255:
new_lnotab.extend((offset, 255))
offset = 0
lineno -= 255
new_lnotab.extend((offset, lineno))
return bytes(new_lnotab)
new_lnotab = calc_lnotab(new_insns, co.co_firstlineno)
new_bytecode = b''.join(map(encode_insn, new_insns))
new_code = types.CodeType(
co.co_argcount,
co.co_kwonlyargcount,
co.co_nlocals,
# To be safe (the stack should usually be empty at points where we yield)
co.co_stacksize + 1,
# We added yields, so the function is a generator now
co.co_flags | 0x20,
new_bytecode,
tuple(new_consts),
co.co_names,
co.co_varnames,
co.co_filename,
co.co_name,
co.co_firstlineno,
new_lnotab,
co.co_freevars,
co.co_cellvars,
)
new_func = types.FunctionType(
new_code,
func.__globals__,
func.__name__,
func.__defaults__,
func.__closure__,
)
Putting it all together
Finally, here's a link to a self-contained example of doing the above. (I'd have included the code for the complete example in the answer, too, but I went over the 30k character limit for answers.. Oops. :)
(Note: The ZeroDivisionError
that you get at the end, if you run the example, is there by design to demonstrate that the traceback is correct)