I've got an implementation of the k-means algorithm and I would like to make my process faster by using Java 8 streams and multicore processing.
I've got this code in Java 7:
//Step 2: For each point p:
//find nearest clusters c
//assign the point p to the closest cluster c
for (Point p : points) {
double minDst = Double.MAX_VALUE;
int minClusterNr = 1;
for (Cluster c : clusters) {
double tmpDst = determineDistance(p, c);
if (tmpDst < minDst) {
minDst = tmpDst;
minClusterNr = c.clusterNumber;
}
}
clusters.get(minClusterNr - 1).points.add(p);
}
//Step 3: For each cluster c
//find the central point of all points p in c
//set c to the center point
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
for (Cluster c : clusters) {
double newX = 0;
double newY = 0;
for (Point p : c.points) {
newX += p.x;
newY += p.y;
}
newX = newX / c.points.size();
newY = newY / c.points.size();
newClusters.add(new Cluster(newX, newY, c.clusterNumber));
}
And I would like to use Java 8 with parallel streams to speed up the process. I have tried a bit and came up with this solution:
points.stream().forEach(p -> {
minDst = Double.MAX_VALUE; //<- THESE ARE GLOBAL VARIABLES NOW
minClusterNr = 1; //<- THESE ARE GLOBAL VARIABLES NOW
clusters.stream().forEach(c -> {
double tmpDst = determineDistance(p, c);
if (tmpDst < minDst) {
minDst = tmpDst;
minClusterNr = c.clusterNumber;
}
});
clusters.get(minClusterNr - 1).points.add(p);
});
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
clusters.stream().forEach(c -> {
newX = 0; //<- THESE ARE GLOBAL VARIABLES NOW
newY = 0; //<- THESE ARE GLOBAL VARIABLES NOW
c.points.stream().forEach(p -> {
newX += p.x;
newY += p.y;
});
newX = newX / c.points.size();
newY = newY / c.points.size();
newClusters.add(new Cluster(newX, newY, c.clusterNumber));
});
This solution with streams is considerably faster than the one without. And I was wondering if this already uses multicore processing? Why else would it suddenly be almost twice as fast?
without streams : Elapsed time: 202 msec & with streams : Elapsed time: 116 msec
Also would it be usefull to use parallelStream in any of these methods to speed them up even more? All it does right now is lead into ArrayOutOfBounce and NullPointer exceptions when I change the stream to stream().parallel().forEach(CODE)
---- EDIT (Added the source code as requested so you try this on your own) ----
--- Clustering.java ---
package algo;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.Random;
import java.util.function.BiFunction;
import graphics.SimpleColorFun;
/**
* An implementation of the k-means-algorithm.
* <p>
* Step 0: Determine the max size of the canvas
* <p>
* Step 1: Place clusters at random
* <p>
* Step 2: For each point p:<br>
* find nearest clusters c<br>
* assign the point p to the closest cluster c
* <p>
* Step 3: For each cluster c<br>
* find the central point of all points p in c<br>
* set c to the center point
* <p>
* Stop when none of the cluster x,y values change
* @author makt
*
*/
public class Clustering {
private BiFunction<Integer, Integer, Color> colorFun = new SimpleColorFun();
// private BiFunction<Integer, Integer, Color> colorFun = new GrayScaleColorFun();
public Random rngGenerator = new Random();
public double max_x;
public double max_y;
public double max_xy;
//---------------------------------
//TODO: IS IT GOOD TO HAVE THOUSE VALUES UP HERE?
double minDst = Double.MAX_VALUE;
int minClusterNr = 1;
double newX = 0;
double newY = 0;
//----------------------------------
public boolean workWithStreams = false;
public ArrayList<ArrayList<Cluster>> allGeneratedClusterLists = new ArrayList<ArrayList<Cluster>>();
public ArrayList<BufferedImage> allGeneratedImages = new ArrayList<BufferedImage>();
public Clustering(int seed) {
rngGenerator.setSeed(seed);
}
public Clustering(Random rng) {
rngGenerator = rng;
}
public void setup(int centroidCount, ArrayList<Point> points, int maxIterations) {
//Step 0: Determine the max size of the canvas
determineSize(points);
ArrayList<Cluster> clusters = new ArrayList<Cluster>();
//Step 1: Place clusters at random
for (int i = 0; i < centroidCount; i++) {
clusters.add(new Cluster(rngGenerator.nextInt((int) max_x), rngGenerator.nextInt((int) max_y), i + 1));
}
int iterations = 0;
if (workWithStreams) {
allGeneratedClusterLists.add(doClusteringWithStreams(points, clusters));
} else {
allGeneratedClusterLists.add(doClustering(points, clusters));
}
iterations += 1;
//do until maxIterations is reached or until none of the cluster x and y values change anymore
while (iterations < maxIterations) {
//Step 2: happens inside doClustering
if (workWithStreams) {
allGeneratedClusterLists.add(doClusteringWithStreams(points, allGeneratedClusterLists.get(iterations - 1)));
} else {
allGeneratedClusterLists.add(doClustering(points, allGeneratedClusterLists.get(iterations - 1)));
}
if (!didPointsChangeClusters(allGeneratedClusterLists.get(iterations - 1), allGeneratedClusterLists.get(iterations))) {
break;
}
iterations += 1;
}
System.out.println("Finished with " + iterations + " out of " + maxIterations + " max iterations");
}
/**
* checks if the cluster x and y values changed compared to the previous x and y values
* @param previousCluster
* @param currentCluster
* @return true if any cluster x or y values changed, false if all of them they are the same
*/
private boolean didPointsChangeClusters(ArrayList<Cluster> previousCluster, ArrayList<Cluster> currentCluster) {
for (int i = 0; i < previousCluster.size(); i++) {
if (previousCluster.get(i).x != currentCluster.get(i).x || previousCluster.get(i).y != currentCluster.get(i).y) {
return true;
}
}
return false;
}
/**
*
* @param points - all given points
* @param clusters - its point list gets filled in this method
* @return a new Clusters Array which has an <b> empty </b> point list.
*/
private ArrayList<Cluster> doClustering(ArrayList<Point> points, ArrayList<Cluster> clusters) {
//Step 2: For each point p:
//find nearest clusters c
//assign the point p to the closest cluster c
for (Point p : points) {
double minDst = Double.MAX_VALUE;
int minClusterNr = 1;
for (Cluster c : clusters) {
double tmpDst = determineDistance(p, c);
if (tmpDst < minDst) {
minDst = tmpDst;
minClusterNr = c.clusterNumber;
}
}
clusters.get(minClusterNr - 1).points.add(p);
}
//Step 3: For each cluster c
//find the central point of all points p in c
//set c to the center point
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
for (Cluster c : clusters) {
double newX = 0;
double newY = 0;
for (Point p : c.points) {
newX += p.x;
newY += p.y;
}
newX = newX / c.points.size();
newY = newY / c.points.size();
newClusters.add(new Cluster(newX, newY, c.clusterNumber));
}
allGeneratedImages.add(createImage(clusters));
return newClusters;
}
/**
* Does the same as doClustering but about twice as fast!<br>
* Uses Java8 streams to achieve this
* @param points
* @param clusters
* @return
*/
private ArrayList<Cluster> doClusteringWithStreams(ArrayList<Point> points, ArrayList<Cluster> clusters) {
points.stream().forEach(p -> {
minDst = Double.MAX_VALUE;
minClusterNr = 1;
clusters.stream().forEach(c -> {
double tmpDst = determineDistance(p, c);
if (tmpDst < minDst) {
minDst = tmpDst;
minClusterNr = c.clusterNumber;
}
});
clusters.get(minClusterNr - 1).points.add(p);
});
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
clusters.stream().forEach(c -> {
newX = 0;
newY = 0;
c.points.stream().forEach(p -> {
newX += p.x;
newY += p.y;
});
newX = newX / c.points.size();
newY = newY / c.points.size();
newClusters.add(new Cluster(newX, newY, c.clusterNumber));
});
allGeneratedImages.add(createImage(clusters));
return newClusters;
}
//draw all centers from clusters
//draw all points
//color points according to cluster value
private BufferedImage createImage(ArrayList<Cluster> clusters) {
//add 10% of the max size left and right to the image bounds
//BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_BYTE_INDEXED);
BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_INT_ARGB); // support 32-bit RGBA values
Graphics2D g2d = bi.createGraphics();
int numClusters = clusters.size();
for (Cluster c : clusters) {
//color points according to cluster value
Color col = colorFun.apply(c.clusterNumber, numClusters);
//draw all points
g2d.setColor(col);
for (Point p : c.points) {
g2d.fillRect((int) p.x, (int) p.y, (int) (max_xy * 0.02), (int) (max_xy * 0.02));
}
//draw all centers from clusters
g2d.setColor(new Color(160, 80, 80, 200)); // use RGBA: transparency=200
g2d.fillOval((int) c.x, (int) c.y, (int) (max_xy * 0.03), (int) (max_xy * 0.03));
}
return bi;
}
/**
* Calculates the euclidean distance without square root
* @param p
* @param c
* @return
*/
private double determineDistance(Point p, Cluster c) {
//math.sqrt not needed because the relative distance does not change by applying the square root
// return Math.sqrt(Math.pow((p.x - c.x), 2)+Math.pow((p.y - c.y),2));
return Math.pow((p.x - c.x), 2) + Math.pow((p.y - c.y), 2);
}
//TODO: What if coordinates can also be negative?
private void determineSize(ArrayList<Point> points) {
for (Point p : points) {
if (p.x > max_x) {
max_x = p.x;
}
if (p.y > max_y) {
max_y = p.y;
}
}
if (max_x > max_y) {
max_xy = max_x;
} else {
max_xy = max_y;
}
}
}
--- Point.java ---
package algo;
public class Point {
public double x;
public double y;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
public Point(double x, double y) {
this.x = x;
this.y = y;
}
}
--- Cluster.java ---
package algo;
import java.util.ArrayList;
public class Cluster {
public double x;
public double y;
public int clusterNumber;
public ArrayList<Point> points = new ArrayList<Point>();
public Cluster(double x, double y, int clusterNumber) {
this.x = x;
this.y = y;
this.clusterNumber = clusterNumber;
}
}
--- SimpleColorFun.java ---
package graphics;
import java.awt.Color;
import java.util.function.BiFunction;
/**
* Simple function for selection a color for a specific cluster identified with an integer-ID.
*
* @author makl, hese
*/
public class SimpleColorFun implements BiFunction<Integer, Integer, Color> {
/**
* Selects a color value.
* @param n current index
* @param numCol number of color-values possible
*/
@Override
public Color apply(Integer n, Integer numCol) {
Color col = Color.BLACK;
//color points according to cluster value
switch (n) {
case 1:
col = Color.RED;
break;
case 2:
col = Color.GREEN;
break;
case 3:
col = Color.BLUE;
break;
case 4:
col = Color.ORANGE;
break;
case 5:
col = Color.MAGENTA;
break;
case 6:
col = Color.YELLOW;
break;
case 7:
col = Color.CYAN;
break;
case 8:
col = Color.PINK;
break;
case 9:
col = Color.LIGHT_GRAY;
break;
default:
break;
}
return col;
}
}
--- Main.java --- (REPLACE THE Stopwatch with some time logging mechanismus - I get this from our working environment)
package main;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import javax.imageio.ImageIO;
import algo.Clustering;
import algo.Point;
import eu.lbase.common.util.Stopwatch;
// import persistence.DataHandler;
public class Main {
private static final String OUTPUT_DIR = (new File("./output/withoutStream")).getAbsolutePath() + File.separator;
private static final String OUTPUT_DIR_2 = (new File("./output/withStream")).getAbsolutePath() + File.separator;
public static void main(String[] args) {
Random rng = new Random();
int numPoints = 300;
int seed = 2;
ArrayList<Point> points = new ArrayList<Point>();
rng.setSeed(rng.nextInt());
for (int i = 0; i < numPoints; i++) {
points.add(new Point(rng.nextInt(1000), rng.nextInt(1000)));
}
Stopwatch stw = Stopwatch.create(TimeUnit.MILLISECONDS);
{
// Stopwatch start
System.out.println("--- Started without streams ---");
stw.start();
Clustering algo = new Clustering(seed);
algo.setup(8, points, 25);
// Stopwatch stop
stw.stop();
System.out.println("--- Finished without streams ---");
System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed());
System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR);
deleteOldFiles(new File(OUTPUT_DIR));
makeImages(OUTPUT_DIR, algo);
System.out.println("Finished writing.\n");
}
{
System.out.println("--- Started with streams ---");
stw.start();
Clustering algo = new Clustering(seed);
algo.workWithStreams = true;
algo.setup(8, points, 25);
// Stopwatch stop
stw.stop();
System.out.println("--- Finished with streams ---");
System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed());
System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR_2);
deleteOldFiles(new File(OUTPUT_DIR_2));
makeImages(OUTPUT_DIR_2, algo);
System.out.println("Finished writing.\n");
}
}
/**
* creates one image for each iteration in the given directory
* @param algo
*/
private static void makeImages(String dir, Clustering algo) {
int i = 1;
for (BufferedImage img : algo.allGeneratedImages) {
try {
String filename = String.format("%03d.png", i);
ImageIO.write(img, "png", new File(dir + filename));
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
i++;
}
}
/**
* deletes old files from the target directory<br>
* Does <b>not</b> delete directories!
* @param dir - directory to delete files from
* @return
*/
private static boolean deleteOldFiles(File file) {
File[] allContents = file.listFiles();
if (allContents != null) {
for (File f : allContents) {
deleteOldFiles(f);
}
}
if (!file.isDirectory()) {
return file.delete();
}
return false;
}
}