I'm using Apache Spark 2.3.1 in Java. I want to count the number of rows in a dataset matching a given condition, by using the agg()
method of the Dataset
class.
For example, I want to count the number of rows where label
equals to 1.0
in the following dataset:
SparkSession spark = ...
List<Row> rows = new ArrayList<>();
rows.add(RowFactory.create(0, 0.0));
rows.add(RowFactory.create(1, 1.0));
rows.add(RowFactory.create(2, 1.0));
Dataset<Row> ds =
spark.sqlContext().createDataFrame(rows,
new StructType(new StructField[] {
new StructField("id", DataTypes.LongType, false, Metadata.empty()),
new StructField("label", DataTypes.DoubleType, false, Metadata.empty())}));
My guess is to use the following code:
ds.agg(functions.count(ds.col("label").equalTo(1.0))).show();
However, a wrong result is displayed:
+--------------------+
|count((label = 1.0))|
+--------------------+
| 3|
+--------------------+
The right result should be of course 2
.
Is the agg()
method not supposed to work this way?