1

I have the input csv file like below -

plant_id,  system1_id, system2_id, system3_id
A1          s1-111      s2-111     s3-111
A2          s1-222      s2-222     s3-222
A3          s1-333      s2-333     s3-333

I want to flatten the record like this below

plant_id    system_id     system_name   
A1          s1-111        system1
A1          s2-111        system2
A1          s3-111        system3
A2          s1-222        system1
A2          s2-222        system2
A2          s3-222        system3
A3          s1-333        system1
A3          s2-333        system2
A3          s3-333        system3

currently I am able to achieve it by creating a transposed pyspark df for each system column and then doing union at the end for all the df's. But it requires to write a long piece of code. Is there way to achieve it using few lines of code?

OneCricketeer
  • 179,855
  • 19
  • 132
  • 245
Codegator
  • 459
  • 7
  • 28

2 Answers2

1

Use stack:

df2 = df.selectExpr(
    'plant_id',
    """stack(
         3,
         system1_id, 'system1_id', system2_id, 'system2_id', system3_id, 'system3_id')
         as (system_id, system_name)"""
)

df2.show()
+--------+---------+-----------+
|plant_id|system_id|system_name|
+--------+---------+-----------+
|      A1|   s1-111| system1_id|
|      A1|   s2-111| system2_id|
|      A1|   s3-111| system3_id|
|      A2|   s1-222| system1_id|
|      A2|   s2-222| system2_id|
|      A2|   s3-222| system3_id|
|      A3|   s1-333| system1_id|
|      A3|   s2-333| system2_id|
|      A3|   s3-333| system3_id|
+--------+---------+-----------+
mck
  • 40,932
  • 13
  • 35
  • 50
1

1. Preparing the sample input data

from pyspark.sql import functions as F
sampleData = (('A1','s1-111','s2-111','s3-111'),
        ('A2','s1-222','s2-222','s3-222'),
        ('A3','s1-333','s2-222','s3-333')
        )

2. Creating the list of input data columns
columns = ['plant_id','system1_id','system2_id','system3_id']

3. Creating the Spark DataFrame

df = spark.createDataFrame(data=sampleData, schema=columns)
df.show()
+--------+----------+----------+----------+
|plant_id|system1_id|system2_id|system3_id|
+--------+----------+----------+----------+
|      A1|    s1-111|    s2-111|    s3-111|
|      A2|    s1-222|    s2-222|    s3-222|
|      A3|    s1-333|    s2-222|    s3-333|
+--------+----------+----------+----------+

4. We are using the stack() function to separate multiple columns into rows. Here is the stack function syntax: stack(n, expr1, ..., exprk) - Separates expr1, ..., exprk into n rows.

finalDF = df.select('plant_id',F.expr("stack(3,system1_id, 'system1_id', system2_id, 'system2_id', system3_id, 'system3_id') as (system_id, system_name)"))

finalDF.show()
+--------+---------+-----------+
|plant_id|system_id|system_name|
+--------+---------+-----------+
|      A1|   s1-111| system1_id|
|      A1|   s2-111| system2_id|
|      A1|   s3-111| system3_id|
|      A2|   s1-222| system1_id|
|      A2|   s2-222| system2_id|
|      A2|   s3-222| system3_id|
|      A3|   s1-333| system1_id|
|      A3|   s2-222| system2_id|
|      A3|   s3-333| system3_id|
+--------+---------+-----------+

Vijay_Shinde
  • 1,332
  • 2
  • 17
  • 38