2

This code is working as expected and calculates cosign distance between two embeddings. But it takes a lot of time. I have tens of thousands of records to check and I am looking for a way to make it quicker.

import pandas as pd
import numpy as np
from numpy import dot
from numpy.linalg import norm

import ast

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")

for k, i in enumerate(df["embeddings"]):
    df["dist" + str(k)] = df.embeddings.apply(
        lambda x: dot(ast.literal_eval(x), ast.literal_eval(i))
        / (norm(ast.literal_eval(x)) * norm(ast.literal_eval(i)))
    )
shantanuo
  • 31,689
  • 78
  • 245
  • 403
  • 1
    Before your loop `df['embeddings'] = df['embeddings'].apply(ast.literal_eval)`. Now you have lists then remove all other `ast.literal_eval(...)` – Corralien Feb 27 '23 at 09:31
  • related: https://stackoverflow.com/a/40593934 – djvg Feb 27 '23 at 09:33
  • Can you embed a toy example of the `df` without resorting to downloading the csv file? – norok2 Feb 27 '23 at 09:57
  • The embeddings list is so big that it will not display / format properly. @norok2 – shantanuo Feb 27 '23 at 10:05
  • @shantanuo At this level it is important to know how big it is and the data types. If you provide a toy example this will tell the data types, and the size you could comment about. Even better if you'd provide a generative model. – norok2 Feb 27 '23 at 10:50

2 Answers2

2

Optimized way:

Instead of applying ast.literal_eval many times in a loop, load the input csv file with the needed structure at once using converters option to convert all 'embeddings' column array string representation into "real" array with numpy.fromstring routine.

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv", 
                 delimiter=',', usecols=["embeddings"], quotechar='"',
                 converters = {'embeddings': lambda s: np.fromstring(s[1:-1], sep=',')})

In the above process in my measurements numpy.fromstring runs in 6 times faster than converters = {'embeddings': ast.literal_eval}, though the latter, in it's turn, is faster than your initial approach.

Then, to avoid multiple dataframe insertions with a new dist<num> column replace for looping with pd.concat:

df = pd.concat([df] + [df.embeddings.apply(
                        lambda x: np.dot(x, arr)
                                  / (np.linalg.norm(x) * np.linalg.norm(arr))
                        ).rename(f'dist{i}')
                        for i, arr in enumerate(df["embeddings"])], axis=1)

The final result (a fragment of first 3 records):

print(df.head(3))

                                          embeddings     dist0     dist1  \
0  [-0.009409046731889248, 0.01787922903895378, -...  1.000000  0.824427   
1  [-0.0005574452807195485, -0.004265215713530779...  0.824427  1.000000   
2  [-0.024396933615207672, -0.0016798048745840788...  0.757717  0.762072   

      dist2     dist3     dist4     dist5     dist6     dist7     dist8  \
0  0.757717  0.761481  0.858895  0.844244  0.781320  0.830562  0.869494   
1  0.762072  0.768355  0.832918  0.813206  0.779384  0.822365  0.831671   
2  1.000000  0.775206  0.757655  0.756076  0.770092  0.766206  0.765154   

      dist9    dist10    dist11    dist12    dist13    dist14    dist15  \
0  0.824993  0.838671  0.863087  0.809240  0.839480  0.852663  0.812553   
1  0.843859  0.832757  0.846339  0.797901  0.833095  0.836512  0.794878   
2  0.756694  0.765615  0.760532  0.754582  0.759305  0.749540  0.758749   

     dist16    dist17    dist18    dist19    dist20    dist21    dist22  \
0  0.834376  0.851168  0.853374  0.831500  0.786812  0.840630  0.831902   
1  0.831882  0.829072  0.828278  0.773624  0.828781  0.814124  0.852540   
2  0.749419  0.750785  0.759364  0.776753  0.761560  0.770000  0.766988   

     dist23    dist24    dist25    dist26    dist27    dist28    dist29  \
0  0.836812  0.807903  0.822919  0.837386  0.767737  0.815725  0.807334   
1  0.799753  0.827189  0.812533  0.822119  0.788155  0.850503  0.843936   
2  0.761825  0.773761  0.764308  0.757826  0.755465  0.772704  0.766396   

     dist30    dist31    dist32    dist33    dist34    dist35    dist36  \
0  0.832612  0.835909  0.819697  0.853597  0.806614  0.805309  0.822521   
1  0.852967  0.842627  0.802803  0.860669  0.793716  0.787563  0.788239   
2  0.762748  0.763906  0.765716  0.756643  0.766686  0.772603  0.760913   

     dist37    dist38    dist39    dist40    dist41    dist42    dist43  \
0  0.831307  0.834015  0.821262  0.812144  0.853028  0.849498  0.830675   
1  0.845437  0.816868  0.833320  0.808172  0.835293  0.824654  0.856051   
2  0.760276  0.754683  0.765499  0.756421  0.755651  0.763656  0.754828   

     dist44    dist45    dist46    dist47    dist48  ...    dist50    dist51  \
0  0.861366  0.802735  0.789774  0.790563  0.827335  ...  0.820754  0.842522   
1  0.854080  0.827517  0.839423  0.828683  0.812323  ...  0.802451  0.829247   
2  0.760256  0.764869  0.754423  0.757319  0.774664  ...  0.747934  0.793632   

     dist52    dist53    dist54    dist55    dist56    dist57    dist58  \
0  0.827061  0.814656  0.813548  0.834271  0.818362  0.823394  0.828642   
1  0.814514  0.834007  0.784510  0.796033  0.821271  0.821276  0.814710   
2  0.759410  0.747319  0.783079  0.759875  0.742791  0.771096  0.759520   

     dist59    dist60    dist61    dist62    dist63    dist64    dist65  \
0  0.869624  0.840927  0.842052  0.859140  0.859804  0.840041  0.835204   
1  0.835696  0.845089  0.810699  0.853660  0.834497  0.828624  0.803920   
2  0.764160  0.758037  0.773802  0.762592  0.762257  0.751729  0.758366   

     dist66    dist67    dist68    dist69    dist70    dist71    dist72  \
0  0.816945  0.852561  0.815066  0.812858  0.844518  0.851627  0.838417   
1  0.821947  0.812763  0.765442  0.795368  0.848876  0.831772  0.828389   
2  0.759480  0.755786  0.762572  0.756787  0.769603  0.756226  0.750196   

     dist73    dist74    dist75    dist76    dist77    dist78    dist79  \
0  0.839868  0.846972  0.851668  0.860816  0.880957  0.845313  0.849569   
1  0.822491  0.810707  0.812499  0.816586  0.828081  0.826785  0.813240   
2  0.757696  0.746333  0.767805  0.759218  0.770810  0.766181  0.768756   

     dist80    dist81    dist82    dist83    dist84    dist85    dist86  \
0  0.870180  0.862554  0.866397  0.874742  0.899475  0.883464  0.879084   
1  0.848113  0.840173  0.814944  0.826645  0.848822  0.818360  0.809330   
2  0.763216  0.766606  0.762598  0.754603  0.767628  0.757145  0.774004   

     dist87    dist88    dist89    dist90    dist91    dist92    dist93  \
0  0.861392  0.874843  0.855589  0.851598  0.849689  0.854272  0.837288   
1  0.827020  0.839443  0.822301  0.831517  0.815193  0.827057  0.813251   
2  0.771965  0.768978  0.784956  0.768604  0.767573  0.759978  0.772354   

     dist94    dist95    dist96    dist97    dist98    dist99  
0  0.841660  0.868675  0.867444  0.836115  0.829863  0.834038  
1  0.799496  0.837142  0.833741  0.791625  0.819392  0.807420  
2  0.767823  0.770422  0.756819  0.762370  0.774629  0.777811  

[3 rows x 101 columns]
RomanPerekhrest
  • 88,541
  • 4
  • 65
  • 105
1

I'm trying to give a solution with creating a database for the data embeddings, then calculate the distance with the library scipy.spatial.distance, the df3 is the result you expected.

import pandas as pd
import numpy as np
from scipy.spatial import distance
from numpy.linalg import norm
import ast
import time 


start = time.perf_counter()

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")

num_colums = len(ast.literal_eval(df["embeddings"][0]))
num_lines = len(df["embeddings"])
data_base_embeddings =np.zeros((num_lines, num_colums))

colums_list = []

for k, i in enumerate(df["embeddings"]):
    embeddings_temp = np.array(ast.literal_eval(i))
    data_base_embeddings[k, :] = embeddings_temp
    colums_list.append("dist" + str(k))
    


dist_matix = distance.cdist(data_base_embeddings, data_base_embeddings, lambda u, v: np.dot(u,v) / (norm(u)*norm(v)))

df2 = pd.DataFrame(data=dist_matix, columns=colums_list)

df3 = df.join(df2)

end = time.perf_counter()

print("Running time: ", str(end - start), "s")

Which gives you a running time including loading data :

Running time: 2.0602212000012514 s

With not including loading data:

Running time: 0.7129389000001538 s

HMH1013
  • 1,216
  • 2
  • 13