This is exactly what I have done last year, so I might be in a good position to give you an answer.
First, here is my Spark implementation of the batch SOM algorithm (it is written in Scala, but most things will be similar in Pyspark).
I needed this algorithm for a project, and every implementation I found had at least one of these two problems or limitations:
- they did not really implement the batch SOM algorithm, but used a map averaging method that gave me strange results (abnormal symmetries in the output map)
- they did not use the DataFrame API (pure RDD API) and were not in the Spark ML/MLlib spirit, i.e. with a simple
fit()
/transform()
API operating over DataFrames.
So, there I went on to code it myself: the batch SOM algorithm in Spark ML style. The first thing I did was looking how k-means was implemented in Spark ML, because as you know, the batch SOM is very similar to the k-means algorithm. Actually, I could re-use a large portion of the Spark ML k-means code, but I had to modify the core algorithm and the hyperparameters.
I can summarize quickly how the model is built:
- A
SOMParams
class, containing the SOM hyperparameters (size, training parameters, etc.)
- A
SOM
class, which inherits from spark's Estimator
, and contains the training algorithm. In particular, it contains a fit()
method that operates on an input DataFrame
, where features are stored as a spark.ml.linalg.Vector
in a single column. fit()
will then select this column and unpack the DataFrame
to obtain the unerlying RDD[Vector]
of features, and call the run()
method on it. This is where all the computations happen, and as you guessed, it uses RDD
s, accumulators and broadcast variables. Finally, the fit()
method returns a SOMModel
object.
SOMModel
is a trained SOM model, and inherits from spark's Transformer
/Model
. It contains the map prototypes (center vectors), and contains a transform()
method that can operate on DataFrames
by taking an input feature column, and adding a new column with the predictions (projection on the map). This is done by a prediction UDF.
- There is also
SOMTrainingSummary
that collects stuff such as the objective function.
Here are the take-aways:
- There is not really an opposition between
RDD
and DataFrame
s (or rather Dataset
s, but the difference between those two is of no real importance here). They are just used in different contexts. In fact, a DataFrame can be seen as a RDD
specialized for manipulating structured data organized in columns (such as relational tables), allowing SQL-like operations and an optimization of the execution plan (Catalyst optimizer).
- For structured data, select/filter/aggregation operations, DO USE
Dataframe
s, always.
- ...but for more complex tasks such as a machine learning algorithm, you NEED to come back to the
RDD
API and distribute your computations yourself, using map/mapPartitions/foreach/reduce/reduceByKey/and so son. Look at how things are done in MLlib: it's only a nice wrapper around RDD manipulations!
Hope it will solve your question. Concerning performance, as you asked for an efficient implementation, I did not make any benchmarks yet but I use it at work and it crunches 500k/1M-rows datasets in a couple of minutes on the production cluster.