I'm trying to understand the .filter() method in Pandas. I'm not sure why the below code doesn't work:
# Load data
from sklearn.datasets import load_iris
import pandas as pd
data = load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
# Set arbitrary index (is this needed?) and try filtering:
indexed_df = df.copy().set_index('sepal width (cm)')
test = indexed_df.filter(lambda x: x['petal length (cm)'] > 1.4)
I get:
TypeError: 'function' object is not iterable
I appreciate there are simpler ways to do this (e.g. Boolean indexing) but I'm trying to understand for learning purposes why filter
fails here when it works for a groupby
as shown below:
This works:
filtered_df = df.groupby('petal width (cm)').filter(lambda x: x['sepal width (cm)'].sum() > 50)