0

I have a dataset with several columns

  • subject
  • student
  • marks I want to find min, max, and median of marks

df.group(df.col("subject"), df.col("student"))
.agg(functions.min(df.col("marks")),
functions.max(df.col("marks")))

How do we find the median for the marks column?

I know we can do it in SQL using percentile_approx. Is there a way to do it with Dataset?

Edit: The linked question points to answers containing pyspark API while the question is specifically for java

Sam
  • 1,298
  • 6
  • 30
  • 65
  • next time, please use the search function. It's very unlikely that those things haven't been asked and answered before ... and afaik, you can run SQL on `Dataset` or am I wrong here? – UninformedUser Aug 21 '19 at 19:40
  • I did search for the related item, but all of them spoke about running an sql on spark context. Hence I specifically mentioned it. I am not sure how we can run a sql on Dataset though? – Sam Aug 21 '19 at 20:39
  • I see. It looks like there is no function for it in Dataset API. I always thought it's possible to run SQL on a `Dataset`. In the end a `DataFrame` ist just a `Dataset[Row]`. And using `createOrReplaceTempView` on the `Dataset` and then running SQL with the `SQLContext` doesn't work? – UninformedUser Aug 22 '19 at 02:37

1 Answers1

1

You can use Hive percentile_approx UDF. Something like that

    SparkSession spark = SparkSession
            .builder()
            .config(new SparkConf().setAppName("medianTest").setMaster("local[*]"))
            .getOrCreate();

    StructType schema = DataTypes.createStructType(new StructField[]{
            createStructField("subject", DataTypes.StringType, true),
            createStructField("student", DataTypes.StringType, true),
            createStructField("mark", DataTypes.IntegerType, true)
    });

    List<Row> rows = Arrays.asList(
            RowFactory.create("CS", "Alice", 85),
            RowFactory.create("CS", "Alice", 81),
            RowFactory.create("CS", "Alice", 97),
            RowFactory.create("CS", "Bob", 92),
            RowFactory.create("CS", "Bob", 75),
            RowFactory.create("CS", "Bob", 99),
            RowFactory.create("CS", "Carol", 71),
            RowFactory.create("CS", "Carol", 84),
            RowFactory.create("CS", "Carol", 91)
    );

    Dataset df = spark.createDataFrame(rows, schema);

    df
            .groupBy("subject", "student")
            .agg(
                    min("mark").as("min"),
                    max("mark").as("max"),
                    callUDF("percentile_approx", col("mark"), lit(0.5)).as("median")
            )
            .show();
Grisha Weintraub
  • 7,803
  • 1
  • 25
  • 45