3

I have seen a few solutions to unpivot a spark dataframe when the number of columns is reasonably low and that the columns' names can be hardcoded. Do you have a scalable solution to unpivot a dataframe with numerous columns?

Below is a toy problem.

Input:

  val df = Seq(
    (1,1,1,0),
    (2,0,0,1)    
  ).toDF("ID","A","B","C")

+---+--------+----+
| ID|  A | B | C  |
+---+--------+-----
|  1|  1 | 1 | 0  |
|  2|  0 | 0 | 1  |
+---+----------+--+

expected result:

+---+-----+-----+
| ID|names|count|
+---+-----------|
|  1|  A  |  1  |
|  1|  B  |  1  |
|  1|  C  |  0  |
|  2|  A  |  0  |
|  2|  B  |  0  |
|  2|  C  |  1  |
+---+-----------+

The solution should be applicable to datasets with N columns to unpivot, where N is large (say 100 columns).

SCouto
  • 7,808
  • 5
  • 32
  • 49
mobupu
  • 245
  • 3
  • 10
  • 1
    Does this answer your question? [Unpivot in spark-sql/pyspark](https://stackoverflow.com/questions/42465568/unpivot-in-spark-sql-pyspark) – blackbishop Feb 13 '20 at 15:49
  • 1
    No, the answer you link does not apply to the general case as I ask here. – mobupu Feb 13 '20 at 16:25

2 Answers2

9

This should work, I am assuming you know the list of columns that you want to unpivot on

import org.apache.spark.sql.{functions => func, _}

val df = Seq(
    (1,1,1,0),
    (2,0,0,1)    
  ).toDF("ID","A","B","C")

val cols = Seq("A", "B", "C")

df.select(
    $"ID",
    func.explode(
        func.array(
            cols.map(
                col =>
                    func.struct(    
                        func.lit(col).alias("names"),
                        func.col(col).alias("count")
                    )
            ): _*
        )
    ).alias("v")
)
.selectExpr("ID", "v.*")
Mikel San Vicente
  • 3,831
  • 2
  • 21
  • 39
4

This can be done in pure spark Sql, by stacking the columns.

Below is an example in pyspark that can be easily adapted to Scala. The python code is relevant only for dynamically constructing the Sql based on the relevant fields. I use this approach quite often.

from pyspark.sql.types import * 

df = spark.createDataFrame([
  {"id" : 1, "A" : 1, "B" : 1, "C" : 0},
  {"id" : 2, "A" : 0, "B" : 0, "C" : 1}
],
StructType([StructField("id", IntegerType()), StructField("A", IntegerType()),StructField("B", IntegerType()) , StructField("C", IntegerType())]))

def features_to_eav(df, subset=None):

  relevant_columns = subset if subset else df.columns
  n = len(relevant_columns)
  cols_to_stack = ", ".join(['\'{c}\', {c}'.format(c=c) for c in relevant_columns]) 
  stack_expression = "stack({}, {}) as (name, value)".format(n, cols_to_stack)

  fetures_to_check_df = df.select(*(["id"] + relevant_columns)).createOrReplaceTempView("features_to_check")

  sql = "select id, {} from features_to_check".format(stack_expression)
  print ("Stacking sql:", sql)
  return spark.sql(sql)

features_to_eav(df, ["A", "B", "C"]).show()

The output (pay attention to the constructed sql):

Stacking sql: select id, stack(3, 'A', A, 'B', B, 'C', C) as (name, value) from features_to_check
+---+----+-----+
| id|name|value|
+---+----+-----+
|  1|   A|    1|
|  1|   B|    1|
|  1|   C|    0|
|  2|   A|    0|
|  2|   B|    0|
|  2|   C|    1|
+---+----+-----+
Vitaliy
  • 8,044
  • 7
  • 38
  • 66