0

I have a dataframe that has an array with doubles as values. Within the array, 1 or a sum of the numbers equals a certain target value, and I want to extract the values that either equal the value or can be summed to equal the value. I'd like to be able to do this in PySpark.

| Array                  | Target    | NewArray         |
| -----------------------|-----------|------------------|
| [0.0001,2.5,3.0,0.0031]| 0.0032    | [0.0001,0.0031]  |
| [2.5,1.0,0.5,3.0]      | 3.0       | [2.5, 0.5, 3.0]  |
| [1.0,1.0,1.5,1.0]      | 4.5       | [1.0,1.0,1.5,1.0]|

1 Answers1

1

You can encapsulate the logic as an udf and create NewArray based on this. I have borrowed the logic for identifying the elements of array summing to a target value from here.


from pyspark.sql.types import ArrayType, DoubleType
from pyspark.sql.functions import udf
from decimal import Decimal

data = [([0.0001,2.5,3.0,0.0031], 0.0032),
([2.5, 1.0, 0.5, 3.0], 3.0),
([1.0, 1.0, 1.5, 1.0], 4.5), 
([], 1.0),
(None, 1.0),
([1.0,2.0], None),]


df = spark.createDataFrame(data, ("Array", "Target", ))


@udf(returnType=ArrayType(DoubleType()))
def find_values_summing_to_target(array, target):
    def subset_sum(numbers, target, partial, result):
        s = sum(partial)
        # check if the partial sum is equals to target
        if s == target: 
            result.extend(partial)
        if s >= target:
            return  # if we reach the number why bother to continue
    
        for i in range(len(numbers)):
            n = numbers[i]
            remaining = numbers[i+1:]
            subset_sum(remaining, target, partial + [n], result)
    result = []
    if array is not None and target is not None:
        array = [Decimal(str(a)) for a in array]
        subset_sum(array, Decimal(str(target)), [], result)
        result = [float(r) for r in result]
    return result

df.withColumn("NewArray", find_values_summing_to_target("Array", "Target")).show(200, False)

Output

+--------------------------+------+--------------------+
|Array                     |Target|NewArray            |
+--------------------------+------+--------------------+
|[1.0E-4, 2.5, 3.0, 0.0031]|0.0032|[1.0E-4, 0.0031]    |
|[2.5, 1.0, 0.5, 3.0]      |3.0   |[2.5, 0.5, 3.0]     |
|[1.0, 1.0, 1.5, 1.0]      |4.5   |[1.0, 1.0, 1.5, 1.0]|
|[]                        |1.0   |[]                  |
|null                      |1.0   |[]                  |
|[1.0, 2.0]                |null  |[]                  |
+--------------------------+------+--------------------+
Nithish
  • 3,062
  • 2
  • 8
  • 16
  • Thanks for your help, it's definitely putting me on the right track. However I'm having trouble at this point: if s >= target: return I get an error when left in: TypeError: '>=' not supported between instances of 'int' and 'NoneType'. When I take this out it runs, but it does not return all of the values that sum to the target, only shows when 1 of the values is equal to the target by itself. – Alex Triece Nov 24 '21 at 16:25
  • Additionally, the issue could be that the decimals I'm using are much smaller (in the .0031 and .0001 range). I noticed when I substituted the example data with decimals like this it returned empty arrays. Any thoughts on that? – Alex Triece Nov 24 '21 at 16:34
  • 1
    For the first issue, I think you have None values in `target` column. For this I will update the answers to return an empty array if this happens. – Nithish Nov 24 '21 at 16:36
  • You were absolutely right about that first issue. Changed the na's to 0 and it works fine. However, it doesn't read the smaller decimals. I'm ok with 0's in the target column, so no need to spend too much time on that issue, unless you want to for others' sake. – Alex Triece Nov 24 '21 at 16:44
  • The code in the answer is now `na` or `null` safe. For the precision I would need an example, I tried for smaller ranges too 6 decimal digits and it still works. An example would help replicate. – Nithish Nov 24 '21 at 16:50
  • Just changed the top example to show what I'm looking at, really just the first row. When I plug this in, I get correct results for everything except the top row. – Alex Triece Nov 24 '21 at 16:57
  • 1
    The problem is due to floating point precision error, in Python `0.0001 + 0.0031` is `0.0031999999999999997` https://stackoverflow.com/questions/11950819/python-math-is-wrong/11950951#11950951, I have updated the answer to support precision arithmetic to support your usecase. – Nithish Nov 24 '21 at 22:12
  • Thanks, that helps. However, it throws an error with the Decimal() function. Is there something that needs to be imported for that to be recognized? – Alex Triece Nov 29 '21 at 14:33
  • Nevermind, I figured it out. It wanted me to specify decimal.Decimal() to run that function. – Alex Triece Nov 29 '21 at 15:17