0

I would like to implement a feature in a python pyqt GUI similar to something that I have seen in AstroImageJ, where you can adjust the contrast of an image.

AstroImageJ contrast adjustment example

I am new to Python and haven't found any way to do this yet. Maybe matplotlib widgets or artists offer such functionality?

Also sorry if the title is confusing. I welcome any suggestions for improvement!

eyllanesc
  • 235,170
  • 19
  • 170
  • 241
mapf
  • 1,906
  • 1
  • 14
  • 40
  • Matplotlib provides [sliders](https://www.google.de/search?q=matplotlib+slider). PyQt provides [sliders](https://www.google.de/search?q=pyqt+slider) as well. – ImportanceOfBeingErnest Jan 21 '19 at 14:45
  • Hi, thank you for the suggestion! I looked at matplotlib sliders already and as far as I can tell, they only have very limited functionality, i.e., when I use such a slider, they plot over the whole histogram. I haven't found a way for them to appear anywhere similar to what I would like. – mapf Jan 21 '19 at 16:18
  • I suppose it isn't particularly clear what exactly "what I would like" would be, and especially where the problem of achieving that lies. – ImportanceOfBeingErnest Jan 21 '19 at 16:48
  • if you dont depend on matplotlib: pyqtgraph has this with [LinearRegionItem](http://www.pyqtgraph.org/documentation/graphicsItems/linearregionitem.html) – Jonas Jan 21 '19 at 22:20
  • @Jonas: Interesting! I would have to try and see if I can make it work. – mapf Jan 22 '19 at 08:08
  • @ImportanceOfBeingErnest: Sorry if I didn't make it clear enough. So basically, I would like something that looks similar to what you see in the gif, i.e. having some kind of handle/picker/grabbing thing (I'm not sure about the correct terminology here) on the axes that you can drag along and having at least a line in the histogram that indicates the position. If I can also somehow make the excluded areas gray that would be great because it's a nice additional visual clue but for now it's not essential I think. – mapf Jan 22 '19 at 08:17
  • You literally want your output to look like the one shown in the image? That's cumbersome. One can of course mimic the slider handles via a polygon of 5 points, then subclass the `Slider` class to add those. It seems possible, but honestly, I doubt anyone will do this work for you for free. – ImportanceOfBeingErnest Jan 22 '19 at 23:57
  • I see, thank you. It's a shame there doesn't seem to exist a similar feature already, but I guess it's too specific. I wasn't expecting someone to build this tool basically from scratch for free though. Apparently, that's what you need to be end up doing, but I didn't know that before. – mapf Jan 23 '19 at 08:10

1 Answers1

0

So I think I have finally come to a solution however it can probably be improved a lot. I post the code for a working example here however it's rather long and also includes some other methods to zoom and pan in the image. If anyone would like to try it out and give me some feedback you are very welcome!

Also sometimes, when I close the window, I get this error message several times:

Exception ignored in: <function WeakMethod.__new__.<locals>._cb at 0x00000193A3D7C7B8>
Traceback (most recent call last):
  File "C:\Users\mapf\Anaconda3\lib\weakref.py", line 58, in _cb
  File "C:\Users\mapf\Anaconda3\lib\site-packages\matplotlib\cbook\__init__.py", line 182, in _remove_proxy
  File "C:\Users\mapf\Anaconda3\lib\weakref.py", line 74, in __eq__
TypeError: isinstance() arg 2 must be a type or tuple of types

This is what it looks like:

enter image description here

Here is the code:

import sys
import numpy as np
import copy
import matplotlib.pyplot as plt
from matplotlib.text import Annotation
import matplotlib.patheffects as PathEffects
from matplotlib.backends.backend_qt5agg import \
    FigureCanvasQTAgg as FigureCanvas
from matplotlib.patches import Rectangle

from PyQt5.QtWidgets import QDialog, QApplication, QGridLayout
from astropy.visualization import ImageNormalize, LinearStretch, ZScaleInterval


class IDAnnotation(Annotation):
    def __init__(
            self, text, position, ha='center', rotation=0, fontsize=15,
            picker=False, zorder=3, clip_on=True, identifier='',
            verticalalignment='baseline'
    ):
        super().__init__(
            text, position, ha=ha, rotation=rotation, fontsize=fontsize,
            picker=picker, zorder=zorder, clip_on=clip_on,
            verticalalignment=verticalalignment
        )
        self._id = identifier

    def get_id(self):
        return self._id

    def set_id(self, identifier):
        self._id = identifier


class ImageFigure:
    def __init__(self, image):
        self.fig, self.ax = plt.subplots()
        self.canvas = FigureCanvas(self.fig)
        self.base_image = image
        self.base_image[np.where(self.base_image < 0)] = 0
        self.image = copy.deepcopy(self.base_image)
        self.norm = ImageNormalize(
            self.image, stretch=LinearStretch(),
            interval=ZScaleInterval()
        )
        self.image_artist = self.ax.imshow(
            image, cmap='gray', interpolation='nearest', norm=self.norm
        )
        self.clim = self.image_artist.get_clim()
        self.base_scale = 2.0
        self.base_xlim = self.ax.get_xlim()
        self.base_ylim = self.ax.get_ylim()
        self.new_xlim = [0, 1]
        self.new_ylim = [0, 1]
        self.x_press = 0
        self.y_press = 0
        self.fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.fig.canvas.mpl_connect('button_press_event', self.pan_press)
        self.fig.canvas.mpl_connect('motion_notify_event', self.pan_move)

        self.hist = np.hstack(self.base_image)
        self.hist = np.delete(self.hist, np.where(self.hist == 0))
        self.contrast = HistogramFigure(self.hist, self.clim)
        # self.contrast.fig.canvas.mpl_connect(
        #     'button_release_event', self.adjust_contrast
        # )
        self.contrast.fig.canvas.mpl_connect(
            'motion_notify_event', self.adjust_contrast
        )

    def adjust_contrast(self, event):
        self.contrast.on_move_event(event)
        low_in = self.contrast.lclim
        high_in = self.contrast.uclim

        self.image_artist.set_clim(low_in, high_in)
        self.canvas.draw_idle()

    def zoom(self, event):
        xdata = event.xdata
        ydata = event.ydata
        if xdata is None or ydata is None:
            pass
        else:
            cur_xlim = self.ax.get_xlim()
            cur_ylim = self.ax.get_ylim()
            x_left = xdata - cur_xlim[0]
            x_right = cur_xlim[1] - xdata
            y_top = ydata - cur_ylim[0]
            y_bottom = cur_ylim[1] - ydata
            if event.button == 'up':
                scale_factor = 1 / self.base_scale
            elif event.button == 'down':
                scale_factor = self.base_scale
            else:
                scale_factor = 1

            new_xlim = [
                xdata - x_left*scale_factor, xdata + x_right*scale_factor
            ]
            new_ylim = [
                ydata - y_top*scale_factor, ydata + y_bottom*scale_factor
            ]

            # intercept new plot parameters if they are out of bound
            self.new_xlim, self.new_ylim = check_limits(
                self.base_xlim, self.base_ylim, new_xlim, new_ylim
            )

            self.ax.set_xlim(self.new_xlim)
            self.ax.set_ylim(self.new_ylim)
            self.canvas.draw()

    def pan_press(self, event):
        if event.button == 1:
            if event.xdata is None or event.ydata is None:
                pass
            else:
                self.x_press = event.xdata
                self.y_press = event.ydata

    def pan_move(self, event):
        if event.button == 1:
            xdata = event.xdata
            ydata = event.ydata
            if xdata is None or ydata is None:
                pass
            else:
                cur_xlim = self.ax.get_xlim()
                cur_ylim = self.ax.get_ylim()
                dx = xdata - self.x_press
                dy = ydata - self.y_press
                new_xlim = [cur_xlim[0] - dx, cur_xlim[1] - dx]
                new_ylim = [cur_ylim[0] - dy, cur_ylim[1] - dy]

                # intercept new plot parameters that are out of bound
                new_xlim, new_ylim = check_limits(
                    self.base_xlim, self.base_ylim, new_xlim, new_ylim
                )

                self.ax.set_xlim(new_xlim)
                self.ax.set_ylim(new_ylim)
                self.canvas.draw()


class HistogramFigure:
    def __init__(self, image, clim):
        self.fig, self.ax = plt.subplots()
        self.canvas = FigureCanvas(self.fig)
        self.image = image
        self.clim = clim
        self.uclim = self.clim[1]
        self.lclim = self.clim[0]
        self.nbins = 20
        self.dragged = None
        self.pick_pos = None
        self.uclim_hightlight = False
        self.lclim_hightlight = False
        self.dummy_patches = [False, False]
        self.cropped_patches_index = [0, 0]
        self.canvas.setMaximumHeight(100)
        self.fig.subplots_adjust(left=0.07, right=0.98, bottom=0.1, top=0.75)
        self.ax.tick_params(
            axis="both", labelsize=6, left=True, top=True, labelleft=True,
            labeltop=True, bottom=False, labelbottom=False
        )
        self.ax.tick_params(which='minor', bottom=False, top=True)
        self.bins = np.geomspace(
            min(self.image), max(self.image), self.nbins
        )
        _, _, self.patches = self.ax.hist(
            self.image, bins=self.bins, log=True, zorder=1
        )
        self.ax.set_xscale("log", nonposx='clip')
        self.color_patches()

        self.ax.margins(0, 0.1)
        self.uclim_marker = IDAnnotation(
            r'$\blacktriangledown$',
            (self.uclim, self.ax.get_ylim()[1]/6),
            ha='center', fontsize=15, picker=True, zorder=3, clip_on=False,
            identifier='uclim'
        )
        self.lclim_marker = IDAnnotation(
            r'$\blacktriangle$',
            (self.lclim+self.ax.get_xlim()[0], self.ax.get_ylim()[0]*16),
            ha='center', verticalalignment='top', fontsize=15, picker=True,
            zorder=2, clip_on=False, identifier='lclim'
        )
        self.ax.add_artist(self.uclim_marker)
        self.ax.add_artist(self.lclim_marker)

        self.fig.canvas.mpl_connect('pick_event', self.on_pick_event)
        self.fig.canvas.mpl_connect(
            'motion_notify_event', self.highlight_picker
        )
        self.fig.canvas.mpl_connect(
            'button_release_event', self.on_release_event
        )
        self.fig.canvas.mpl_connect(
            'button_press_event', self.on_button_press_event
        )

        self.canvas.draw()

    def color_patches(self):
        j = 0
        i = self.bins[j]
        overlap = False
        while i < self.lclim:
            self.patches[j].set_facecolor('gray')
            j += 1
            i = self.bins[j]
        if j > 0:
            self.cropped_patches_index[0] = j - 1
            self.patches[j - 1].set_width(self.lclim - self.bins[j - 1])
            self.patches[j - 1].set_facecolor('gray')
            if self.uclim <= self.bins[j]:
                width = self.uclim - self.lclim
                overlap = True
            else:
                width = self.bins[j] - self.lclim
            if self.dummy_patches[0]:
                self.dummy_patches[0].set_xy(
                    (self.lclim, self.patches[j - 1].get_y())
                )
                self.dummy_patches[0].set_width(width)
                self.dummy_patches[0].set_height(
                    self.patches[j - 1].get_height())
            else:
                self.dummy_patches[0] = Rectangle(
                    (self.lclim, self.patches[j - 1].get_y()),
                    width=width, linewidth=0,
                    height=self.patches[j - 1].get_height(), color='c'
                )
                self.ax.add_artist(self.dummy_patches[0])
        if not overlap:
            while np.logical_and(
                    i < np.max(self.bins), i < self.uclim
            ):
                self.patches[j].set_facecolor('c')
                j += 1
                i = self.bins[j]
            self.cropped_patches_index[1] = j-1
            self.patches[j-1].set_width(self.uclim - self.bins[j-1])
            self.patches[j-1].set_facecolor('c')
        if self.dummy_patches[1]:
            self.dummy_patches[1].set_xy(
                (self.uclim, self.patches[j-1].get_y())
            )
            self.dummy_patches[1].set_width(self.bins[j]-self.uclim)
            self.dummy_patches[1].set_height(self.patches[j-1].get_height())
        else:
            self.dummy_patches[1] = Rectangle(
                (self.uclim, self.patches[j-1].get_y()),
                width=self.bins[j]-self.uclim, linewidth=0,
                height=self.patches[j-1].get_height(), color='gray'
            )
        self.ax.add_artist(self.dummy_patches[1])
        while i < max(self.bins):
            self.patches[j].set_facecolor('gray')
            j += 1
            i = self.bins[j]

    def add_dummy(self, j, colors, limit):
        if colors[0] == 'gray':
            idx = 0
        else:
            idx = 1
        self.cropped_patches_index[idx] = j
        self.patches[j].set_width(limit - self.bins[j])
        self.patches[j].set_facecolor(colors[0])
        self.dummy_patches[idx].set_xy((limit, self.patches[j].get_y()))
        self.dummy_patches[idx].set_width(self.bins[j]-limit)
        self.dummy_patches[idx].set_height(self.patches[j].get_height())
        # self.dummy_patches[0] = Rectangle(
        #     (limit, self.patches[j].get_y()),
        #     width=self.bins[j]-limit, linewidth=0,
        #     height=self.patches[j].get_height(),
        #     color=colors[1]
        # )
        # self.ax.add_artist(self.dummy_patches[0])

    def on_pick_event(self, event):
        """
            Store which text object was picked and were the pick event occurs.
        """
        if isinstance(event.artist, Annotation):
            self.dragged = event.artist
            inv = self.ax.transData.inverted()
            self.pick_pos = inv.transform(
                (event.mouseevent.x, event.mouseevent.y)
            )[0]
            if self.pick_pos < self.ax.get_xlim()[0]:
                self.pick_pos = self.ax.get_xlim()[0]
            if self.pick_pos > self.ax.get_xlim()[1]:
                self.pick_pos = self.ax.get_xlim()[1]
        return True

    def on_button_press_event(self, event):
        if np.logical_and(
            event.button == 1,
            self.lclim_marker.contains(event)[0]
            == self.uclim_marker.contains(event)[0]
        ):
            inv = self.ax.transData.inverted()
            self.pick_pos = inv.transform(
                (event.x, event.y)
            )[0]

    def on_release_event(self, _):
        if self.dragged is not None:
            self.dragged = None

    def on_move_event(self, event):
        """Update text position and redraw"""
        if event.button == 1:
            inv = self.ax.transData.inverted()
            new_pos = (inv.transform((event.x, event.y))[0])
            if self.dragged is not None:
                old_pos = self.dragged.get_position()
                if self.dragged.get_id() == 'lclim':
                    if new_pos < self.ax.get_xlim()[0]:
                        new_pos = self.ax.get_xlim()[0]
                    self.lclim = new_pos
                    if self.lclim > self.uclim:
                        self.lclim = self.uclim*0.999
                    self.dragged.set_position(
                        (self.lclim, old_pos[1])
                    )
                    self.patches[
                        self.cropped_patches_index[0]].set_width(
                        self.bins[self.cropped_patches_index[0] + 1]
                        - self.bins[self.cropped_patches_index[0]]
                    )
                elif self.dragged.get_id() == 'uclim':
                    if new_pos > self.ax.get_xlim()[1]:
                        new_pos = self.ax.get_xlim()[1]
                    self.uclim = new_pos
                    if self.uclim < self.lclim:
                        self.uclim = self.lclim*1.001
                    self.dragged.set_position(
                        (self.uclim, old_pos[1])
                    )
                    self.patches[
                        self.cropped_patches_index[1]].set_width(
                        self.bins[self.cropped_patches_index[1] + 1]
                        - self.bins[self.cropped_patches_index[1]]
                    )
                else:
                    pass

                # self.dummy_patches = []

                self.color_patches()

                self.ax.figure.canvas.draw()
            else:
                pass

        return True

    def highlight_picker(self, event):
        if event.button == 1:
            pass
        else:
            if self.uclim_marker.contains(event)[0]:
                if not self.uclim_hightlight:
                    self.uclim_hightlight = True
                    self.uclim_marker.set_path_effects(
                        [PathEffects.withStroke(linewidth=2, foreground="c")]
                    )
                    self.ax.figure.canvas.draw()
                else:
                    pass
            else:
                if self.uclim_hightlight:
                    self.uclim_hightlight = False
                    self.uclim_marker.set_path_effects(
                        [PathEffects.Normal()]
                    )
                    self.ax.figure.canvas.draw()
                else:
                    pass

            if self.lclim_marker.contains(event)[0]:
                if self.lclim_hightlight:
                    pass
                else:
                    self.lclim_hightlight = True
                    self.lclim_marker.set_path_effects(
                        [PathEffects.withStroke(linewidth=2, foreground="c")]
                    )
                    self.ax.figure.canvas.draw()
            else:
                if self.lclim_hightlight:
                    self.lclim_hightlight = False
                    self.lclim_marker.set_path_effects(
                        [PathEffects.Normal()]
                    )
                    self.ax.figure.canvas.draw()
                else:
                    pass

        return True


class MainWindow(QDialog):
    def __init__(self):
        super().__init__()
        self.img = np.random.random((500, 500))
        self.layout = None
        self.image = None
        self.contrast = None

        self.create_widgets()

    def create_widgets(self):
        self.layout = QGridLayout(self)
        self.image = ImageFigure(self.img)
        self.contrast = self.image.contrast

        self.layout.addWidget(self.image.canvas, 0, 0)
        self.layout.addWidget(self.contrast.canvas, 1, 0)


def check_limits(base_xlim, base_ylim, new_xlim, new_ylim):
    if new_xlim[0] < base_xlim[0]:
        overlap = base_xlim[0] - new_xlim[0]
        new_xlim[0] = base_xlim[0]
        if new_xlim[1] + overlap > base_xlim[1]:
            new_xlim[1] = base_xlim[1]
        else:
            new_xlim[1] += overlap
    if new_xlim[1] > base_xlim[1]:
        overlap = new_xlim[1] - base_xlim[1]
        new_xlim[1] = base_xlim[1]
        if new_xlim[0] - overlap < base_xlim[0]:
            new_xlim[0] = base_xlim[0]
        else:
            new_xlim[0] -= overlap
    if new_ylim[1] < base_ylim[1]:
        overlap = base_ylim[1] - new_ylim[1]
        new_ylim[1] = base_ylim[1]
        if new_ylim[0] + overlap > base_ylim[0]:
            new_ylim[0] = base_ylim[0]
        else:
            new_ylim[0] += overlap
    if new_ylim[0] > base_ylim[0]:
        overlap = new_ylim[0] - base_ylim[0]
        new_ylim[0] = base_ylim[0]
        if new_ylim[1] - overlap < base_ylim[1]:
            new_ylim[1] = base_ylim[1]
        else:
            new_ylim[1] -= overlap

    return new_xlim, new_ylim


if __name__ == '__main__':
    app = QApplication(sys.argv)
    GUI = MainWindow()
    GUI.show()
    sys.exit(app.exec_())
mapf
  • 1,906
  • 1
  • 14
  • 40