4

Given a PySpark DataFrame of the form:

+----+--------+
|time|messages|
+----+--------+
| t01|    [m1]|
| t03|[m1, m2]|
| t04|    [m2]|
| t06|    [m3]|
| t07|[m3, m1]|
| t08|    [m1]|
| t11|    [m2]|
| t13|[m2, m4]|
| t15|    [m2]|
| t20|    [m4]|
| t21|      []|
| t22|[m1, m4]|
+----+--------+

I'd like to refactor it to compress runs containing the same message (the order of the output doesn't matter much, but sorted her for clarity):

+----------+--------+-------+
|start_time|end_time|message|
+----------+--------+-------+
|       t01|     t03|     m1|
|       t07|     t08|     m1|
|       t22|     t22|     m1|
|       t03|     t04|     m2|
|       t11|     t15|     m2|
|       t06|     t07|     m3|
|       t13|     t13|     m4|
|       t20|     t20|     m4|
|       t22|     t22|     m4|
+----------+--------+-------+

(i.e. treat the message column as a sequence and identify the start and end of "runs" for each message),

Is there a clean way to make this transformation in Spark? Currently, I'm dumping this as a 6 GB TSV and processing it imperatively.

I'm open to the possibility of toPandas-ing this and accumulating on the driver if Pandas has a clean way to do this aggregation.

(see my answer below for a naïve baseline implementation).

Jedi
  • 3,088
  • 2
  • 28
  • 47
  • The amount of unique messages it's very huge? I mean, in your case you have m1, m2, m3, and m4. They are only 4 o could be more, and in this case, are pre-defined or could be distinct each time? – Alfilercio May 21 '20 at 14:59
  • forget it, what version of spark you have? – Alfilercio May 21 '20 at 15:19

3 Answers3

2

You can try the following method using forward-filling(Spark 2.4+ is not required):

Step-1: do the following:

  1. for each row ordered by time, find prev_messages and next_messages
  2. explode messages into individual message
  3. for each message, if prev_messages is NULL or message is not in prev_messages, then set start=time, see below SQL syntax:

    IF(prev_messages is NULL or !array_contains(prev_messages, message),time,NULL)
    

    which can be simplified to:

    IF(array_contains(prev_messages, message),NULL,time)
    
  4. and if next_messages is NULL or message is not in next_messages, then set end=time

Code below:

from pyspark.sql import Window, functions as F

# rows is defined in your own post
df = spark.createDataFrame(rows, ['time', 'messages'])

w1 = Window.partitionBy().orderBy('time')

df1 = df.withColumn('prev_messages', F.lag('messages').over(w1)) \
    .withColumn('next_messages', F.lead('messages').over(w1)) \
    .withColumn('message', F.explode('messages')) \
    .withColumn('start', F.expr("IF(array_contains(prev_messages, message),NULL,time)")) \
    .withColumn('end', F.expr("IF(array_contains(next_messages, message),NULL,time)"))

df1.show()
#+----+--------+-------------+-------------+-------+-----+----+
#|time|messages|prev_messages|next_messages|message|start| end|
#+----+--------+-------------+-------------+-------+-----+----+
#| t01|    [m1]|         null|     [m1, m2]|     m1|  t01|null|
#| t03|[m1, m2]|         [m1]|         [m2]|     m1| null| t03|
#| t03|[m1, m2]|         [m1]|         [m2]|     m2|  t03|null|
#| t04|    [m2]|     [m1, m2]|         [m3]|     m2| null| t04|
#| t06|    [m3]|         [m2]|     [m3, m1]|     m3|  t06|null|
#| t07|[m3, m1]|         [m3]|         [m1]|     m3| null| t07|
#| t07|[m3, m1]|         [m3]|         [m1]|     m1|  t07|null|
#| t08|    [m1]|     [m3, m1]|         [m2]|     m1| null| t08|
#| t11|    [m2]|         [m1]|     [m2, m4]|     m2|  t11|null|
#| t13|[m2, m4]|         [m2]|         [m2]|     m2| null|null|
#| t13|[m2, m4]|         [m2]|         [m2]|     m4|  t13| t13|
#| t15|    [m2]|     [m2, m4]|         [m4]|     m2| null| t15|
#| t20|    [m4]|         [m2]|           []|     m4|  t20| t20|
#| t22|[m1, m4]|           []|         null|     m1|  t22| t22|
#| t22|[m1, m4]|           []|         null|     m4|  t22| t22|
#+----+--------+-------------+-------------+-------+-----+----+

Step-2: create WindSpec partitioned by message and do forward-filling to start column.

w2 = Window.partitionBy('message').orderBy('time')

# for illustration purpose, I used a different column-name so that we can 
# compare `start` column before and after ffill
df2 = df1.withColumn('start_new', F.last('start', True).over(w2))
df2.show()
#+----+--------+-------------+-------------+-------+-----+----+---------+
#|time|messages|prev_messages|next_messages|message|start| end|start_new|
#+----+--------+-------------+-------------+-------+-----+----+---------+
#| t01|    [m1]|         null|     [m1, m2]|     m1|  t01|null|      t01|
#| t03|[m1, m2]|         [m1]|         [m2]|     m1| null| t03|      t01|
#| t07|[m3, m1]|         [m3]|         [m1]|     m1|  t07|null|      t07|
#| t08|    [m1]|     [m3, m1]|         [m2]|     m1| null| t08|      t07|
#| t22|[m1, m4]|           []|         null|     m1|  t22| t22|      t22|
#| t03|[m1, m2]|         [m1]|         [m2]|     m2|  t03|null|      t03|
#| t04|    [m2]|     [m1, m2]|         [m3]|     m2| null| t04|      t03|
#| t11|    [m2]|         [m1]|     [m2, m4]|     m2|  t11|null|      t11|
#| t13|[m2, m4]|         [m2]|         [m2]|     m2| null|null|      t11|
#| t15|    [m2]|     [m2, m4]|         [m4]|     m2| null| t15|      t11|
#| t06|    [m3]|         [m2]|     [m3, m1]|     m3|  t06|null|      t06|
#| t07|[m3, m1]|         [m3]|         [m1]|     m3| null| t07|      t06|
#| t13|[m2, m4]|         [m2]|         [m2]|     m4|  t13| t13|      t13|
#| t20|    [m4]|         [m2]|           []|     m4|  t20| t20|      t20|
#| t22|[m1, m4]|           []|         null|     m4|  t22| t22|      t22|
#+----+--------+-------------+-------------+-------+-----+----+---------+

Step-3: remove rows having end is NULL and then select only required columns:

df2.selectExpr("message", "start_new as start", "end") \
    .filter("end is not NULL") \
    .orderBy("message","start").show()
#+-------+-----+---+
#|message|start|end|
#+-------+-----+---+
#|     m1|  t01|t03|
#|     m1|  t07|t08|
#|     m1|  t22|t22|
#|     m2|  t03|t04|
#|     m2|  t11|t15|
#|     m3|  t06|t07|
#|     m4|  t13|t13|
#|     m4|  t20|t20|
#|     m4|  t22|t22|
#+-------+-----+---+

To summarize the above steps, we have the following:

from pyspark.sql import Window, functions as F

# define two Window Specs
w1 = Window.partitionBy().orderBy('time')
w2 = Window.partitionBy('message').orderBy('time')

df_new = df \
    .withColumn('prev_messages', F.lag('messages').over(w1)) \
    .withColumn('next_messages', F.lead('messages').over(w1)) \
    .withColumn('message', F.explode('messages')) \
    .withColumn('start', F.expr("IF(array_contains(prev_messages, message),NULL,time)")) \
    .withColumn('end', F.expr("IF(array_contains(next_messages, message),NULL,time)")) \
    .withColumn('start', F.last('start', True).over(w2)) \
    .select("message", "start", "end") \
    .filter("end is not NULL")

df_new.orderBy("start").show()
jxc
  • 13,553
  • 4
  • 16
  • 34
  • That's very clean. It should work with partitions too, shouldn't it? As long as `w2` has an additional partition on `messages` compared to `w1` (assuming a larger dataframe with many non-overlapping columns). – Jedi May 19 '20 at 12:18
  • 1
    @Jedi, it's based on your sample data with `w1` only to calculate prev_messages and next_messages. if you can find another column or columns to split rows so that messages from different splits are independent from each other, then adding that col (or cols) into `partitionBy()` clause of w1 and w2 should be enough and ideal. – jxc May 19 '20 at 13:28
  • 1
    If on the other hand, it's a single time-series data which can not be simply partitioned to calculate prev_messages and next_messages. In such case, we can for example use `w1 = Window.partitionBy(F.year('time')).orderBy('time')` to calculate prev and next messages, but will have to post-process the first and last rows on each partition. that will need some extra jobs. – jxc May 19 '20 at 13:29
1

Here you can find info of array functions in spark 2.4, and explode_outer is an explode that in an empty array, will produce a row with a 'null' value.

The idea is first to get for each moment, the array of messages that start, and the array of messages that end in each moment (start_of and end_of).

Then, we keep only the moments that a message starts or ends, and create and then do explodes to have a dataframe with 3 columns, one per each message start, and end. For a moment that m1 and m2 are created, will produce 2 start rows, for a moment that m1 starts and end, will produce 2 rows, with an m1 star, and m1 end.

And at the end, use a window function to group by 'message' and order by time, making sure that if a message starts and ends at the same moment (same time), the start will go first. Now we can guarantee that after each start, there will be a end row. Mix them and you will have the start and end of each message.

A great exercise to think.

I've made the example in scala, but is should be easy to translate. Each line marked as showAndContinue, prints your example in that state to show what it does.

val w = Window.partitionBy().orderBy("time")
val w2 = Window.partitionBy("message").orderBy($"time", desc("start_of"))
df.select($"time", $"messages", lag($"messages", 1).over(w).as("pre"), lag("messages", -1).over(w).as("post"))
  .withColumn("start_of", when($"pre".isNotNull, array_except(col("messages"), col("pre"))).otherwise($"messages"))
  .withColumn("end_of",  when($"post".isNotNull, array_except(col("messages"), col("post"))).otherwise($"messages"))
  .filter(size($"start_of") + size($"end_of") > 0)
  .showAndContinue
  .select(explode(array(
    struct($"time", $"start_of", array().as("end_of")),
    struct($"time", array().as("start_of"), $"end_of")
  )).as("elem"))
  .select("elem.*")
  .select($"time", explode_outer($"start_of").as("start_of"), $"end_of")
  .select( $"time", $"start_of", explode_outer($"end_of").as("end_of"))
  .filter($"start_of".isNotNull || $"end_of".isNotNull)
  .showAndContinue
  .withColumn("message", when($"start_of".isNotNull, $"start_of").otherwise($"end_of"))
  .showAndContinue
  .select($"message", when($"start_of".isNotNull, $"time").as("starts_at"), lag($"time", -1).over(w2).as("ends_at"))
  .filter($"starts_at".isNotNull)
  .showAndContinue

And the tables

+----+--------+--------+--------+--------+--------+
|time|messages|     pre|    post|start_of|  end_of|
+----+--------+--------+--------+--------+--------+
| t01|    [m1]|    null|[m1, m2]|    [m1]|      []|
| t03|[m1, m2]|    [m1]|    [m2]|    [m2]|    [m1]|
| t04|    [m2]|[m1, m2]|    [m3]|      []|    [m2]|
| t06|    [m3]|    [m2]|[m3, m1]|    [m3]|      []|
| t07|[m3, m1]|    [m3]|    [m1]|    [m1]|    [m3]|
| t08|    [m1]|[m3, m1]|    [m2]|      []|    [m1]|
| t11|    [m2]|    [m1]|[m2, m4]|    [m2]|      []|
| t13|[m2, m4]|    [m2]|    [m2]|    [m4]|    [m4]|
| t15|    [m2]|[m2, m4]|    [m4]|      []|    [m2]|
| t20|    [m4]|    [m2]|      []|    [m4]|    [m4]|
| t22|[m1, m4]|      []|    null|[m1, m4]|[m1, m4]|
+----+--------+--------+--------+--------+--------+

+----+--------+------+
|time|start_of|end_of|
+----+--------+------+
| t01|      m1|  null|
| t03|      m2|  null|
| t03|    null|    m1|
| t04|    null|    m2|
| t06|      m3|  null|
| t07|      m1|  null|
| t07|    null|    m3|
| t08|    null|    m1|
| t11|      m2|  null|
| t13|      m4|  null|
| t13|    null|    m4|
| t15|    null|    m2|
| t20|      m4|  null|
| t20|    null|    m4|
| t22|      m1|  null|
| t22|      m4|  null|
| t22|    null|    m1|
| t22|    null|    m4|
+----+--------+------+

+----+--------+------+-------+
|time|start_of|end_of|message|
+----+--------+------+-------+
| t01|      m1|  null|     m1|
| t03|      m2|  null|     m2|
| t03|    null|    m1|     m1|
| t04|    null|    m2|     m2|
| t06|      m3|  null|     m3|
| t07|      m1|  null|     m1|
| t07|    null|    m3|     m3|
| t08|    null|    m1|     m1|
| t11|      m2|  null|     m2|
| t13|      m4|  null|     m4|
| t13|    null|    m4|     m4|
| t15|    null|    m2|     m2|
| t20|      m4|  null|     m4|
| t20|    null|    m4|     m4|
| t22|      m1|  null|     m1|
| t22|      m4|  null|     m4|
| t22|    null|    m1|     m1|
| t22|    null|    m4|     m4|
+----+--------+------+-------+

+-------+---------+-------+
|message|starts_at|ends_at|
+-------+---------+-------+
|     m1|      t01|    t03|
|     m1|      t07|    t08|
|     m1|      t22|    t22|
|     m2|      t03|    t04|
|     m2|      t11|    t15|
|     m3|      t06|    t07|
|     m4|      t13|    t13|
|     m4|      t20|    t20|
|     m4|      t22|    t22|
+-------+---------+-------+

It could be optimised extracting all the elements that start and end in the same moment, in the first table created, so they don't have to be "matched" the start and end again, but it depends if this is a common case, or just a small amount of cases. It will be like this with the optimization (same windows)

val dfStartEndAndFiniteLife = df.select($"time", $"messages", lag($"messages", 1).over(w).as("pre"), lag("messages", -1).over(w).as("post"))
  .withColumn("start_of", when($"pre".isNotNull, array_except(col("messages"), col("pre"))).otherwise($"messages"))
  .withColumn("end_of",  when($"post".isNotNull, array_except(col("messages"), col("post"))).otherwise($"messages"))
  .filter(size($"start_of") + size($"end_of") > 0)
  .withColumn("start_end_here", array_intersect($"start_of", $"end_of"))
  .withColumn("start_of", array_except($"start_of", $"start_end_here"))
  .withColumn("end_of", array_except($"end_of", $"start_end_here"))
  .showAndContinue

val onlyStartEndSameMoment = dfStartEndAndFiniteLife.filter(size($"start_end_here") > 0)
  .select(explode($"start_end_here"), $"time".as("starts_at"), $"time".as("ends_at"))
  .showAndContinue

val startEndDifferentMoment = dfStartEndAndFiniteLife
  .filter(size($"start_of") + size($"end_of") > 0)
  .showAndContinue
  .select(explode(array(
    struct($"time", $"start_of", array().as("end_of")),
    struct($"time", array().as("start_of"), $"end_of")
  )).as("elem"))
  .select("elem.*")
  .select($"time", explode_outer($"start_of").as("start_of"), $"end_of")
  .select( $"time", $"start_of", explode_outer($"end_of").as("end_of"))
  .filter($"start_of".isNotNull || $"end_of".isNotNull)
  .showAndContinue
  .withColumn("message", when($"start_of".isNotNull, $"start_of").otherwise($"end_of"))
  .showAndContinue
  .select($"message", when($"start_of".isNotNull, $"time").as("starts_at"), lag($"time", -1).over(w2).as("ends_at"))
  .filter($"starts_at".isNotNull)
  .showAndContinue

val result = onlyStartEndSameMoment.union(startEndDifferentMoment)

result.orderBy("col", "starts_at").show()

And the tables

+----+--------+--------+--------+--------+------+--------------+
|time|messages|     pre|    post|start_of|end_of|start_end_here|
+----+--------+--------+--------+--------+------+--------------+
| t01|    [m1]|    null|[m1, m2]|    [m1]|    []|            []|
| t03|[m1, m2]|    [m1]|    [m2]|    [m2]|  [m1]|            []|
| t04|    [m2]|[m1, m2]|    [m3]|      []|  [m2]|            []|
| t06|    [m3]|    [m2]|[m3, m1]|    [m3]|    []|            []|
| t07|[m3, m1]|    [m3]|    [m1]|    [m1]|  [m3]|            []|
| t08|    [m1]|[m3, m1]|    [m2]|      []|  [m1]|            []|
| t11|    [m2]|    [m1]|[m2, m4]|    [m2]|    []|            []|
| t13|[m2, m4]|    [m2]|    [m2]|      []|    []|          [m4]|
| t15|    [m2]|[m2, m4]|    [m4]|      []|  [m2]|            []|
| t20|    [m4]|    [m2]|      []|      []|    []|          [m4]|
| t22|[m1, m4]|      []|    null|      []|    []|      [m1, m4]|
+----+--------+--------+--------+--------+------+--------------+

+---+---------+-------+
|col|starts_at|ends_at|
+---+---------+-------+
| m4|      t13|    t13|
| m4|      t20|    t20|
| m1|      t22|    t22|
| m4|      t22|    t22|
+---+---------+-------+

+----+--------+--------+--------+--------+------+--------------+
|time|messages|     pre|    post|start_of|end_of|start_end_here|
+----+--------+--------+--------+--------+------+--------------+
| t01|    [m1]|    null|[m1, m2]|    [m1]|    []|            []|
| t03|[m1, m2]|    [m1]|    [m2]|    [m2]|  [m1]|            []|
| t04|    [m2]|[m1, m2]|    [m3]|      []|  [m2]|            []|
| t06|    [m3]|    [m2]|[m3, m1]|    [m3]|    []|            []|
| t07|[m3, m1]|    [m3]|    [m1]|    [m1]|  [m3]|            []|
| t08|    [m1]|[m3, m1]|    [m2]|      []|  [m1]|            []|
| t11|    [m2]|    [m1]|[m2, m4]|    [m2]|    []|            []|
| t15|    [m2]|[m2, m4]|    [m4]|      []|  [m2]|            []|
+----+--------+--------+--------+--------+------+--------------+

+----+--------+------+
|time|start_of|end_of|
+----+--------+------+
| t01|      m1|  null|
| t03|      m2|  null|
| t03|    null|    m1|
| t04|    null|    m2|
| t06|      m3|  null|
| t07|      m1|  null|
| t07|    null|    m3|
| t08|    null|    m1|
| t11|      m2|  null|
| t15|    null|    m2|
+----+--------+------+

+----+--------+------+-------+
|time|start_of|end_of|message|
+----+--------+------+-------+
| t01|      m1|  null|     m1|
| t03|      m2|  null|     m2|
| t03|    null|    m1|     m1|
| t04|    null|    m2|     m2|
| t06|      m3|  null|     m3|
| t07|      m1|  null|     m1|
| t07|    null|    m3|     m3|
| t08|    null|    m1|     m1|
| t11|      m2|  null|     m2|
| t15|    null|    m2|     m2|
+----+--------+------+-------+

+-------+---------+-------+
|message|starts_at|ends_at|
+-------+---------+-------+
|     m1|      t01|    t03|
|     m1|      t07|    t08|
|     m2|      t03|    t04|
|     m2|      t11|    t15|
|     m3|      t06|    t07|
+-------+---------+-------+

+---+---------+-------+
|col|starts_at|ends_at|
+---+---------+-------+
| m1|      t01|    t03|
| m1|      t07|    t08|
| m1|      t22|    t22|
| m2|      t03|    t04|
| m2|      t11|    t15|
| m3|      t06|    t07|
| m4|      t13|    t13|
| m4|      t20|    t20|
| m4|      t22|    t22|
+---+---------+-------+
Alfilercio
  • 1,088
  • 6
  • 13
0

Found a reasonable way to do this that scales well if you can partition when applying window operations (which you should be able to on any real dataset, I was able to on the one where I derived this problem from).

Broke it up into chunks for explainability (imports are in the first snippet only).

Setup:

# Need these for the setup
import pandas as pd
from pyspark.sql.types import ArrayType, StringType, StructField, StructType

# We'll need these later
from pyspark.sql.functions import array_except, coalesce, col, explode, from_json, lag, lit, rank
from pyspark.sql.window import Window

rows = [
    ['t01',['m1']],
    ['t03',['m1','m2']],
    ['t04',['m2']],
    ['t06',['m3']],
    ['t07',['m3','m1']],
    ['t08',['m1']],
    ['t11',['m2']],
    ['t13',['m2','m4']],
    ['t15',['m2']],
    ['t20',['m4']],
    ['t21',[]],
    ['t22',['m1','m4']],
]

pdf = pd.DataFrame(rows,columns=['time', 'messages'])
schema = StructType([
    StructField("time", StringType(), True),
    StructField("messages", ArrayType(StringType()), True)
])
df = spark.createDataFrame(pdf,schema=schema)

Order by time, lag and generate a diff of the message arrays to identify the start and end of runs:

w = Window().partitionBy().orderBy('time')
df2 = df.withColumn('messages_lag_1', lag('messages', 1).over(w))\
        .withColumn('end_time', lag('time', 1).over(w))\
        .withColumnRenamed('time', 'start_time')\
        .withColumn('messages_lag_1',          # Replace nulls with []
            coalesce(                          # cargoculted from
                col('messages_lag_1'),         # https://stackoverflow.com/a/57198009
                from_json(lit('[]'), ArrayType(StringType()))
            )
        )\
        .withColumn('message_run_starts', array_except('messages', 'messages_lag_1'))\
        .withColumn('message_run_ends', array_except('messages_lag_1', 'messages'))\
        .drop(*['messages', 'messages_lag_1']) # ^ only on Spark > 2.4

+----------+--------+------------------+----------------+
|start_time|end_time|message_run_starts|message_run_ends|
+----------+--------+------------------+----------------+
|       t01|    null|              [m1]|              []|
|       t03|     t01|              [m2]|              []|
|       t04|     t03|                []|            [m1]|
|       t06|     t04|              [m3]|            [m2]|
|       t07|     t06|              [m1]|              []|
|       t08|     t07|                []|            [m3]|
|       t11|     t08|              [m2]|            [m1]|
|       t13|     t11|              [m4]|              []|
|       t15|     t13|                []|            [m4]|
|       t20|     t15|              [m4]|            [m2]|
|       t21|     t20|                []|            [m4]|
|       t22|     t21|          [m1, m4]|              []|
+----------+--------+------------------+----------------+

Group by time and message, and apply a rank to both the start and end tables. Join and in case of nulls, copy the start_time into end_time:

w_start = Window().partitionBy('message_run_starts').orderBy(col('start_time'))
df3 = df2.withColumn('message_run_starts', explode('message_run_starts')).drop('message_run_ends', 'end_time')
df3 = df3.withColumn('start_row_id',rank().over(w_start))

w_end = Window().partitionBy('message_run_ends').orderBy(col('end_time'))
df4 = df2.withColumn('message_run_ends', explode('message_run_ends')).drop('message_run_starts', 'start_time')
df4 = df4.withColumn('end_row_id',rank().over(w_end))

df_combined = df3\
    .join(df4, (df3.message_run_starts == df4.message_run_ends) & (df3.start_row_id == df4.end_row_id), how='full')\
        .drop(*['message_run_ends','start_row_id','end_row_id'])\
        .withColumn('end_time',coalesce(col('end_time'),col('start_time'))) 

df_combined.show()

+----------+------------------+--------+
|start_time|message_run_starts|end_time|
+----------+------------------+--------+
|       t01|                m1|     t03|
|       t07|                m1|     t08|
|       t22|                m1|     t22|
|       t03|                m2|     t04|
|       t11|                m2|     t15|
|       t06|                m3|     t07|
|       t13|                m4|     t13|
|       t20|                m4|     t20|
|       t22|                m4|     t22|
+----------+------------------+--------+
Jedi
  • 3,088
  • 2
  • 28
  • 47