0

I'm using Apache Spark 2.3.1 in Java. I want to count the number of rows in a dataset matching a given condition, by using the agg() method of the Dataset class.

For example, I want to count the number of rows where label equals to 1.0 in the following dataset:

SparkSession spark = ...

List<Row> rows = new ArrayList<>();
rows.add(RowFactory.create(0, 0.0));
rows.add(RowFactory.create(1, 1.0));
rows.add(RowFactory.create(2, 1.0));

Dataset<Row> ds =
    spark.sqlContext().createDataFrame(rows,
        new StructType(new StructField[] {
            new StructField("id", DataTypes.LongType, false, Metadata.empty()),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty())}));

My guess is to use the following code:

ds.agg(functions.count(ds.col("label").equalTo(1.0))).show();

However, a wrong result is displayed:

+--------------------+
|count((label = 1.0))|
+--------------------+
|                   3|
+--------------------+

The right result should be of course 2.

Is the agg() method not supposed to work this way?

guidoman
  • 1,071
  • 2
  • 11
  • 20
  • looks like your condition isn't working inside agg, what if you try filtering before count? `df.filter(df.col("label").equalTo(1.0)).count();` – chlebek Nov 20 '19 at 11:29
  • In fact my question is why the condition isn't working inside agg... – guidoman Nov 20 '19 at 11:43

3 Answers3

1

count in agg() will only count not null values, so it's possible to do this:

 import org.apache.spark.sql.functions._
 ds.agg(count(when('label.equalTo(1.0),1).otherwise(null))).show()

I've found this solution here https://stackoverflow.com/a/1400115/9687910

chlebek
  • 2,431
  • 1
  • 8
  • 20
0

The agg method isn't supposed to work like this. Indeed what you need here is firstly to group your data per label and then apply an aggregation such as count, max and many more.

df.filter("label".equalTo(1.0)).groupBy('label').agg(count("*").alias("cnt"))

It refers to the following documentation.

rbcvl
  • 406
  • 3
  • 13
0

chlebek's answer is correct.

Using Java syntax:

ds.agg(functions.count(functions.when(ds.col("label").equalTo(1.0), 0))).show();

Note that, when using count, the value argument of the when function doesn't matter (it's equivalent to SQL count(*)).

Another way to accomplish the same would be to output a 1 and sum all results:

ds.agg(functions.sum(functions.when(ds.col("label").equalTo(1.0), 1))).show();

In this case the value must be exactly 1.

guidoman
  • 1,071
  • 2
  • 11
  • 20