I have an assignment problem, and I wanted to ask the SO community the best way to go about implementing this for my spark dataframe (utilizing spark 3.1+). I will first describe the problem and then move to implementation.
Here is the problem: I have up to N tasks and up to N individuals (in the case of this problem, N=10). Each individual has a cost of performing each task, where the min cost is $0 and the max cost is $10. It is sort of a Hungarian algorithm problem with some caveats.
- There will be some instances where there are less than 10 tasks and/or less than 10 individuals, and it is okay for someone not to be assigned a task (or for a task to not be assigned an individual).
- [The more complex edge case/the one I am having trouble with] - There may be one task in the list that has the flag
multiTask=True
(there cannot be more than 1multiTask
, and it is possible there are none). If a worker has a cost less thanx
for the multiTask, he is automatically assigned to the multiTask and the multiTask is considered taken during the optimization.- I will share a few examples. In this example, the x value to be assigned to the multi task is 1.
- If 1 worker out of 10 has a cost of .25 on the multiTask, he is assigned to the multiTask and then the other 9 workers will be assigned to the other 9 tasks
- If 2 workers out of the 10 have a cost < 1 on the multiTask, both of them are assigned to the multiTask and then the other 8 workers will be assigned to 8 of the remaining 9 tasks. 1 task will not be assigned to anyone.
- If all 10 workers have a cost < 1 on the multiTask, all of them are assigned to the multiTask. This is very rare but possible.
- If no workers have a cost < 1 on the multiTask, the multiTask will only be assigned to one person during the optimization to minimize the cost.
- I will share a few examples. In this example, the x value to be assigned to the multi task is 1.
Here is what the spark dataframe looks like. Note: I am showing an example where N=3 (3 tasks, 3 individuals) for simplicity sake.
from pyspark.sql import Row
rdd = spark.sparkContext.parallelize([
Row(date='2019-08-01', locationId='z2-NY', workerId=129, taskId=220, cost=1.50, isMultiTask=False),
Row(date='2019-08-01', locationId='z2-NY', workerId=129, taskId=110, cost=2.90, isMultiTask=True),
Row(date='2019-08-01', locationId='z2-NY', workerId=129, taskId=190, cost=0.80, isMultiTask=False),
Row(date='2019-08-01', locationId='z2-NY', workerId=990, taskId=220, cost=1.80, isMultiTask=False),
Row(date='2019-08-01', locationId='z2-NY', workerId=990, taskId=110, cost=0.90, isMultiTask=True),
Row(date='2019-08-01', locationId='z2-NY', workerId=990, taskId=190, cost=9.99, isMultiTask=False),
Row(date='2019-08-01', locationId='z2-NY', workerId=433, taskId=220, cost=1.20, isMultiTask=False),
Row(date='2019-08-01', locationId='z2-NY', workerId=433, taskId=110, cost=0.25, isMultiTask=True),
Row(date='2019-08-01', locationId='z2-NY', workerId=433, taskId=190, cost=4.99, isMultiTask=False)
])
df = spark.createDataFrame(rdd)
You will see there is a date/location as I need to solve this assignment problem for every date/location grouping.
I was planning to solve this by assigning each worker and task an "index" based on their IDs using dense_rank()
and then using a pandas UDF, populating the N x N numpy array based on the indices, and invoking the linear_sum_assignment
function. However, I don't believe that this plan will work due to the 2nd edge case I laid out with the multiTask.
worker_order_window = Window.partitionBy("date", "locationId").orderBy("workerId")
task_order_window = Window.partitionBy("date", "locationId").orderBy("taskId")
# get the dense_rank because will use this to assign a worker ID an index for the np array for linear_sum_assignment
# dense_rank - 1 as arrays are 0 indexed
df = df.withColumn("worker_idx", dense_rank().over(worker_order_window) - 1)
df = df.withColumn("task_idx", dense_rank().over(task_order_window) - 1)
def linear_assignment_udf(pandas_df: pd.DataFrame) -> pd.DataFrame:
df_dict = pandas_df.to_dict('records')
# in case there are less than N rows/columns
N = max(pandas_df.shape[0], pandas_df.shape[1])
arr = np.zeros((N,N))
for row in df_dict:
# worker_idx will be the row number, task idx will be the col number
worker_idx = row.get('worker_idx')
task_idx = row.get('task_idx')
arr[worker_idx][task_idx] = row.get('cost')
rids, cids = linear_sum_assignment(n)
return_list = []
# now want to return a dataframe that says which task_idx a worker has
for r, c in zip(rids, cids):
for d in df_dict:
if d.get('worker_idx') == r:
d['task_assignment'] = c
return_list.append(d)
return pd.DataFrame(return_list)
schema = StructType.fromJson(df.schema.jsonValue()).add('task_assignment', 'integer')
df = df.groupBy("date", "locationId").applyInPandas(linear_assignment_udf, schema)
df = df.withColumn("isAssigned", when(col("task_assignment") == col("task_idx"), True).otherwise(False))
As you can see, this case does not cover the multiTask at all. I would like to solve this in the most efficient way possible so I am not tied to pandas udf or scipy.