0

Basically what I want to do is take two lists of objects and based on some test and divide them into two lists depending on whether the result is True or False. Sort of what filter() does, except that instead of determining if it's in or out it determines which list/iterator to go to. I've made this attempt using itertools.groupby():

import random
from itertools import groupby

class FooObject(object):
    def __init__(self):
        self.key = random.choice((True, False))

    def __repr__(self):
        return '<Foo: %s>' % self.key

foos = [FooObject() for _ in range(10)]
left, right = [], []

groups = groupby(sorted(foos, key=lambda o: o.key), lambda o: o.key)
for k, group in groups:
    if k:
        right = list(group)
    else:
        left = list(group)

print left
print right

This gets the job done, but just wondering if there is a clearner/simpler way. I realize I could use filter() (or the equivalent list comprehension) and do it in two passes, but what fun is that?

FatalError
  • 52,695
  • 14
  • 99
  • 116

2 Answers2

1

Here's a function that consumes the source only once and returns an dictionary-like object, each member of which is a generator, that yields values from the source as lazily as possible:

def partition(it, fun):

    class x(object):
        def __init__(self):
            self.buf = {}

        def flush(self, val):
            for p in self.buf.get(val, []):
                yield p
            self.buf.pop(val, None)

        def __getitem__(self, val):
            for p in self.flush(val): yield p
            while True:
                try:
                    p = next(it)
                except StopIteration:
                    break
                v = fun(p)
                if v == val:
                    yield p
                else:
                    self.buf.setdefault(v, []).append(p)
            for p in self.flush(val): yield p

    return x()

Example of use:

def primes(): # note that this is an endless generator
    yield 2
    p, n = [], 3
    while True:
        if all(n % x for x in p):
            p.append(n)
            yield n
        n += 2


p = partition(primes(), lambda x: x % 3)
# each member of p is endless as well

for x in p[1]:
    print x
    if x > 200: break

for x in p[2]:
    print x
    if x > 200: break
georg
  • 211,518
  • 52
  • 313
  • 390
1

If you only have 2 buckets, you can use a ternary:

d={'left':[],'right':[]}
for e in (random.random() for i in xrange(50)):
    d['left' if e<0.5 else 'right'].append(e)

With more than 2 buckets, use a function that returns keys already defined or use a default dict with a list:

def f(i):
   return int(i*10)

DoL=defaultdict(list)
for e in (random.random() for i in xrange(50)):
   DoL[f(e)].append(e)
the wolf
  • 34,510
  • 13
  • 53
  • 71