0

I have a pyspark dataframe where multiple columns contain arrays of different lengths. I want to iterate through the relevant columns and clip the arrays in each row so that they are the same length. In this example, length of 3.

This is an example dataframe:

id_1|id_2|id_3|        timestamp     |thing1       |thing2       |thing3
A   |b  |  c |[time_0,time_1,time_2]|[1.2,1.1,2.2]|[1.3,1.5,2.6|[2.5,3.4,2.9]
A   |b  |  d |[time_0,time_1]       |[5.1,6.1, 1.4, 1.6]    |[5.5,6.2, 0.2]   |[5.7,6.3]
A   |b  |  e |[time_0,time_1]       |[0.1,0.2, 1.1]    |[0.5,0.3, 0.3]   |[0.9,0.6, 0.9, 0.4]

So far I have,

 def clip_func(x, ts_len, backfill=1500):
     template = [backfill]*ts_len
     template[-len(x):] = x
     x = template
     return x[-1 * ts_len:]

clip = udf(clip_func, ArrayType(DoubleType()))

for c in [x for x in example.columns if 'thing' in x]:
    missing_fill = 3.3
    ans = ans.withColumn(c, clip(c, 3, missing_fill))

But is not working. If the array is too short, I want to fill the array with the missing_fill value.

Cards14
  • 99
  • 1
  • 9
  • I get this error: TypeError: Invalid argument, not a string or column: 24 of type . For column literals, use 'lit', 'array', 'struct' or 'create_map' function. – Cards14 Apr 08 '19 at 14:17
  • have you tried `clip = udf(clip_func, DoubleType())`? the example in the docs uses `IntegerType`, not `ArrayType`, so that would be my only suggestion on what looks wrong here. – Stael Apr 08 '19 at 14:27
  • Your question asks to clip the array to a given length, but based on your code it seems what you really want to do is have the `thing` arrays be the same length as the `timestamp` arrays. Is that correct? – pault Apr 08 '19 at 14:27
  • Anyway your error is because you're passing python literals to the `udf` when you should be passing column literals (`pyspark.sql.functions.lit`) – pault Apr 08 '19 at 14:31
  • `df = df.withColumn(c, clip(col(c),lit(3),lit(missing_fill)))` Atleast you won't get the error. – cph_sto Apr 08 '19 at 14:55
  • yeah append to beginning is right. I used lit and that stopped giving me an error. – Cards14 Apr 08 '19 at 15:04

1 Answers1

1

Your error is cause by passing in 3 and missing_fill as python literals to clip. As described in this answer, the inputs to the udf are converted to columns.

You should instead be passing in column literals.

Here is a simplified example DataFrame:

example.show(truncate=False)
#+---+------------------------+--------------------+---------------+--------------------+
#|id |timestamp               |thing1              |thing2         |thing3              |
#+---+------------------------+--------------------+---------------+--------------------+
#|A  |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]     |[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]     |
#|B  |[time_0, time_1]        |[5.1, 6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[5.7, 6.3]          |
#|C  |[time_0, time_1]        |[0.1, 0.2, 1.1]     |[0.5, 0.3, 0.3]|[0.9, 0.6, 0.9, 0.4]|
#+---+------------------------+--------------------+---------------+--------------------+

You just need to make one small change in the arguments passed to the udf:

from pyspark.sql.functions import lit, udf

def clip_func(x, ts_len, backfill):
    template = [backfill]*ts_len
    template[-len(x):] = x
    x = template
    return x[-1 * ts_len:]

clip = udf(clip_func, ArrayType(DoubleType()))

ans = example
for c in [x for x in example.columns if 'thing' in x]:
    missing_fill = 3.3
    ans = ans.withColumn(c, clip(c, lit(3), lit(missing_fill)))

ans.show(truncate=False)
#+---+------------------------+---------------+---------------+---------------+
#|id |timestamp               |thing1         |thing2         |thing3         |
#+---+------------------------+---------------+---------------+---------------+
#|A  |[time_0, time_1, time_2]|[1.2, 1.1, 2.2]|[1.3, 1.5, 2.6]|[2.5, 3.4, 2.9]|
#|B  |[time_0, time_1]        |[6.1, 1.4, 1.6]|[5.5, 6.2, 0.2]|[3.3, 5.7, 6.3]|
#|C  |[time_0, time_1]        |[0.1, 0.2, 1.1]|[0.5, 0.3, 0.3]|[0.6, 0.9, 0.4]|
#+---+------------------------+---------------+---------------+---------------+

As your udf is currently written:

  • When the array is longer than ts_len, it will truncate the array from the beginning (left side).
  • When the array is shorter than ts_len, it will append the missing_fill at the start of the array.
pault
  • 41,343
  • 15
  • 107
  • 149