Context: I have a dataset too large to fit in memory I am training a Keras RNN on. I am using PySpark on an AWS EMR Cluster to train the model in batches that are small enough to be stored in memory. I was not able to implement the model as distributed using elephas
and I suspect this is related to my model being stateful. I'm not entirely sure though.
The dataframe has a row for every user and days elapsed from the day of install from 0 to 29. After querying the database I do a number of operations on the dataframe:
query = """WITH max_days_elapsed AS (
SELECT user_id,
max(days_elapsed) as max_de
FROM table
GROUP BY user_id
)
SELECT table.*
FROM table
LEFT OUTER JOIN max_days_elapsed USING (user_id)
WHERE max_de = 1
AND days_elapsed < 1"""
df = read_from_db(query) #this is just a custom function to query our database
#Create features vector column
assembler = VectorAssembler(inputCols=features_list, outputCol="features")
df_vectorized = assembler.transform(df)
#Split users into train and test and assign batch number
udf_randint = udf(lambda x: np.random.randint(0, x), IntegerType())
training_users, testing_users = df_vectorized.select("user_id").distinct().randomSplit([0.8,0.2],123)
training_users = training_users.withColumn("batch_number", udf_randint(lit(N_BATCHES)))
#Create and sort train and test dataframes
train = df_vectorized.join(training_users, ["user_id"], "inner").select(["user_id", "days_elapsed","batch_number","features", "kpi1", "kpi2", "kpi3"])
train = train.sort(["user_id", "days_elapsed"])
test = df_vectorized.join(testing_users, ["user_id"], "inner").select(["user_id","days_elapsed","features", "kpi1", "kpi2", "kpi3"])
test = test.sort(["user_id", "days_elapsed"])
The problem I am having is that I cannot seem to be able to filter on batch_number without caching train. I can filter on any of the columns that are in the original dataset in our database, but not on any column I have generated in pyspark after querying the database:
This: train.filter(train["days_elapsed"] == 0).select("days_elapsed").distinct.show()
returns only 0.
But, all of these return all of the batch numbers between 0 and 9 without any filtering:
train.filter(train["batch_number"] == 0).select("batch_number").distinct().show()
train.filter(train.batch_number == 0).select("batch_number").distinct().show()
train.filter("batch_number = 0").select("batch_number").distinct().show()
train.filter(col("batch_number") == 0).select("batch_number").distinct().show()
This also does not work:
train.createOrReplaceTempView("train_table")
batch_df = spark.sql("SELECT * FROM train_table WHERE batch_number = 1")
batch_df.select("batch_number").distinct().show()
All of these work if I do train.cache() first. Is that absolutely necessary or is there a way to do this without caching?