Optimized way:
Instead of applying ast.literal_eval
many times in a loop, load the input csv file with the needed structure at once using converters
option to convert all 'embeddings'
column array string representation into "real" array with numpy.fromstring
routine.
df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv",
delimiter=',', usecols=["embeddings"], quotechar='"',
converters = {'embeddings': lambda s: np.fromstring(s[1:-1], sep=',')})
In the above process in my measurements numpy.fromstring
runs in 6 times faster than converters = {'embeddings': ast.literal_eval}
, though the latter, in it's turn, is faster than your initial approach.
Then, to avoid multiple dataframe insertions with a new dist<num>
column replace for
looping with pd.concat
:
df = pd.concat([df] + [df.embeddings.apply(
lambda x: np.dot(x, arr)
/ (np.linalg.norm(x) * np.linalg.norm(arr))
).rename(f'dist{i}')
for i, arr in enumerate(df["embeddings"])], axis=1)
The final result (a fragment of first 3 records):
print(df.head(3))
embeddings dist0 dist1 \
0 [-0.009409046731889248, 0.01787922903895378, -... 1.000000 0.824427
1 [-0.0005574452807195485, -0.004265215713530779... 0.824427 1.000000
2 [-0.024396933615207672, -0.0016798048745840788... 0.757717 0.762072
dist2 dist3 dist4 dist5 dist6 dist7 dist8 \
0 0.757717 0.761481 0.858895 0.844244 0.781320 0.830562 0.869494
1 0.762072 0.768355 0.832918 0.813206 0.779384 0.822365 0.831671
2 1.000000 0.775206 0.757655 0.756076 0.770092 0.766206 0.765154
dist9 dist10 dist11 dist12 dist13 dist14 dist15 \
0 0.824993 0.838671 0.863087 0.809240 0.839480 0.852663 0.812553
1 0.843859 0.832757 0.846339 0.797901 0.833095 0.836512 0.794878
2 0.756694 0.765615 0.760532 0.754582 0.759305 0.749540 0.758749
dist16 dist17 dist18 dist19 dist20 dist21 dist22 \
0 0.834376 0.851168 0.853374 0.831500 0.786812 0.840630 0.831902
1 0.831882 0.829072 0.828278 0.773624 0.828781 0.814124 0.852540
2 0.749419 0.750785 0.759364 0.776753 0.761560 0.770000 0.766988
dist23 dist24 dist25 dist26 dist27 dist28 dist29 \
0 0.836812 0.807903 0.822919 0.837386 0.767737 0.815725 0.807334
1 0.799753 0.827189 0.812533 0.822119 0.788155 0.850503 0.843936
2 0.761825 0.773761 0.764308 0.757826 0.755465 0.772704 0.766396
dist30 dist31 dist32 dist33 dist34 dist35 dist36 \
0 0.832612 0.835909 0.819697 0.853597 0.806614 0.805309 0.822521
1 0.852967 0.842627 0.802803 0.860669 0.793716 0.787563 0.788239
2 0.762748 0.763906 0.765716 0.756643 0.766686 0.772603 0.760913
dist37 dist38 dist39 dist40 dist41 dist42 dist43 \
0 0.831307 0.834015 0.821262 0.812144 0.853028 0.849498 0.830675
1 0.845437 0.816868 0.833320 0.808172 0.835293 0.824654 0.856051
2 0.760276 0.754683 0.765499 0.756421 0.755651 0.763656 0.754828
dist44 dist45 dist46 dist47 dist48 ... dist50 dist51 \
0 0.861366 0.802735 0.789774 0.790563 0.827335 ... 0.820754 0.842522
1 0.854080 0.827517 0.839423 0.828683 0.812323 ... 0.802451 0.829247
2 0.760256 0.764869 0.754423 0.757319 0.774664 ... 0.747934 0.793632
dist52 dist53 dist54 dist55 dist56 dist57 dist58 \
0 0.827061 0.814656 0.813548 0.834271 0.818362 0.823394 0.828642
1 0.814514 0.834007 0.784510 0.796033 0.821271 0.821276 0.814710
2 0.759410 0.747319 0.783079 0.759875 0.742791 0.771096 0.759520
dist59 dist60 dist61 dist62 dist63 dist64 dist65 \
0 0.869624 0.840927 0.842052 0.859140 0.859804 0.840041 0.835204
1 0.835696 0.845089 0.810699 0.853660 0.834497 0.828624 0.803920
2 0.764160 0.758037 0.773802 0.762592 0.762257 0.751729 0.758366
dist66 dist67 dist68 dist69 dist70 dist71 dist72 \
0 0.816945 0.852561 0.815066 0.812858 0.844518 0.851627 0.838417
1 0.821947 0.812763 0.765442 0.795368 0.848876 0.831772 0.828389
2 0.759480 0.755786 0.762572 0.756787 0.769603 0.756226 0.750196
dist73 dist74 dist75 dist76 dist77 dist78 dist79 \
0 0.839868 0.846972 0.851668 0.860816 0.880957 0.845313 0.849569
1 0.822491 0.810707 0.812499 0.816586 0.828081 0.826785 0.813240
2 0.757696 0.746333 0.767805 0.759218 0.770810 0.766181 0.768756
dist80 dist81 dist82 dist83 dist84 dist85 dist86 \
0 0.870180 0.862554 0.866397 0.874742 0.899475 0.883464 0.879084
1 0.848113 0.840173 0.814944 0.826645 0.848822 0.818360 0.809330
2 0.763216 0.766606 0.762598 0.754603 0.767628 0.757145 0.774004
dist87 dist88 dist89 dist90 dist91 dist92 dist93 \
0 0.861392 0.874843 0.855589 0.851598 0.849689 0.854272 0.837288
1 0.827020 0.839443 0.822301 0.831517 0.815193 0.827057 0.813251
2 0.771965 0.768978 0.784956 0.768604 0.767573 0.759978 0.772354
dist94 dist95 dist96 dist97 dist98 dist99
0 0.841660 0.868675 0.867444 0.836115 0.829863 0.834038
1 0.799496 0.837142 0.833741 0.791625 0.819392 0.807420
2 0.767823 0.770422 0.756819 0.762370 0.774629 0.777811
[3 rows x 101 columns]