2

I'm attempting to create a new column in my Spark Dataframe that is based on:

  1. a previous value of this column (i.e. the new value in the column is based on the values above it, which in turn are based on...)

  2. a very complex conditional statement (24 different conditions) depending on the values of other columns (and the lagged value of variable itself)

For example, something like the logic in this loop:

for row, i in df:
    if row.col1 == "a":
        row.col4 = row.col1 + row.col3
        row.col5 = 11
    if row.col1 == "b":
        if row.col3 == 1:
            row.col4 = lag(row.col4) + row.col1 + row.col2
            row.col5 = 14
        if row.col3 == 0:
            row.col4 = lag(row.col4) + row.col1 + row.col3)
            row.col5 = 17
    if row.col1 == "d":
        if row.col3 == 1:
            row.col4 = 99
            row.col5 = 19
    if lag(row.col4) == 99:
        row.col4 = lag(row.col4) + row.col5
        row.col5 = etc...

(...plus another 21 possible values of c and d)

Example

I want to convert this:

w = Window.orderBy(col("col1").asc())

df = spark.createDataFrame([
    ("a", 2, 0),
    ("b", 3, 1),
    ("b", 4, 0),
    ("d", 5, 1),
    ("e", 6, 0),
    ("f", 7, 1)
], ["col1", "col2","col3"])

+----+----+----+
|col1|col2|col3|
+----+----+----+
|   a|   2|   0|
|   b|   3|   1|
|   b|   4|   0|
|   d|   5|   1|
|   e|   6|   0|
|   f|   7|   1|
+----+----+----+

...into this:

+----+----+----+--------+-----------------------------------------------------+-----+---------------------------+
|col1|col2|col3|col4    >(explanation)                                        |col5 >(also uses complex logic)  |
+----+----+----+--------+-----------------------------------------------------+-----+---------------------------+
|   a|   2|   0|a0      >(because (col1==a) ==> col1+col3)                    |11   >                           |
|   b|   3|   1|a0b3    >(because (col1==b & col3==1) ==> lag(col4)+col1+col2)|14   >                           |
|   b|   4|   0|a0b3b0  >(because (col1==b & col3==0) ==> lag(col4)+col1+col3)|17   >                           |
|   d|   5|   1|99      >(because (col1==d) ==> 99)                           |19   >                           |
|   e|   6|   0|9919    >(because (lag(col4)==99) ==> lag(col4)+col5          |e6   >                           |
|   f|   7|   1|etc...  >etc...                                               |etc..>etc...                     |
+----+----+----+--------+-----------------------------------------------------+-----+---------------------------+

Is this at all possible in Spark? Nothing I've tried has worked:

  • I haven't found a way to feed the output of a UDF back into the next UDF calculation
  • The conditional + self-reference makes storing previous values in temporary columns basically impossible.
  • I tried using gigantic when clauses but I get tripped up referencing the lagged values of the column itself within the withColumn() statement. Another problem with the when() + lag() approach is that other variables are referencing the lagged variable, and the lagged variable is referencing other variables. (in other words, there is just one lagged value getting fed into each row, but that value interacts differently with other variables based on the conditions met by that row.
cronoik
  • 15,434
  • 3
  • 40
  • 78
ropeladder
  • 1,103
  • 11
  • 24
  • Yes that is possible, you need a the [lag](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.lag) window function and [when](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.functions.when). There is no need of an udf based on your current describtion. Can you please add an [reproducible example](https://stackoverflow.com/questions/48427185/how-to-make-good-reproducible-apache-spark-examples) to your question? – cronoik Oct 28 '19 at 22:57
  • Updated. See my last bullet point for why using lag() and when() isn't working (at least not my naive approach). – ropeladder Oct 29 '19 at 14:02
  • Another problem is that one lag() value can depend on another lag(). – cronoik Nov 03 '19 at 02:47

1 Answers1

3

If you are fine with UDF, then it's simple (I just copied your conditions in the below code). For non-UDF solution, it depends on how lag columns show in the if condition, unless you can supply examples with more or the most complex if condition, I'd say that UDF is the simplest way to go.

side-note: the following only works when the data can be properly partitioned, for example, there should be one or more columns which can be used to groupby and thus the related rows all be processed in the same partition.

from pyspark.sql.functions import udf, lit, collect_list, struct

@udf('array<struct<col1:string,col2:int,col3:int,col4:string,col5:string>>')
def gen_col(rows):
  new_rows = []
  for row in sorted(rows, key=lambda x: x.col2):
    if row.col1 == 'a':
        col4 = row.col1 + str(row.col3)
        col5 = '11'
    elif row.col1 == "b":
        if row.col3 == 1:
            col4 = col4 + row.col1 + str(row.col2)
            col5 = '14'
        if row.col3 == 0:
            col4 = col4 + row.col1 + str(row.col3)
            col5 = '17'
    elif row.col1 == "d":
        if row.col3 == 1:
            col4 = '99'
            col5 = '19'
    elif col4 == '99':
        col4 = col4 + col5
        col5 = row.col1 + str(row.col2)
    else:
        col4 = None
        col5 = None
    new_rows.append(dict(col4=col4, col5=col5, **row.asDict()))
  return new_rows

df.groupby(lit(1)) \
  .agg(gen_col(collect_list(struct(df.columns))).alias('new')) \
  .selectExpr('inline(new)') \
  .show()
+----+----+----+------+----+
|col1|col2|col3|  col4|col5|
+----+----+----+------+----+
|   a|   2|   0|    a0|  11|
|   b|   3|   1|  a0b3|  14|
|   b|   4|   0|a0b3b0|  17|
|   d|   5|   1|    99|  19|
|   e|   6|   0|  9919|  e6|
|   f|   7|   1|  null|null|
+----+----+----+------+----+
jxc
  • 13,553
  • 4
  • 16
  • 34
  • If you will be using UDF, you must declare it in Scala. It will be way more efficient by skipping data de/serialization from python to Scala and back again. – ML_TN Nov 04 '19 at 23:33
  • @ML_TN, I agree with you that using Scala-based udf through pyspark SQL would improve performance. But it's not mandatory that we must use it. In many cases, using python-based udf is enough for the tasks and much easier to maintain. :) just my 2 cents. – jxc Nov 05 '19 at 01:51
  • it's always a matter of choice to build system with the best performance – ML_TN Nov 05 '19 at 17:05