I'm having difficulty getting the following complex list comprehension to work as expected. It's a double nested for loop with conditionals.
Let me first explain what I'm doing:
import pandas as pd
dict1 = {'stringA':['ABCDBAABDCBD','BBXB'], 'stringB':['ABDCXXXBDDDD', 'AAAB'], 'num':[42, 13]}
df = pd.DataFrame(dict1)
print(df)
stringA stringB num
0 ABCDBAABDCBD ABDCXXXBDDDD 42
1 BBXB AAAB 13
This DataFrame has two columns stringA
and stringB
with strings containing characters A
, B
, C
, D
, X
. By definition, these two strings have the same length.
Based on these two columns, I create dictionaries such that stringA
begins at index 0, and stringB
begins at the index starting at num
.
Here's the function I use:
def create_translation(x):
x['translated_dictionary'] = {i: i +x['num'] for i, e in enumerate(x['stringA'])}
return x
df2 = df.apply(create_translation, axis=1).groupby('stringA')['translated_dictionary']
df2.head()
0 {0: 42, 1: 43, 2: 44, 3: 45, 4: 46, 5: 47, 6: ...
1 {0: 13, 1: 14, 2: 15, 3: 16}
Name: translated_dictionary, dtype: object
print(df2.head()[0])
{0: 42, 1: 43, 2: 44, 3: 45, 4: 46, 5: 47, 6: 48, 7: 49, 8: 50, 9: 51, 10: 52, 11: 53}
print(df2.head()[1])
{0: 13, 1: 14, 2: 15, 3: 16}
That's correct.
However, there are 'X' characters in these strings. That requires a special rule: If X
is in stringA
, don't create a key-value pair in the dictionary. If X
is in stringB
, then the value should not be i + x['num']
but -500
.
I tried the following list comprehension:
def try1(x):
for count, element in enumerate(x['stringB']):
x['translated_dictionary'] = {i: -500 if element == 'X' else i + x['num'] for i, e in enumerate(x['stringA']) if e != 'X'}
return x
That gives the wrong answer.
df3 = df.apply(try1, axis=1).groupby('stringA')['translated_dictionary']
print(df3.head()[0]) ## this is wrong!
{0: 42, 1: 43, 2: 44, 3: 45, 4: 46, 5: 47, 6: 48, 7: 49, 8: 50, 9: 51, 10: 52, 11: 53}
print(df3.head()[1]) ## this is correct! There is no key for 2:15!
{0: 13, 1: 14, 3: 16}
There are no -500 values!
The correct answer is:
print(df3.head()[0])
{0: 42, 1: 43, 2: 44, 3: 45, 4:-500, 5:-500, 6:-500, 7: 49, 8: 50, 9: 51, 10: 52, 11: 53}
print(df3.head()[1])
{0: 13, 1: 14, 3: 16}