0

Trying to create a new column in a PySpark UDF but the values are null!

Create the DF

data_list = [['a', [1, 2, 3]], ['b', [4, 5, 6]],['c', [2, 4, 6, 8]],['d', [4, 1]],['e', [1,2]]]
all_cols = ['COL1','COL2']
df = sqlContext.createDataFrame(data_list, all_cols)
df.show()
+----+------------+
|COL1|        COL2|
+----+------------+
|   a|   [1, 2, 3]|
|   b|   [4, 5, 6]|
|   c|[2, 4, 6, 8]|
|   d|      [4, 1]|
|   e|      [1, 2]|
+----+------------+

df.printSchema()
root
 |-- COL1: string (nullable = true)
 |-- COL2: array (nullable = true)
 |    |-- element: long (containsNull = true)

Create a function

def cr_pair(idx_src, idx_dest):
    idx_dest.append(idx_dest.pop(0))
    return idx_src, idx_dest
lst1 = [1,2,3]
lst2 = [1,2,3]
cr_pair(lst1, lst2)
([1, 2, 3], [2, 3, 1])

Create and register a UDF

from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

from pyspark.sql.types import ArrayType
get_idx_pairs = udf(lambda x: cr_pair(x, x), ArrayType(IntegerType()))

Add a new column to the DF

df = df.select('COL1', 'COL2',  get_idx_pairs('COL2').alias('COL3'))
df.printSchema()
root
 |-- COL1: string (nullable = true)
 |-- COL2: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- COL3: array (nullable = true)
 |    |-- element: integer (containsNull = true)

df.show()
+----+------------+------------+
|COL1|        COL2|        COL3|
+----+------------+------------+
|   a|   [1, 2, 3]|[null, null]|
|   b|   [4, 5, 6]|[null, null]|
|   c|[2, 4, 6, 8]|[null, null]|
|   d|      [4, 1]|[null, null]|
|   e|      [1, 2]|[null, null]|
+----+------------+------------+

Here where the problem is. I am getting all values 'null' in the COL3 column. The intended outcome should be:

+----+------------+----------------------------+
|COL1|        COL2|                        COL3|
+----+------------+----------------------------+
|   a|   [1, 2, 3]|[[1 ,2, 3], [2, 3, 1]]      |
|   b|   [4, 5, 6]|[[4, 5, 6], [5, 6, 4]]      |
|   c|[2, 4, 6, 8]|[[2, 4, 6, 8], [4, 6, 8, 2]]|
|   d|      [4, 1]|[[4, 1], [1, 4]]            |
|   e|      [1, 2]|[[1, 2], [2, 1]]            |
+----+------------+----------------------------+
pault
  • 41,343
  • 15
  • 107
  • 149
TSAR
  • 683
  • 1
  • 6
  • 8

2 Answers2

2

Your UDF should return ArrayType(ArrayType(IntegerType())) since you are expecting a list of lists in your column, besides it only needs one parameter:

def cr_pair(idx_src):
    return idx_src, idx_src[1:] + idx_src[:1]

get_idx_pairs = udf(cr_pair, ArrayType(ArrayType(IntegerType())))
df.withColumn('COL3', get_idx_pairs(df['COL2'])).show(5, False)
+----+------------+----------------------------+
|COL1|COL2        |COL3                        |
+----+------------+----------------------------+
|a   |[1, 2, 3]   |[[2, 3, 1], [2, 3, 1]]      |
|b   |[4, 5, 6]   |[[5, 6, 4], [5, 6, 4]]      |
|c   |[2, 4, 6, 8]|[[4, 6, 8, 2], [4, 6, 8, 2]]|
|d   |[4, 1]      |[[1, 4], [1, 4]]            |
|e   |[1, 2]      |[[2, 1], [2, 1]]            |
+----+------------+----------------------------+
Psidom
  • 209,562
  • 33
  • 339
  • 356
  • Wow, that was quick. Thank you Psidom! Almost there. The outcome of the 'show' is a bit off. See below... well I cannot properly copy the outcome but in front of the list of lists I am getting "[WrappedArray(1, .." – TSAR Jul 23 '18 at 20:03
  • Do you mean the `...`? That's just the print truncation. See the update. – Psidom Jul 23 '18 at 20:06
  • I see your edit. Please don't edit the answer if it doesn't answer the question though. With that said, what python and spark version are you using? I am getting the results fine with spark `2.3.0` and python `3.6.5`. – Psidom Jul 23 '18 at 20:12
  • Python 2.7.13 |Anaconda 4.4.0 (64-bit) and PySpark 2.2.0 – TSAR Jul 23 '18 at 20:13
  • BTW, how do I enter the results or code from within a 'comment'? I have tried it and it always is off 'code format'. That was the reason I edited your reply!! – TSAR Jul 23 '18 at 20:17
  • The whole table might be too long to fit the comment. You could just paste one cell if you want to show what you have. BTW what you see might just be a print issue that spark `2.2.0` has not optimized. – Psidom Jul 23 '18 at 20:19
  • @TSAR are you asking why it shows `WrappedArray(...)` instead of the list of lists as in your intended outcome? – pault Jul 23 '18 at 20:33
  • Yes Pault, this is my question, why it shows WrappedArray(...)?? From what you have said it seems that it is a 'print issue; with spark 2.2.0, correct? – TSAR Jul 26 '18 at 17:36
2

It seems like what you want to do is circularly shift the elements in your list. Here is a non-udf approach using pyspark.sql.functions.posexplode() (Spark version 2.1 and above):

import pyspark.sql.functions as f
from pyspark.sql import Window

w = Window.partitionBy("COL1", "COL2").orderBy(f.col("pos") == 0, "pos")
df = df.select("*", f.posexplode("COL2"))\
    .select("COL1", "COL2", "pos", f.collect_list("col").over(w).alias('COL3'))\
    .where("pos = 0")\
    .drop("pos")\
    .withColumn("COL3", f.array("COL2", "COL3"))

df.show(truncate=False)
#+----+------------+----------------------------------------------------+
#|COL1|COL2        |COL3                                                |
#+----+------------+----------------------------------------------------+
#|a   |[1, 2, 3]   |[WrappedArray(1, 2, 3), WrappedArray(2, 3, 1)]      |
#|b   |[4, 5, 6]   |[WrappedArray(4, 5, 6), WrappedArray(5, 6, 4)]      |
#|c   |[2, 4, 6, 8]|[WrappedArray(2, 4, 6, 8), WrappedArray(4, 6, 8, 2)]|
#|d   |[4, 1]      |[WrappedArray(4, 1), WrappedArray(1, 4)]            |
#|e   |[1, 2]      |[WrappedArray(1, 2), WrappedArray(2, 1)]            |
#+----+------------+----------------------------------------------------+

Using posexplode will return two columns- the position in the list (pos) and the value (col). The trick here is that we order by f.col("pos") == 0 first and then "pos". This will move the first position in the array to the end of the list.

Though this output prints differently than you would expect with list of lists in python, the contents of COL3 are indeed a list of lists of integers.

df.printSchema()
#root
# |-- COL1: string (nullable = true)
# |-- COL2: array (nullable = true)
# |    |-- element: long (containsNull = true)
# |-- COL3: array (nullable = false)
# |    |-- element: array (containsNull = true)
# |    |    |-- element: long (containsNull = true)

Update

The "WrappedArray prefix" is just the way Spark prints nested lists. The underlying array is exactly as you need it. One way to verify this is by calling collect() and inspecting the data:

results = df.collect()
print([(r["COL1"], r["COL3"]) for r in results])
#[(u'a', [[1, 2, 3], [2, 3, 1]]),
# (u'b', [[4, 5, 6], [5, 6, 4]]),
# (u'c', [[2, 4, 6, 8], [4, 6, 8, 2]]),
# (u'd', [[4, 1], [1, 4]]),
# (u'e', [[1, 2], [2, 1]])]

Or if you converted df to a pandas DataFrame:

print(df.toPandas())
#  COL1          COL2                          COL3
#0    a     [1, 2, 3]        ([1, 2, 3], [2, 3, 1])
#1    b     [4, 5, 6]        ([4, 5, 6], [5, 6, 4])
#2    c  [2, 4, 6, 8]  ([2, 4, 6, 8], [4, 6, 8, 2])
#3    d        [4, 1]              ([4, 1], [1, 4])
#4    e        [1, 2]              ([1, 2], [2, 1])
pault
  • 41,343
  • 15
  • 107
  • 149
  • This is great Pault, thank you! However, I don't think the 'window function' is needed for my use case. I only need to (yes) get the original list and its 'one-position-shift' returned as a pair. I am basically trying to mimic the np.roll functionality from NumPy. If I could get the contents of the COL3 without the 'WrappedArray' prefix would solve my problem. Thanks again, your effort is much appreciated! – TSAR Jul 23 '18 at 21:15
  • What I am trying to tell you is that the "WrappedArray prefix" is just the way it prints. The contents are exactly as you need it. I'm using the `Window` here because it's this approach is only way (that I can think of) to avoid the `udf` (udf's are slower). – pault Jul 23 '18 at 21:17
  • @TSAR see the edit where I collect the data. The result is a list of lists as you would expect. – pault Jul 23 '18 at 21:28
  • @TSAR read [this post](https://stackoverflow.com/questions/38296609/spark-functions-vs-udf-performance) which explains why you usually avoid `udf`s when possible. – pault Jul 24 '18 at 19:59