I want to load a word2vec model and evaluate it by executing word analogy tasks (e.g. a is to b as c is to something?). To do this, first I load my w2v model:
model = Word2VecModel.load(spark.sparkContext, str(sys.argv[1]))
and then I call the mapper to evaluate the model:
rdd_lines = spark.read.text("questions-words.txt").rdd.map(getAnswers)
The getAnswers
function reads one line per time from questions-words.txt, in which each line contains the question and the answer to evaluate my model (e.g. Athens Greece Baghdad Iraq, where a=Athens, b=Greece, c=Baghdad and something=Iraq). After reading the line, I create the current_question
and the actual_answer
(e.g.: current_question=Athens Greece Baghdad
and actual_answer=Iraq
). After that, I call the getAnalogy
function that is used to compute the analogy (basically, given the question it computes the answer). Finally, after computing the analogy, I return the answer and write it to a text file.
The problem is that I get the following exception:
Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers.
and I think that it is thrown because I am using the model within the map function. This question is similar to my problem but I do not know how to apply that answer to my code. How can I solve this problem? The following is the full code:
def getAnalogy(s, model):
try:
qry = model.transform(s[0]) - model.transform(s[1]) - model.transform(s[2])
res = model.findSynonyms((-1)*qry,5) # return 5 "synonyms"
res = [x[0] for x in res]
for k in range(0,3):
if s[k] in res:
res.remove(s[k])
return res[0]
except ValueError:
return "NOT FOUND"
def getAnswers (text):
tmp = text[0].split(' ', 3)
answer_list = []
current_question = " ".join(str(x) for x in tmp[:3])
actual_answer = tmp[-1]
model_answer = getAnalogy(current_question, model)
if model_answer is "NOT FOUND":
answer_list.append("NOT FOUND\n")
elif model_answer is actual_answer:
answer_list.append("TRUE\n")
else:
answer_list.append("FALSE:\n")
return answer_list.append
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: my_test <file>", file=sys.stderr)
exit(-1)
spark = SparkSession\
.builder\
.appName("my_test")\
.getOrCreate()
model = Word2VecModel.load(spark.sparkContext, str(sys.argv[1]))
rdd_lines = spark.read.text("questions-words.txt").rdd.map(getAnswers)
dataframe = rdd_lines.toDF()
dataframe.write.text(str(sys.argv[2]))
spark.stop()