1

I am trying to use a Spark DataFrame of a health plan's member IDs and enrollment months to identify "continuous" coverage periods, which is a member being enrolled for consecutive months.

Below is an example of the data I am working with in PySpark (sc is the SparkSession).

import pandas as pd
import numpy as np

df = pd.DataFrame({'memid': ['123a', '123a', '123a', '123a', '123a', '123a',
                             '456b', '456b', '456b', '456b', '456b',
                             '789c', '789c', '789c', '789c', '789c', '789c'], 
                     'month_elig': ['2020-01-01', '2020-02-01', '2020-03-01', '2020-08-01', '2020-09-01', '2021-01-01',
                                    '2020-02-01', '2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01',
                                    '2020-02-01', '2020-03-01', '2020-04-01', '2020-05-01', '2020-06-01', '2020-07-01']})
df['month_elig'] = pd.to_datetime(test['month_elig'])
df['gap'] = (df.month_elig - df.groupby(['memid']).shift(1).month_elig)/np.timedelta64(1, 'M')
df['gap'] = np.where(df['gap'].isnull(), 0, df['gap'])
df['gap'] = np.round(df['gap'], 0)

scdf = sc.createDataFrame(df)

scdf.show()

#+-----+-------------------+---+
#|memid|         month_elig|gap|
#+-----+-------------------+---+
#| 123a|2020-01-01 00:00:00|0.0|
#| 123a|2020-02-01 00:00:00|1.0|
#| 123a|2020-03-01 00:00:00|1.0|
#| 123a|2020-08-01 00:00:00|5.0|
#| 123a|2020-09-01 00:00:00|1.0|
#| 123a|2021-01-01 00:00:00|4.0|
#| 456b|2020-02-01 00:00:00|0.0|
#| 456b|2020-05-01 00:00:00|3.0|
#| 456b|2020-06-01 00:00:00|1.0|
#| 456b|2020-07-01 00:00:00|1.0|
#| 456b|2020-08-01 00:00:00|1.0|
#| 789c|2020-02-01 00:00:00|0.0|
#| 789c|2020-03-01 00:00:00|1.0|
#| 789c|2020-04-01 00:00:00|1.0|
#| 789c|2020-05-01 00:00:00|1.0|
#| 789c|2020-06-01 00:00:00|1.0|
#| 789c|2020-07-01 00:00:00|1.0|
#+-----+-------------------+---+

If I was able to complete this exercise in Pandas, I would do it using the code below to create the unique_coverage_period field. But the solution needs to be in Spark because of the size of the data I'm dealing with, and from what I've researched so far (example), an iterator approach like this is not something that Spark is really set-up to do.

a = 0
b = []
for i in df.gap.tolist():
    if i != 1:
      a += 1
      b.append(a)
    else:
      b.append(a)

df['unique_coverage_period'] = b

print(df)

#   memid month_elig  gap  unique_coverage_period
#0   123a 2020-01-01  0.0                       1
#1   123a 2020-02-01  1.0                       1
#2   123a 2020-03-01  1.0                       1
#3   123a 2020-08-01  5.0                       2
#4   123a 2020-09-01  1.0                       2
#5   123a 2021-01-01  4.0                       3
#6   456b 2020-02-01  0.0                       4
#7   456b 2020-05-01  3.0                       5
#8   456b 2020-06-01  1.0                       5
#9   456b 2020-07-01  1.0                       5
#10  456b 2020-08-01  1.0                       5
#11  789c 2020-02-01  0.0                       6
#12  789c 2020-03-01  1.0                       6
#13  789c 2020-04-01  1.0                       6
#14  789c 2020-05-01  1.0                       6
#15  789c 2020-06-01  1.0                       6
#16  789c 2020-07-01  1.0                       6

mck
  • 40,932
  • 13
  • 35
  • 50
bshelt141
  • 1,183
  • 15
  • 31

2 Answers2

1

You can do a rolling sum over a window as below:

from pyspark.sql import functions as F, Window

result = scdf.withColumn(
    'flag',
    F.sum((F.col('gap') != 1).cast('int')).over(Window.orderBy('memid', 'month_elig'))
)

result.show()
+-----+-------------------+---+----+
|memid|         month_elig|gap|flag|
+-----+-------------------+---+----+
| 123a|2020-01-01 00:00:00|0.0|   1|
| 123a|2020-02-01 00:00:00|1.0|   1|
| 123a|2020-03-01 00:00:00|1.0|   1|
| 123a|2020-08-01 00:00:00|5.0|   2|
| 123a|2020-09-01 00:00:00|1.0|   2|
| 123a|2021-01-01 00:00:00|4.0|   3|
| 456b|2020-02-01 00:00:00|0.0|   4|
| 456b|2020-05-01 00:00:00|3.0|   5|
| 456b|2020-06-01 00:00:00|1.0|   5|
| 456b|2020-07-01 00:00:00|1.0|   5|
| 456b|2020-08-01 00:00:00|1.0|   5|
| 789c|2020-02-01 00:00:00|0.0|   6|
| 789c|2020-03-01 00:00:00|1.0|   6|
| 789c|2020-04-01 00:00:00|1.0|   6|
| 789c|2020-05-01 00:00:00|1.0|   6|
| 789c|2020-06-01 00:00:00|1.0|   6|
| 789c|2020-07-01 00:00:00|1.0|   6|
+-----+-------------------+---+----+
mck
  • 40,932
  • 13
  • 35
  • 50
1

I've since come up with an alternative approach to identifying unique coverage periods. While I find the accepted answer posted by @mck to be much more legible and straight forward, the approach provided below seems to perform much quicker when processing on the actual, larger data set of 84.6M records.

# Create a new DataFrame that retains only the coverage break months and then orders each month per member
w1 = Window().partitionBy('memid').orderBy( F.col('month_elig'))

scdf1 = scdf \
  .filter(F.col('gap') != 1) \
    .withColumn('rank', rank().over(w1)) \
  .select('memid', F.col('month_elig').alias('starter_month'), 'rank')

  
# Joins the two Spark Data Frames by memid and keeps only the records where the 'month_elig' is >= the 'starter_month' 
scdf2 = scdf.join(scdf1, on = 'memid', how = 'inner') \
  .withColumn('starter', F.when(F.col('month_elig') == F.col('starter_month'), 1) \
                  .otherwise(0)) \
  .filter(F.col('month_elig') >= F.col('starter_month'))
  

# If the 'month_elig' == 'starter_month', then keep that one, otherwise keep the latest 'starter_month' for each 'month_elig' record
w2 = Window().partitionBy(['memid', 'month_elig']).orderBy(F.col('starter').desc(), F.col('rank').desc())

scdf2 = scdf2 \
  .withColumn('rank', rank().over(w2)) \
  .filter(F.col('rank') == 1).drop('rank') \
  .withColumn('flag', F.concat(F.col('memid'), F.lit('_'), F.trunc(F.col('starter_month'), 'month'))) \
  .select('memid', 'month_elig', 'gap', 'flag')
  
scdf2.show()
+-----+-------------------+---+---------------+
|memid|         month_elig|gap|           flag|
+-----+-------------------+---+---------------+
| 789c|2020-02-01 00:00:00|0.0|789c_2020-02-01|
| 789c|2020-03-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-04-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-05-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-06-01 00:00:00|1.0|789c_2020-02-01|
| 789c|2020-07-01 00:00:00|1.0|789c_2020-02-01|
| 123a|2020-01-01 00:00:00|0.0|123a_2020-01-01|
| 123a|2020-02-01 00:00:00|1.0|123a_2020-01-01|
| 123a|2020-03-01 00:00:00|1.0|123a_2020-01-01|
| 123a|2020-08-01 00:00:00|5.0|123a_2020-08-01|
| 123a|2020-09-01 00:00:00|1.0|123a_2020-08-01|
| 123a|2021-01-01 00:00:00|4.0|123a_2021-01-01|
| 456b|2020-02-01 00:00:00|0.0|456b_2020-02-01|
| 456b|2020-05-01 00:00:00|3.0|456b_2020-05-01|
| 456b|2020-06-01 00:00:00|1.0|456b_2020-05-01|
| 456b|2020-07-01 00:00:00|1.0|456b_2020-05-01|
| 456b|2020-08-01 00:00:00|1.0|456b_2020-05-01|
+-----+-------------------+---+---------------+
bshelt141
  • 1,183
  • 15
  • 31