-1

The following program is a solution to the Marching Square problem in Python:

from typing import List

def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return []

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    linesList = linesList + list
    else:
        raise AssertionError

    return [linesList]

The problem with this source code is - it takes ages to generate an output.

I.e. using the following driver program:

import drawSvg as draw_svg

N_int = 800
N2_float = N_int / 8
x_int_vector = [i for i in range(N_int)]
y_int_vector = [i for i in range(N_int)]

matrix_256x256 = [[(math.sin(i / N2_float) * math.sin(j / N2_float)) for i in range(N_int)] for j in range(N_int)]

fill = "#2591a3"
drawing = draw_svg.Drawing(N_int, N_int, displayInline=False)

threshold_float_list = [0.2, 0.4, 0.6, 0.8]
collection = marching_square(x_int_vector, y_int_vector, matrix_256x256, threshold_float_list)
for line_set in collection:
    for line in line_set:
        drawing.append(draw_svg.Line(line[0], line[1], line[2], line[3], stroke='red'))
     # END of line
# END of line_set
drawing.saveSvg('example.svg') 

The code becomes horribly slow for practical use.

How can I speed up the code?

N.B. marching_square()'s signature must not be changed.

user366312
  • 16,949
  • 65
  • 235
  • 452

1 Answers1

1

Got ~10x speed-up

  1. Removed extending list which was biggest bottleneck (using this trick to concatenate list of lists)
  2. Applied numba to GetCaseId which was second bottleneck
from typing import List
import numba
import functools
import operator

@numba.jit(nopython=True)
def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return None

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    if list:
                        linesList.append(list)

    else:
        raise AssertionError

    return functools.reduce(operator.iconcat, linesList, [])
dankal444
  • 3,172
  • 1
  • 23
  • 35