In my PySpark application, I have two RDD's:
items - This contains item ID and item name for all valid items. Approx 100000 items.
attributeTable - This contains the fields user ID, item ID and an attribute value of this combination in that order. These is a certain attribute for each user-item combination in the system. This RDD has several 100s of 1000s of rows.
I would like to discard all rows in attributeTable RDD that don't correspond to a valid item ID (or name) in the items RDD. In other words, a semi-join by the item ID. For instance, if these were R data frames, I would have done semi_join(attributeTable, items, by="itemID")
I tried the following approach first, but found that this takes forever to return (on my local Spark installation running on a VM on my PC). Understandably so, because there are such a huge number of comparisons involved:
# Create a broadcast variable of all valid item IDs for doing filter in the drivers
validItemIDs = sc.broadcast(items.map(lambda (itemID, itemName): itemID)).collect())
attributeTable = attributeTable.filter(lambda (userID, itemID, attributes): itemID in set(validItemIDs.value))
After a bit of fiddling around, I found that the following approach works pretty fast (a min or so on my system).
# Create a broadcast variable for item ID to item name mapping (dictionary)
itemIdToNameMap = sc.broadcast(items.collectAsMap())
# From the attribute table, remove records that don't correspond to a valid item name.
# First go over all records in the table and add a dummy field indicating whether the item name is valid
# Then, filter out all rows with invalid names. Finally, remove the dummy field we added.
attributeTable = (attributeTable
.map(lambda (userID, itemID, attributes): (userID, itemID, attributes, itemIdToNameMap.value.get(itemID, 'Invalid')))
.filter(lambda (userID, itemID, attributes, itemName): itemName != 'Invalid')
.map(lambda (userID, itemID, attributes, itemName): (userID, itemID, attributes)))
Although this works well enough for my application, it feels more like a dirty workaround and I am pretty sure there must be another cleaner or idiomatically correct (and possibly more efficient) way or ways to do this in Spark. What would you suggest? I am new to both Python and Spark, so any RTFM advices will also be helpful if you could point me to the right resources.
My Spark version is 1.3.1.