3

Given an ensemble estimator, I would like to iterate over the contents of its estimators_ field.

The problem is that the field can have a very different structure.

E.g., for a GradientBoostingClassifier it is a rank-2 numpy.ndarray (so I can use nditer) while for a RandomForestClassifier it is a simple list.

Can I do better than this:

import numpy as np
def iter_estimators(estimators):
    if isinstance(estimators, np.ndarray):
        return map(lambda x: x[()], np.nditer(estimators, flags=["refs_ok"]))
    return iter(estimators)
sds
  • 58,617
  • 29
  • 161
  • 278

2 Answers2

1

I suppose you could use np.asarray to conveniently ensure the object is an ndarray. Then use ndarray.flat to get an iterator over the flattened array.

>>> estimators = model.estimators_
>>> array = np.asarray(estimators)
>>> iterator = array.flat
>>> iterator
<numpy.flatiter at 0x7f84f48f8e00>
Matt Eding
  • 917
  • 1
  • 8
  • 15
0

A numpy-agnostic solution is

def iter_nested(obj):
    """Iterate over all iterable sub-objects.
    https://stackoverflow.com/q/58615038/850781"""
    try:
        for o1 in obj:
            for o2 in iter_nested(o1):
                yield o2
    except TypeError:           # ... object is not iterable
        yield obj

See also

sds
  • 58,617
  • 29
  • 161
  • 278