3

I have a Spark data frame in the following format.

df = spark.createDataFrame([(1, 2, 3), (1, 4, 100), (20, 30, 50)],['a', 'b', 'c'])
df.show()

Input:

enter image description here

I want to add a new column "median" as the median of the columns 'a', 'b', 'c'. How to do that in PySpark.

Expected Output:

enter image description here

I'm using Spark 2.3.1

10465355
  • 4,481
  • 2
  • 20
  • 44
Rinaz Belhaj
  • 745
  • 1
  • 7
  • 20

4 Answers4

5

define a user-defined function using udf, and then using withColumn to add the specified column to the data frame:

from numpy import median
from pyspark.sql.functions import col, udf
from pyspark.sql.types import IntegerType

def my_median(a, b, c):
    return int(median([int(a),int(b),int(c)]))

udf_median = udf(my_median, IntegerType())

df_t = df.withColumn('median', udf_median(df['a'], df['b'], df['c']))
df_t.show()
OmG
  • 18,337
  • 10
  • 57
  • 90
2

There is no built-in function, but you can easily write one, using existing components.

# In Spark < 2.4  replace array_sort with sort_array
# Thanks to @RaphaelRoth for pointing that out
from pyspark.sql.functions import array, array_sort, floor, col, size
from pyspark.sql import Column

def percentile(p, *args):
    def col_(c):
        if isinstance(c, Column):
            return c
        elif isinstance(c, str):
            return col(c)
        else:
            raise TypeError("args should str or Column, got {}".format(type(c)))

    xs = array_sort(array(*[col_(x) for x in args]))
    n = size(xs)
    h = (n - 1) * p
    i = floor(h).cast("int")
    x0, x1 = xs[i], xs[i + 1]
    return x0 + (h - i) * (x1 - x0)

Example usage:

df.withColumn("median", percentile(0.5, *df.columns)).show()
+---+---+---+------+
|  a|  b|  c|median|
+---+---+---+------+
|  1|  2|  3|   2.0|
|  1|  4|100|   4.0|
| 20| 30| 50|  30.0|
+---+---+---+------+

The same thing can be done in Scala:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

def percentile(p: Double, args: Column*) = {
    val xs = array_sort(array(args: _*))
    val n = size(xs)
    val h = (n - 1) * p
    val i = floor(h).cast("int")
    val (x0, x1) = (xs(i), xs(i + 1))
    x0 + (h - i) * (x1 - x0)
}

val df = Seq((1, 2, 3), (1, 4, 100), (20, 30, 50)).toDF("a", "b", "c")
df.withColumn("median", percentile(0.5, $"a", $"b", $"c")).show
+---+---+---+------+
|  a|  b|  c|median|
+---+---+---+------+
|  1|  2|  3|   2.0|
|  1|  4|100|   4.0|
| 20| 30| 50|  30.0|
+---+---+---+------+

In Python only, you might also consider vectorized UDF - in general it is likely to be slower than built-in functions, but superior compared to non-vectorized udf:

from pyspark.sql.functions import pandas_udf, PandasUDFType 
from pyspark.sql.types import DoubleType

import pandas as pd
import numpy as np 

def pandas_percentile(p=0.5):
    assert 0 <= p <= 1
    @pandas_udf(DoubleType()) 
    def _(*args): 
        return pd.Series(np.percentile(args, q = p * 100, axis = 0))
    return _


df.withColumn("median", pandas_percentile(0.5)("a", "b", "c")).show()
+---+---+---+------+                                                            
|  a|  b|  c|median|
+---+---+---+------+
|  1|  2|  3|   2.0|
|  1|  4|100|   4.0|
| 20| 30| 50|  30.0|
+---+---+---+------+
10465355
  • 4,481
  • 2
  • 20
  • 44
1

I have modified OmG's answer slightly to make the UDF dynamic for 'n' number of columns instead of just 3.

Code:

df = spark.createDataFrame([(1,2,3),(100,1,10),(30,20,50)],['a','b','c'])

import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType

def my_median(*args):
    return float(np.median(list(args)))

udf_median = udf(my_median, DoubleType())

df.withColumn('median', udf_median('a','b','c')).show()

Output:

enter image description here

Rinaz Belhaj
  • 745
  • 1
  • 7
  • 20
0
df = spark.createDataFrame([(1,2,3),(1,4,100),(20,30,50)],['a','b','c'])

from pyspark.sql.functions import struct, udf
from pyspark.sql.types import FloatType
import numpy as np

def find_median(values_list):
    try:
        median = np.median(values_list) #get the median of values in a list in each row
        return round(float(median),2)
    except Exception:
        return None #if there is anything wrong with the given values
median_finder = udf(find_median,FloatType())

df = df.withColumn("List_abc", struct(col('a'),col('b'),col('c')))\
       .withColumn("median",median_finder("List_abc")).drop('List_abc')
df.show()
+---+---+---+------+
|  a|  b|  c|median|
+---+---+---+------+
|  1|  2|  3|   2.0|
|  1|  4|100|   4.0|
| 20| 30| 50|  30.0|
+---+---+---+------+
cph_sto
  • 7,189
  • 12
  • 42
  • 78