0

Today I tried to write in Python a program that I had already made in Java, to isolate the background from a text image. And I can't figure out why my Python program is much slower to execute than my Java one (it takes less than a second in Java, and more than 5 minutes in Python), although they almost do exactly the same thing.

I'd like to mention that I am quite new to Python, so I'm sorry if I made an obvious mistake.

Here's the Python code:

import random
import math
from PIL import Image
from pythonds.basic.stack import Stack

def isolatebackground(image, max_distance):
    newimage = image.copy()
    (width, height) = image.size

    w = random.randrange(width);
    h = random.randrange(height);

    # Preparing unvisited pixels
    unvisitedpixels = []

    for x in range(width):
        for y in range(height):
            unvisitedpixels.append((x,y))

    # Background pixels
    background = []

    stack = Stack()
    stack.push((x,y))
    (width, height) = image.size
    pixels = image.load()
    unvisitedpixels.remove((x,y))

    possibleneighbors = [i for i in range(8)]
    translations = [(-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0)]

    # Region growing algorithm
    while not stack.isEmpty():
        (currentred, currentgreen, currentblue, currentalpha) = pixels[x, y]
        foundneighbor = False
        unvisitedneighbors = possibleneighbors.copy()

        while len(unvisitedneighbors) > 0:
            n = random.choice(unvisitedneighbors)
            unvisitedneighbors.remove(n)
            (i, j) = translations[n]
            newX = x + i
            newY = y + j
            if newX >= 0 and newX < width and newY >= 0 and newY < height:
                if (newX, newY) in unvisitedpixels:
                    unvisitedpixels.remove((newX, newY))
                    (newred, newgreen, newblue, newalpha) = pixels[newX, newY]

                    distance = math.sqrt((newred - currentred) ** 2 + (newgreen - currentgreen) ** 2 + (newblue - currentblue) ** 2)

                    if distance <= max_distance:
                        foundneighbor = True
                        background.append((newX, newY))

                        stack.push((newX, newY))
                        (x, y) = (newX, newY)

        if not foundneighbor:
            (x,y) = stack.pop()

    for p in background:
        newimage.putpixel(p, (255, 0, 0))

    return newimage


if __name__ == '__main__':
    image = Image.open("TestImage3.png")
    newimage = isolatebackground(image, 5)
    newimage.show()

The Java code:

Main.java:

import java.awt.image.BufferedImage;
import java.io.File;
import javax.imageio.ImageIO;
import java.io.IOException;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.Stack;
import java.util.TreeMap;
import javax.imageio.ImageIO;

public class Main
{
    public static void main(String[] args) throws IOException
    {
        BufferedImage im = readImage("TestImage3.png");

        BufferedImage split_im = removeBackground(im, 5);

        File f1 = new File("test.png");

        ImageIO.write(split_im, "png", f1);
    }

    public static BufferedImage readImage(String path) throws IOException
    {
        File image_file = new File(path);

        BufferedImage image = ImageIO.read(image_file);

        return image;
    }

    public static BufferedImage removeBackground (BufferedImage image, double threshold)
    {
        BufferedImage split_im = new BufferedImage(image.getWidth(), image.getHeight(), BufferedImage.TYPE_INT_ARGB);
        Graphics g = split_im.getGraphics();
        g.drawImage(image, 0, 0, null);
        g.dispose();

        int x = (int)(image.getWidth() * Math.random());
        int y = (int)(image.getHeight() * Math.random());

        split_im.setRGB(x, y, Color.RED.getRGB());

        int[][] translations = {{-1, -1}, {0, -1}, {1, -1}, {1, 0}, {1, 1}, {0, 1}, {-1, 1}, {-1, 0}};
        List<Integer> possibleNeighbors = new ArrayList<>();

        for (int i = 0; i < 8; i++)
        {
            possibleNeighbors.add(i);
        }

        Map<Coordinates, Integer> unvisitedPixels = new HashMap<>();

        for (int h = 0; h < image.getHeight(); h++)
        {
            for (int w = 0; w < image.getWidth(); w++)
            {
                unvisitedPixels.put(new Coordinates(w, h), image.getRGB(w, h));
            }
        }

        Coordinates initialCoord = new Coordinates(x, y);

        unvisitedPixels.remove(initialCoord);

        Stack visitedPixels = new Stack();
        visitedPixels.push(initialCoord);

        while (!visitedPixels.empty())
        {
            Color c1 = new Color(image.getRGB(x, y));
            int r1 = c1.getRed();
            int g1 = c1.getGreen();
            int b1 = c1.getBlue();
            boolean foundNeighbor = false;

            List<Integer> unvisitedNeighbors = new ArrayList<>(possibleNeighbors);

            while (!unvisitedNeighbors.isEmpty())
            {
                int n = (int)(8 * Math.random());

                if (unvisitedNeighbors.contains(n))
                {
                    unvisitedNeighbors.remove(unvisitedNeighbors.indexOf(n));
                    int[] translation = translations[n];
                    int newX = x + translation[0];
                    int newY = y + translation[1];

                    if (newX >= 0 && newX < image.getWidth() && newY >= 0 && newY < image.getHeight())
                    {
                        Coordinates coord = new Coordinates(newX, newY);
                        Integer value = unvisitedPixels.get(coord);

                        if (value != null)
                        {
                            Color c2 = new Color(value);
                            int r2 = c2.getRed();
                            int g2 = c2.getGreen();
                            int b2 = c2.getBlue();

                            double distance = Math.sqrt(Math.pow(r1 - r2, 2) + Math.pow(g1 - g2, 2) + Math.pow(b1 - b2, 2));

                            if (distance <= threshold)
                            {
                                foundNeighbor = true;
                                split_im.setRGB(x, y, Color.RED.getRGB());
                                unvisitedPixels.remove(coord);

                                x = newX;
                                y = newY;
                                visitedPixels.push(coord);
                            }
                        }
                    }
                }
            }

            if (!foundNeighbor)
            {
                Coordinates newCoord = (Coordinates)visitedPixels.pop();
                x = newCoord.getX();
                y = newCoord.getY();
            }
        }

        return split_im;
    }
}

Coordinates.java:

public class Coordinates
{
    private int x;
    private int y;

    public Coordinates(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public int getX() {
        return x;
    }

    public int getY() {
        return y;
    }

    @Override
    public String toString() {
        return "Coordinates{" + "x=" + x + ", y=" + y + '}';
    }

    @Override
    public int hashCode() {
        int hash = 7;
        hash = 73 * hash + this.x;
        hash = 73 * hash + this.y;
        return hash;
    }

    @Override
    public boolean equals(Object obj) {
        if (obj == null)
        {
            return false;
        }
        else if (!(obj instanceof Coordinates))
        {
            return false;
        }
        else
        {
            Coordinates other = (Coordinates)obj;

            return this.x == other.getX() && this.y == other.getY();
        }
    }

}

I used this image as a test : Test image (it has a low resolution, otherwise it would take ages for the Python program to analyze it).

Do you know what could lead to such poor performance in Python?

Thanks!

  • 3
    I think this question is probably way too broad. Unless somebody could easily spot problems in your code, you basically expect to **read and understand** two complex pieces of software; to then come up with assumption about performance issues. I think you should rather step back and you, yourself should learn how **profile** java / python programs; and then **measure** yourself where time is spent. – GhostCat Dec 26 '16 at 20:34
  • Okay, but what do you mean with "profile java / python programs"? – LeChocdesGitans Dec 26 '16 at 20:36
  • `profiler` is special program (function in IDE) to measure time used by elements/functions in code. wikipedia: [Profiling (computer_programming)](https://en.wikipedia.org/wiki/Profiling_(computer_programming)) – furas Dec 26 '16 at 20:40
  • Use the profiler and isolate what part of the code takes so long. https://docs.python.org/2/library/profile.html – Klaus D. Dec 26 '16 at 20:43
  • 1
    These both are not entirely same code. You are using different libraries, different languages (with different compilers), and above all the different logic. There are huge number of factors responsible for the change in performance – Moinuddin Quadri Dec 26 '16 at 20:44
  • I will see what I can do with a profiler. Moinuddin Quadri, when I said "almost exactly the same code", I meant that I tried to use the same logic in both languages. And I'm wondering if this lack of performance is my fault (and then I can do something about it), or if the Python interpreter I'm using is responsible. – LeChocdesGitans Dec 26 '16 at 20:56
  • I think I finally found what the problem was. Using a profiler, I noticed that removing elements from list unvisitedpixels took almost half of the execution time. I replaced it by a dictionnary, and now it taked about 2s to analyze the image I put in the original post. It's still a bit slower than my Java, but it's much better! – LeChocdesGitans Dec 26 '16 at 23:09

0 Answers0