Gensim v4.x.x simplified a lot of what @gojomo described above, as he also explained in his other answer here. Based on those answers, here's an example of how you can multiprocess most_similar in a memory-efficient way, including logging of progress with tqdm. Swap in your own model/dataset to see how this works at scale.
import multiprocessing
from functools import partial
from typing import Dict, List, Tuple
import tqdm
from gensim.models.word2vec import Word2Vec
from gensim.models.keyedvectors import KeyedVectors
from gensim.test.utils import common_texts
def get_most_similar(
word: str, keyed_vectors: KeyedVectors, topn: int
) -> List[Tuple[str, float]]:
try:
return keyed_vectors.most_similar(word, topn=topn)
except KeyError:
return []
def get_most_similar_batch(
word_batch: List[str], word_vectors_path: str, topn: int
) -> Dict[str, List[Tuple[str, float]]]:
# Load the keyedvectors with mmap, so memory isn't duplicated
keyed_vectors = KeyedVectors.load(word_vectors_path, mmap="r")
return {word: get_most_similar(word, keyed_vectors, topn) for word in word_batch}
def create_batches_from_iterable(iterable, batch_size=1000):
return [iterable[i : i + batch_size] for i in range(0, len(iterable), batch_size)]
if __name__ == "__main__":
model = Word2Vec(
sentences=common_texts, vector_size=100, window=5, min_count=1, workers=4
)
# Save wv, so it can be reloaded with mmap later
word_vectors_path = "word2vec.wordvectors"
model.wv.save(word_vectors_path)
# Dummy set of words to find most similar words for
words_to_match = list(model.wv.key_to_index.keys())
# Multiprocess
batches = create_batches_from_iterable(words_to_match, batch_size=2)
partial_func = partial(
get_most_similar_batch,
word_vectors_path=word_vectors_path,
topn=5,
)
words_most_similar = dict()
num_workers = multiprocessing.cpu_count()
with multiprocessing.Pool(num_workers) as pool:
max_ = len(batches)
with tqdm.tqdm(total=max_) as pbar:
# imap required for tqdm to function properly
for result in pool.imap(partial_func, batches):
words_most_similar.update(result)
pbar.update()