Given an ensemble estimator, I would like to iterate over the contents of its estimators_
field.
The problem is that the field can have a very different structure.
E.g., for a GradientBoostingClassifier
it is a rank-2 numpy.ndarray
(so I can use nditer
) while for a RandomForestClassifier
it is a simple list
.
Can I do better than this:
import numpy as np
def iter_estimators(estimators):
if isinstance(estimators, np.ndarray):
return map(lambda x: x[()], np.nditer(estimators, flags=["refs_ok"]))
return iter(estimators)