0

I have a number of multi-index columns each with a list of tuples that I want to flatten (the list, not the tuples) but I'm struggling with it. Here's what I have:

df = pd.DataFrame([[[(1,'a')],[(6,'b')],np.nan,np.nan],[[(5,'d'),(10,'e')],np.nan,np.nan,[(8,'c')]]])
df.columns = pd.MultiIndex.from_tuples([('a', 0), ('a', 1), ('b', 0), ('b', 1)])

>>> df
                   a             b
                   0         1   0         1
0           [(1, a)]  [(6, b)] NaN       NaN
1  [(5, d), (10, e)]       NaN NaN  [(8, c)]

Desired result:

>>> df
                   a             b
0           [(1, a), (6, b)]  [NaN, NaN]
1      [(5, d), (10, e), NaN] [NaN, (8, c)]

How do I do this? From this related question, I tried the following:

>>> df.stack(level=1).groupby(level=[0]).agg(lambda x: np.array(list(x)).flatten())
   a  b
0  a  b
1  a  b

>>> df.stack(level=1).groupby(level=[0]).agg(lambda x: np.concatenate(list(x)))
...
Exception: Must produce aggregated value
irene
  • 2,085
  • 1
  • 22
  • 36

2 Answers2

2

Here's a way to do:

# taken from https://stackoverflow.com/questions/12472338/flattening-a-list-recursively
def flatten(S):
    if S == []:
        return S
    if isinstance(S[0], list):
        return flatten(S[0]) + flatten(S[1:])
    return S[:1] + flatten(S[1:])

# reshape the data for get the desired structure
df2 = (df
     .unstack()
     .reset_index()
     .drop('level_1', 1)
     .groupby(['level_0', 'level_2'])[0]
     .apply(list).apply(flatten).unstack().T)

df2.index.name = None
df2.columns.name = None

print(df2)

                       a             b
0       [(1, a), (6, b)]      [na, na]
1  [(5, d), (10, e), na]  [na, (8, c)]
YOLO
  • 20,181
  • 5
  • 20
  • 40
  • 1
    A bit too long with several steps, but the logic works. I made a shorter answer above. +1, thank you! – irene Feb 25 '20 at 18:03
0

Found a one-liner:

Using the flatten custom function given by @YOLO

>>> df.stack(level=1).groupby(level=0).agg(list).applymap(flatten)
                        a              b
0        [(1, a), (6, b)]     [nan, nan]
1  [(5, d), (10, e), nan]  [nan, (8, c)]

where

def flatten(S):
    if S == []:
        return S
    if isinstance(S[0], list):
        return flatten(S[0]) + flatten(S[1:])
    return S[:1] + flatten(S[1:])
irene
  • 2,085
  • 1
  • 22
  • 36