5

I've got an implementation of the k-means algorithm and I would like to make my process faster by using Java 8 streams and multicore processing.

I've got this code in Java 7:

//Step 2: For each point p:
//find nearest clusters c
//assign the point p to the closest cluster c
for (Point p : points) {
   double minDst = Double.MAX_VALUE;
   int minClusterNr = 1;
   for (Cluster c : clusters) {
      double tmpDst = determineDistance(p, c);
      if (tmpDst < minDst) {
         minDst = tmpDst;
         minClusterNr = c.clusterNumber;
      }
   }
   clusters.get(minClusterNr - 1).points.add(p);
}
//Step 3: For each cluster c
//find the central point of all points p in c
//set c to the center point
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
for (Cluster c : clusters) {
   double newX = 0;
   double newY = 0;
   for (Point p : c.points) {
      newX += p.x;
      newY += p.y;
   }
   newX = newX / c.points.size();
   newY = newY / c.points.size();
   newClusters.add(new Cluster(newX, newY, c.clusterNumber));
}

And I would like to use Java 8 with parallel streams to speed up the process. I have tried a bit and came up with this solution:

points.stream().forEach(p -> {
   minDst = Double.MAX_VALUE; //<- THESE ARE GLOBAL VARIABLES NOW
   minClusterNr = 1;          //<- THESE ARE GLOBAL VARIABLES NOW
   clusters.stream().forEach(c -> {
      double tmpDst = determineDistance(p, c);
      if (tmpDst < minDst) {
         minDst = tmpDst;
         minClusterNr = c.clusterNumber;
      }
   });
   clusters.get(minClusterNr - 1).points.add(p);
});
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
clusters.stream().forEach(c -> {
   newX = 0;  //<- THESE ARE GLOBAL VARIABLES NOW
   newY = 0;  //<- THESE ARE GLOBAL VARIABLES NOW
   c.points.stream().forEach(p -> {
      newX += p.x;
      newY += p.y;
   });
   newX = newX / c.points.size();
   newY = newY / c.points.size();
   newClusters.add(new Cluster(newX, newY, c.clusterNumber));
});

This solution with streams is considerably faster than the one without. And I was wondering if this already uses multicore processing? Why else would it suddenly be almost twice as fast?

without streams : Elapsed time: 202 msec & with streams : Elapsed time: 116 msec

Also would it be usefull to use parallelStream in any of these methods to speed them up even more? All it does right now is lead into ArrayOutOfBounce and NullPointer exceptions when I change the stream to stream().parallel().forEach(CODE)

---- EDIT (Added the source code as requested so you try this on your own) ----

--- Clustering.java ---

package algo;

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.Random;
import java.util.function.BiFunction;

import graphics.SimpleColorFun;

/**
 * An implementation of the k-means-algorithm.
 * <p>
 * Step 0: Determine the max size of the canvas
 * <p>
 * Step 1: Place clusters at random
 * <p>
 * Step 2: For each point p:<br>
 * find nearest clusters c<br>
 * assign the point p to the closest cluster c
 * <p>
 * Step 3: For each cluster c<br>
 * find the central point of all points p in c<br>
 * set c to the center point
 * <p>
 * Stop when none of the cluster x,y values change
 * @author makt
 *
 */
public class Clustering {

   private BiFunction<Integer, Integer, Color> colorFun = new SimpleColorFun();
   //   private BiFunction<Integer, Integer, Color> colorFun = new GrayScaleColorFun();

   public Random rngGenerator = new Random();

   public double max_x;
   public double max_y;
   public double max_xy;

   //---------------------------------
   //TODO: IS IT GOOD TO HAVE THOUSE VALUES UP HERE?
   double minDst = Double.MAX_VALUE;
   int minClusterNr = 1;

   double newX = 0;
   double newY = 0;
   //----------------------------------

   public boolean workWithStreams = false;

   public ArrayList<ArrayList<Cluster>> allGeneratedClusterLists = new ArrayList<ArrayList<Cluster>>();
   public ArrayList<BufferedImage> allGeneratedImages = new ArrayList<BufferedImage>();

   public Clustering(int seed) {
      rngGenerator.setSeed(seed);
   }

   public Clustering(Random rng) {
      rngGenerator = rng;
   }

   public void setup(int centroidCount, ArrayList<Point> points, int maxIterations) {

      //Step 0: Determine the max size of the canvas
      determineSize(points);

      ArrayList<Cluster> clusters = new ArrayList<Cluster>();
      //Step 1: Place clusters at random
      for (int i = 0; i < centroidCount; i++) {
         clusters.add(new Cluster(rngGenerator.nextInt((int) max_x), rngGenerator.nextInt((int) max_y), i + 1));
      }

      int iterations = 0;

      if (workWithStreams) {
         allGeneratedClusterLists.add(doClusteringWithStreams(points, clusters));
      } else {
         allGeneratedClusterLists.add(doClustering(points, clusters));
      }

      iterations += 1;

      //do until maxIterations is reached or until none of the cluster x and y values change anymore
      while (iterations < maxIterations) {
         //Step 2: happens inside doClustering
         if (workWithStreams) {
            allGeneratedClusterLists.add(doClusteringWithStreams(points, allGeneratedClusterLists.get(iterations - 1)));
         } else {
            allGeneratedClusterLists.add(doClustering(points, allGeneratedClusterLists.get(iterations - 1)));
         }

         if (!didPointsChangeClusters(allGeneratedClusterLists.get(iterations - 1), allGeneratedClusterLists.get(iterations))) {
            break;
         }

         iterations += 1;
      }

      System.out.println("Finished with " + iterations + " out of " + maxIterations + " max iterations");
   }

   /**
    * checks if the cluster x and y values changed compared to the previous x and y values
    * @param previousCluster
    * @param currentCluster
    * @return true if any cluster x or y values changed, false if all of them they are the same
    */
   private boolean didPointsChangeClusters(ArrayList<Cluster> previousCluster, ArrayList<Cluster> currentCluster) {
      for (int i = 0; i < previousCluster.size(); i++) {
         if (previousCluster.get(i).x != currentCluster.get(i).x || previousCluster.get(i).y != currentCluster.get(i).y) {
            return true;
         }
      }
      return false;
   }

   /**
    * 
    * @param points - all given points
    * @param clusters - its point list gets filled in this method
    * @return a new Clusters Array which has an <b> empty </b> point list.
    */
   private ArrayList<Cluster> doClustering(ArrayList<Point> points, ArrayList<Cluster> clusters) {
      //Step 2: For each point p:
      //find nearest clusters c
      //assign the point p to the closest cluster c

      for (Point p : points) {
         double minDst = Double.MAX_VALUE;
         int minClusterNr = 1;
         for (Cluster c : clusters) {
            double tmpDst = determineDistance(p, c);
            if (tmpDst < minDst) {
               minDst = tmpDst;
               minClusterNr = c.clusterNumber;
            }
         }
         clusters.get(minClusterNr - 1).points.add(p);
      }

      //Step 3: For each cluster c
      //find the central point of all points p in c
      //set c to the center point
      ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
      for (Cluster c : clusters) {
         double newX = 0;
         double newY = 0;
         for (Point p : c.points) {
            newX += p.x;
            newY += p.y;
         }
         newX = newX / c.points.size();
         newY = newY / c.points.size();
         newClusters.add(new Cluster(newX, newY, c.clusterNumber));
      }

      allGeneratedImages.add(createImage(clusters));

      return newClusters;
   }

   /**
    * Does the same as doClustering but about twice as fast!<br>
    * Uses Java8 streams to achieve this
    * @param points
    * @param clusters
    * @return
    */
   private ArrayList<Cluster> doClusteringWithStreams(ArrayList<Point> points, ArrayList<Cluster> clusters) {
      points.stream().forEach(p -> {
         minDst = Double.MAX_VALUE;
         minClusterNr = 1;
         clusters.stream().forEach(c -> {
            double tmpDst = determineDistance(p, c);
            if (tmpDst < minDst) {
               minDst = tmpDst;
               minClusterNr = c.clusterNumber;
            }
         });
         clusters.get(minClusterNr - 1).points.add(p);
      });

      ArrayList<Cluster> newClusters = new ArrayList<Cluster>();

      clusters.stream().forEach(c -> {
         newX = 0;
         newY = 0;
         c.points.stream().forEach(p -> {
            newX += p.x;
            newY += p.y;
         });
         newX = newX / c.points.size();
         newY = newY / c.points.size();
         newClusters.add(new Cluster(newX, newY, c.clusterNumber));
      });

      allGeneratedImages.add(createImage(clusters));

      return newClusters;
   }

   //draw all centers from clusters
   //draw all points
   //color points according to cluster value
   private BufferedImage createImage(ArrayList<Cluster> clusters) {
      //add 10% of the max size left and right to the image bounds
      //BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_BYTE_INDEXED);
      BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_INT_ARGB); // support 32-bit RGBA values
      Graphics2D g2d = bi.createGraphics();

      int numClusters = clusters.size();
      for (Cluster c : clusters) {
         //color points according to cluster value
         Color col = colorFun.apply(c.clusterNumber, numClusters);
         //draw all points
         g2d.setColor(col);
         for (Point p : c.points) {
            g2d.fillRect((int) p.x, (int) p.y, (int) (max_xy * 0.02), (int) (max_xy * 0.02));
         }
         //draw all centers from clusters
         g2d.setColor(new Color(160, 80, 80, 200)); // use RGBA: transparency=200
         g2d.fillOval((int) c.x, (int) c.y, (int) (max_xy * 0.03), (int) (max_xy * 0.03));
      }

      return bi;
   }

   /**
    * Calculates the euclidean distance without square root
    * @param p
    * @param c
    * @return
    */
   private double determineDistance(Point p, Cluster c) {
      //math.sqrt not needed because the relative distance does not change by applying the square root
      //        return Math.sqrt(Math.pow((p.x - c.x), 2)+Math.pow((p.y - c.y),2));

      return Math.pow((p.x - c.x), 2) + Math.pow((p.y - c.y), 2);
   }

   //TODO: What if coordinates can also be negative?
   private void determineSize(ArrayList<Point> points) {
      for (Point p : points) {
         if (p.x > max_x) {
            max_x = p.x;
         }
         if (p.y > max_y) {
            max_y = p.y;
         }
      }
      if (max_x > max_y) {
         max_xy = max_x;
      } else {
         max_xy = max_y;
      }
   }

}

--- Point.java ---

package algo;

public class Point {

    public double x;
    public double y;

    public Point(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public Point(double x, double y) {
        this.x = x;
        this.y = y;
    }


}

--- Cluster.java ---

package algo;

import java.util.ArrayList;

public class Cluster {

    public double x;
    public double y;

    public int clusterNumber;

    public ArrayList<Point> points = new ArrayList<Point>();

    public Cluster(double x, double y, int clusterNumber) {
        this.x = x;
        this.y = y;
        this.clusterNumber = clusterNumber;
    }

}

--- SimpleColorFun.java ---

package graphics;

import java.awt.Color;
import java.util.function.BiFunction;

/**
 * Simple function for selection a color for a specific cluster identified with an integer-ID.
 * 
 * @author makl, hese
 */
public class SimpleColorFun implements BiFunction<Integer, Integer, Color> {

   /**
    * Selects a color value.
    * @param n current index
    * @param numCol number of color-values possible
    */
   @Override
   public Color apply(Integer n, Integer numCol) {
      Color col = Color.BLACK;
      //color points according to cluster value
      switch (n) {
         case 1:
            col = Color.RED;
            break;
         case 2:
            col = Color.GREEN;
            break;
         case 3:
            col = Color.BLUE;
            break;
         case 4:
            col = Color.ORANGE;
            break;
         case 5:
            col = Color.MAGENTA;
            break;
         case 6:
            col = Color.YELLOW;
            break;
         case 7:
            col = Color.CYAN;
            break;
         case 8:
            col = Color.PINK;
            break;
         case 9:
            col = Color.LIGHT_GRAY;
            break;
         default:
            break;
      }
      return col;
   }

}

--- Main.java --- (REPLACE THE Stopwatch with some time logging mechanismus - I get this from our working environment)

package main;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import javax.imageio.ImageIO;

import algo.Clustering;
import algo.Point;
import eu.lbase.common.util.Stopwatch;
// import persistence.DataHandler;

public class Main {

   private static final String OUTPUT_DIR = (new File("./output/withoutStream")).getAbsolutePath() + File.separator;
   private static final String OUTPUT_DIR_2 = (new File("./output/withStream")).getAbsolutePath() + File.separator;

   public static void main(String[] args) {
      Random rng = new Random();
      int numPoints = 300;
      int seed = 2;

      ArrayList<Point> points = new ArrayList<Point>();
      rng.setSeed(rng.nextInt());
      for (int i = 0; i < numPoints; i++) {
         points.add(new Point(rng.nextInt(1000), rng.nextInt(1000)));
      }

      Stopwatch stw = Stopwatch.create(TimeUnit.MILLISECONDS);
      {
         // Stopwatch start
         System.out.println("--- Started without streams ---");
         stw.start();

         Clustering algo = new Clustering(seed);
         algo.setup(8, points, 25);

         // Stopwatch stop
         stw.stop();
         System.out.println("--- Finished without streams ---");
         System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed());

         System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR);

         deleteOldFiles(new File(OUTPUT_DIR));
         makeImages(OUTPUT_DIR, algo);

         System.out.println("Finished writing.\n");
      }

      {
         System.out.println("--- Started with streams ---");
         stw.start();

         Clustering algo = new Clustering(seed);
         algo.workWithStreams = true;
         algo.setup(8, points, 25);

         // Stopwatch stop
         stw.stop();
         System.out.println("--- Finished with streams ---");
         System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed());

         System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR_2);

         deleteOldFiles(new File(OUTPUT_DIR_2));
         makeImages(OUTPUT_DIR_2, algo);

         System.out.println("Finished writing.\n");
      }
   }

   /**
    * creates one image for each iteration in the given directory
    * @param algo
    */
   private static void makeImages(String dir, Clustering algo) {
      int i = 1;
      for (BufferedImage img : algo.allGeneratedImages) {
         try {
            String filename = String.format("%03d.png", i);
            ImageIO.write(img, "png", new File(dir + filename));
         } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
         }
         i++;
      }
   }

   /**
    * deletes old files from the target directory<br>
    * Does <b>not</b> delete directories!
    * @param dir - directory to delete files from
    * @return
    */
   private static boolean deleteOldFiles(File file) {
      File[] allContents = file.listFiles();
      if (allContents != null) {
         for (File f : allContents) {
            deleteOldFiles(f);
         }
      }
      if (!file.isDirectory()) {
         return file.delete();
      }
      return false;
   }

}
Vulkanos
  • 193
  • 19
  • Your question is an interesting one. I would so much appreciate if you could paste runnable code. – suenda Feb 09 '18 at 12:08
  • There you go. Edited my question and pasted the entire source code needed to run the algorithm and everything. Thanks for helping :) – Vulkanos Feb 09 '18 at 12:38
  • First, benchmarks must be run separeted, since the JVM can be hot in the second loop, so try doing the parallel first and non-parallel in second to see if it was hot, and run separeted for accurate results. – Marcos Vasconcelos Feb 09 '18 at 12:45
  • 1
    Have a look at this [How do I write a correct micro-benchmark in Java?](https://stackoverflow.com/questions/504103/how-do-i-write-a-correct-micro-benchmark-in-java) What many people here will tell you is that benchmarks that only execute once are no good, especially when you do not "warm up" the JVM. The first one might be slower because the compiler did not optimize stuff yet. Flip the two blocks around and test again. – Malte Hartwig Feb 09 '18 at 12:51
  • Marcos and Malte are right. If I am running the codes seperate they take about the same amount of time... So then comes my question... How do I actually make use of the parallel streams? – Vulkanos Feb 09 '18 at 12:54
  • 1
    Now you went to the other extreme; that’s definitely too much code containing entirely irrelevant stuff. – Holger Feb 09 '18 at 13:01
  • 1
    As a side note, the only place where you should refer to `ArrayList` is with `new ArrayList<>()`; at all other places, you should use `List`… – Holger Feb 09 '18 at 13:18

1 Answers1

3

When you want to use Streams efficiently, you should stop using forEach to basically write the same as the loop, and instead, learn about the aggregate operations. See also the comprehensive package documentation.

A thread safe solution may look like

points.stream().forEach(p -> {
    Cluster min = clusters.stream()
        .min(Comparator.comparingDouble(c -> determineDistance(p, c))).get();
    // your original code used the custerNumber to lookup the Cluster in
    // the list, don't know whether this is this really necessary
    min = clusters.get(min.clusterNumber - 1);

    // didn't find a better way considering your current code structure
    synchronized(min) {
        min.points.add(p);
    }
 });
 List<Cluster> newClusters = clusters.stream()
    .map(c -> new Cluster(
       c.points.stream().mapToDouble(p -> p.x).sum()/c.points.size(),
       c.points.stream().mapToDouble(p -> p.y).sum()/c.points.size(),
       c.clusterNumber))
    .collect(Collectors.toList());
}

but you didn’t provide enough context to test this. There are some open questions, e.g. you used the clusterNumber of the Cluster instance to look back into the clusters list; i don’t know whether the clusterNumber represents the actual list index of the Cluster instance we already have, i.e. if this is an unnecessary redundancy, or has a different meaning.

I also don’t know a better solution than synchronizing an the particular Cluster to make the manipulation of its list thread safe (given your current code structure). This is only needed if you decide to use a parallel stream, i.e. points.parallelStream().forEach(p -> …), other operations are unaffected.

You now have several streams you can try in parallel and sequential to find out where you get a benefit or not. Usually, only the other streams bear a significant benefit, if any…

Holger
  • 285,553
  • 42
  • 434
  • 765
  • Thank you very much Holger! Your solution works and your feedback is highly appreciated! You are right min = clusters.get(min.clusterNumber - 1); is not needed as min returns the cluster with the smallest distance to point p. (this is very nice!) However I have not noticed an improvement in time it took, when using paralell streams compared to the normal streams or normal for/each loops (java7). The opposit is the case.. parallel loops seem to be a bit slower than the regular ones? – Vulkanos Feb 09 '18 at 13:40
  • 2
    That’s why my last statement ends with “if any”. Parallel processing always adds an overhead that must be compensated, i.e. by having really heavy computation or a large amount of data, to get an advantage. Otherwise, there is no benefit. – Holger Feb 09 '18 at 13:42
  • Okey, thank you very very much! You helped me out a lot today! I have learned a bunch! :) – Vulkanos Feb 09 '18 at 13:47