I have a given dataframe that looks like this:
TEST_schema = StructType([StructField("Date", StringType(), True),\
StructField("START", StringType(), True),\
StructField("quantity", IntegerType(), True),\
StructField("col1", StringType(), True),
StructField("col2", StringType(), True)])
TEST_data = [('2020-08-15','2020-08-19',1,'2020-08-05','2020-08-09'),('2020-08-16','2020-08-19',2,'2020-08-05','2020-08-09')\
,('2020-08-17','2020-08-19',3,'2020-08-06','2020-08-09'),\
('2020-08-18','2020-08-19',4,'2020-08-10','2020-08-11'),('2020-08-19','2020-08-19',5,'2020-08-16','2020-08-19'),\
('2020-08-20','2020-08-19',6,'2020-08-20','2020-08-25'),('2020-08-21','2020-08-19',7,'2020-08-20','2020-08-21'),\
('2020-08-22','2020-08-19',8,'2020-08-19','2020-08-24'),('2020-08-23','2020-08-19',9,'2020-08-05','2020-08-09')]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("Date",to_date("Date"))\
.withColumn("START",to_date("START"))\
.withColumn("col1",to_date("col1"))\
.withColumn("col2",to_date("col2"))\
TEST_df.show()
+----------+----------+--------+----------+----------+
| Date| START|quantity| col1| col2|
+----------+----------+--------+----------+----------+
|2020-08-15|2020-08-19| 1|2020-08-05|2020-08-09|
|2020-08-16|2020-08-19| 2|2020-08-05|2020-08-09|
|2020-08-17|2020-08-19| 3|2020-08-06|2020-08-09|
|2020-08-18|2020-08-19| 4|2020-08-10|2020-08-11|
|2020-08-19|2020-08-19| 5|2020-08-16|2020-08-19|
|2020-08-20|2020-08-19| 6|2020-08-20|2020-08-25|
|2020-08-21|2020-08-19| 7|2020-08-20|2020-08-21|
|2020-08-22|2020-08-19| 8|2020-08-19|2020-08-24|
|2020-08-23|2020-08-19| 9|2020-08-05|2020-08-09|
+----------+----------+--------+----------+----------+
where col1 and col2 may have not be unique, and Date is just incremental date, and START is unique.
My logic is that if START == col2,
then lag(quantity, offset= datediff(col2,col1),0)
otherwise 0.
in this case, datediff(col2,col1) which is 3days.
Attempt 1.
from pyspark.sql.functions import when, col,datediff,expr
TEST_df = TEST_df.withColumn('datedifff', datediff(col('col2'), col('col1')))\
.withColumn('want', expr("IF(START == col2, lag(quantity, datedifff,0),0) "))
which has literal error...