I have a problem with a user defined function built for concatenating values from one dataframe that matches index value from the other dataframe.
Here are the simplified dataframes that I am trying to match:
a_df:
+-------+------+
| index | name |
+-------+------+
| 1 | aaa |
| 2 | bbb |
| 3 | ccc |
| 4 | ddd |
| 5 | eee |
+-------+------+
b_df:
+-------+------+
| index | code |
+-------+------+
| 1 | 101 |
| 2 | 102 |
| 3 | 101 |
| 3 | 102 |
| 4 | 103 |
| 4 | 104 |
| 5 | 101 |
+-------+------+
udf function & call:
> def concatcodes(index, dataframe):
> res = dataframe.where(dataframe.index == index).collect()
> reslist = "|".join([value.code for value in res])
> return reslist
>
> spark.udf.register("concatcodes", concatcodes, StringType())
>
> resultDF = a_DF.withColumn("codes", lit(concatcodes(a_DF.index, b_df)))
I expect the function to be called per each row of the a_DF dataframe, resulting in the following output:
+-------+------+-------+
| index | name |codes |
+-------+------+-------+
| 1 | aaa |101 |
| 2 | bbb |102 |
| 3 | ccc |101|102|
| 4 | ddd |103|104|
| 5 | eee |101 |
+-------+------+-------+
However, the funtion seems to be called just once with the whole column passed as its argument, resulting in the following output:
+-------+------+---------------------------+
| index | name |codes |
+-------+------+---------------------------+
| 1 | aaa |101|102|101|102|103|104|101| |
| 2 | bbb |101|102|101|102|103|104|101|
| 3 | ccc |101|102|101|102|103|104|101|
| 4 | ddd |101|102|101|102|103|104|101|
| 5 | eee |101|102|101|102|103|104|101|
+-------+------+---------------------------+
I suppose I am doing something fundamentally wrong when it comes to calling UDF in the .withColum method but I could not figure out what - I would very much appreciate someone pointing out what is wrong with my logic.