So I think I have finally come to a solution however it can probably be improved a lot. I post the code for a working example here however it's rather long and also includes some other methods to zoom and pan in the image. If anyone would like to try it out and give me some feedback you are very welcome!
Also sometimes, when I close the window, I get this error message several times:
Exception ignored in: <function WeakMethod.__new__.<locals>._cb at 0x00000193A3D7C7B8>
Traceback (most recent call last):
File "C:\Users\mapf\Anaconda3\lib\weakref.py", line 58, in _cb
File "C:\Users\mapf\Anaconda3\lib\site-packages\matplotlib\cbook\__init__.py", line 182, in _remove_proxy
File "C:\Users\mapf\Anaconda3\lib\weakref.py", line 74, in __eq__
TypeError: isinstance() arg 2 must be a type or tuple of types
This is what it looks like:

Here is the code:
import sys
import numpy as np
import copy
import matplotlib.pyplot as plt
from matplotlib.text import Annotation
import matplotlib.patheffects as PathEffects
from matplotlib.backends.backend_qt5agg import \
FigureCanvasQTAgg as FigureCanvas
from matplotlib.patches import Rectangle
from PyQt5.QtWidgets import QDialog, QApplication, QGridLayout
from astropy.visualization import ImageNormalize, LinearStretch, ZScaleInterval
class IDAnnotation(Annotation):
def __init__(
self, text, position, ha='center', rotation=0, fontsize=15,
picker=False, zorder=3, clip_on=True, identifier='',
verticalalignment='baseline'
):
super().__init__(
text, position, ha=ha, rotation=rotation, fontsize=fontsize,
picker=picker, zorder=zorder, clip_on=clip_on,
verticalalignment=verticalalignment
)
self._id = identifier
def get_id(self):
return self._id
def set_id(self, identifier):
self._id = identifier
class ImageFigure:
def __init__(self, image):
self.fig, self.ax = plt.subplots()
self.canvas = FigureCanvas(self.fig)
self.base_image = image
self.base_image[np.where(self.base_image < 0)] = 0
self.image = copy.deepcopy(self.base_image)
self.norm = ImageNormalize(
self.image, stretch=LinearStretch(),
interval=ZScaleInterval()
)
self.image_artist = self.ax.imshow(
image, cmap='gray', interpolation='nearest', norm=self.norm
)
self.clim = self.image_artist.get_clim()
self.base_scale = 2.0
self.base_xlim = self.ax.get_xlim()
self.base_ylim = self.ax.get_ylim()
self.new_xlim = [0, 1]
self.new_ylim = [0, 1]
self.x_press = 0
self.y_press = 0
self.fig.canvas.mpl_connect('scroll_event', self.zoom)
self.fig.canvas.mpl_connect('button_press_event', self.pan_press)
self.fig.canvas.mpl_connect('motion_notify_event', self.pan_move)
self.hist = np.hstack(self.base_image)
self.hist = np.delete(self.hist, np.where(self.hist == 0))
self.contrast = HistogramFigure(self.hist, self.clim)
# self.contrast.fig.canvas.mpl_connect(
# 'button_release_event', self.adjust_contrast
# )
self.contrast.fig.canvas.mpl_connect(
'motion_notify_event', self.adjust_contrast
)
def adjust_contrast(self, event):
self.contrast.on_move_event(event)
low_in = self.contrast.lclim
high_in = self.contrast.uclim
self.image_artist.set_clim(low_in, high_in)
self.canvas.draw_idle()
def zoom(self, event):
xdata = event.xdata
ydata = event.ydata
if xdata is None or ydata is None:
pass
else:
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
x_left = xdata - cur_xlim[0]
x_right = cur_xlim[1] - xdata
y_top = ydata - cur_ylim[0]
y_bottom = cur_ylim[1] - ydata
if event.button == 'up':
scale_factor = 1 / self.base_scale
elif event.button == 'down':
scale_factor = self.base_scale
else:
scale_factor = 1
new_xlim = [
xdata - x_left*scale_factor, xdata + x_right*scale_factor
]
new_ylim = [
ydata - y_top*scale_factor, ydata + y_bottom*scale_factor
]
# intercept new plot parameters if they are out of bound
self.new_xlim, self.new_ylim = check_limits(
self.base_xlim, self.base_ylim, new_xlim, new_ylim
)
self.ax.set_xlim(self.new_xlim)
self.ax.set_ylim(self.new_ylim)
self.canvas.draw()
def pan_press(self, event):
if event.button == 1:
if event.xdata is None or event.ydata is None:
pass
else:
self.x_press = event.xdata
self.y_press = event.ydata
def pan_move(self, event):
if event.button == 1:
xdata = event.xdata
ydata = event.ydata
if xdata is None or ydata is None:
pass
else:
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
dx = xdata - self.x_press
dy = ydata - self.y_press
new_xlim = [cur_xlim[0] - dx, cur_xlim[1] - dx]
new_ylim = [cur_ylim[0] - dy, cur_ylim[1] - dy]
# intercept new plot parameters that are out of bound
new_xlim, new_ylim = check_limits(
self.base_xlim, self.base_ylim, new_xlim, new_ylim
)
self.ax.set_xlim(new_xlim)
self.ax.set_ylim(new_ylim)
self.canvas.draw()
class HistogramFigure:
def __init__(self, image, clim):
self.fig, self.ax = plt.subplots()
self.canvas = FigureCanvas(self.fig)
self.image = image
self.clim = clim
self.uclim = self.clim[1]
self.lclim = self.clim[0]
self.nbins = 20
self.dragged = None
self.pick_pos = None
self.uclim_hightlight = False
self.lclim_hightlight = False
self.dummy_patches = [False, False]
self.cropped_patches_index = [0, 0]
self.canvas.setMaximumHeight(100)
self.fig.subplots_adjust(left=0.07, right=0.98, bottom=0.1, top=0.75)
self.ax.tick_params(
axis="both", labelsize=6, left=True, top=True, labelleft=True,
labeltop=True, bottom=False, labelbottom=False
)
self.ax.tick_params(which='minor', bottom=False, top=True)
self.bins = np.geomspace(
min(self.image), max(self.image), self.nbins
)
_, _, self.patches = self.ax.hist(
self.image, bins=self.bins, log=True, zorder=1
)
self.ax.set_xscale("log", nonposx='clip')
self.color_patches()
self.ax.margins(0, 0.1)
self.uclim_marker = IDAnnotation(
r'$\blacktriangledown$',
(self.uclim, self.ax.get_ylim()[1]/6),
ha='center', fontsize=15, picker=True, zorder=3, clip_on=False,
identifier='uclim'
)
self.lclim_marker = IDAnnotation(
r'$\blacktriangle$',
(self.lclim+self.ax.get_xlim()[0], self.ax.get_ylim()[0]*16),
ha='center', verticalalignment='top', fontsize=15, picker=True,
zorder=2, clip_on=False, identifier='lclim'
)
self.ax.add_artist(self.uclim_marker)
self.ax.add_artist(self.lclim_marker)
self.fig.canvas.mpl_connect('pick_event', self.on_pick_event)
self.fig.canvas.mpl_connect(
'motion_notify_event', self.highlight_picker
)
self.fig.canvas.mpl_connect(
'button_release_event', self.on_release_event
)
self.fig.canvas.mpl_connect(
'button_press_event', self.on_button_press_event
)
self.canvas.draw()
def color_patches(self):
j = 0
i = self.bins[j]
overlap = False
while i < self.lclim:
self.patches[j].set_facecolor('gray')
j += 1
i = self.bins[j]
if j > 0:
self.cropped_patches_index[0] = j - 1
self.patches[j - 1].set_width(self.lclim - self.bins[j - 1])
self.patches[j - 1].set_facecolor('gray')
if self.uclim <= self.bins[j]:
width = self.uclim - self.lclim
overlap = True
else:
width = self.bins[j] - self.lclim
if self.dummy_patches[0]:
self.dummy_patches[0].set_xy(
(self.lclim, self.patches[j - 1].get_y())
)
self.dummy_patches[0].set_width(width)
self.dummy_patches[0].set_height(
self.patches[j - 1].get_height())
else:
self.dummy_patches[0] = Rectangle(
(self.lclim, self.patches[j - 1].get_y()),
width=width, linewidth=0,
height=self.patches[j - 1].get_height(), color='c'
)
self.ax.add_artist(self.dummy_patches[0])
if not overlap:
while np.logical_and(
i < np.max(self.bins), i < self.uclim
):
self.patches[j].set_facecolor('c')
j += 1
i = self.bins[j]
self.cropped_patches_index[1] = j-1
self.patches[j-1].set_width(self.uclim - self.bins[j-1])
self.patches[j-1].set_facecolor('c')
if self.dummy_patches[1]:
self.dummy_patches[1].set_xy(
(self.uclim, self.patches[j-1].get_y())
)
self.dummy_patches[1].set_width(self.bins[j]-self.uclim)
self.dummy_patches[1].set_height(self.patches[j-1].get_height())
else:
self.dummy_patches[1] = Rectangle(
(self.uclim, self.patches[j-1].get_y()),
width=self.bins[j]-self.uclim, linewidth=0,
height=self.patches[j-1].get_height(), color='gray'
)
self.ax.add_artist(self.dummy_patches[1])
while i < max(self.bins):
self.patches[j].set_facecolor('gray')
j += 1
i = self.bins[j]
def add_dummy(self, j, colors, limit):
if colors[0] == 'gray':
idx = 0
else:
idx = 1
self.cropped_patches_index[idx] = j
self.patches[j].set_width(limit - self.bins[j])
self.patches[j].set_facecolor(colors[0])
self.dummy_patches[idx].set_xy((limit, self.patches[j].get_y()))
self.dummy_patches[idx].set_width(self.bins[j]-limit)
self.dummy_patches[idx].set_height(self.patches[j].get_height())
# self.dummy_patches[0] = Rectangle(
# (limit, self.patches[j].get_y()),
# width=self.bins[j]-limit, linewidth=0,
# height=self.patches[j].get_height(),
# color=colors[1]
# )
# self.ax.add_artist(self.dummy_patches[0])
def on_pick_event(self, event):
"""
Store which text object was picked and were the pick event occurs.
"""
if isinstance(event.artist, Annotation):
self.dragged = event.artist
inv = self.ax.transData.inverted()
self.pick_pos = inv.transform(
(event.mouseevent.x, event.mouseevent.y)
)[0]
if self.pick_pos < self.ax.get_xlim()[0]:
self.pick_pos = self.ax.get_xlim()[0]
if self.pick_pos > self.ax.get_xlim()[1]:
self.pick_pos = self.ax.get_xlim()[1]
return True
def on_button_press_event(self, event):
if np.logical_and(
event.button == 1,
self.lclim_marker.contains(event)[0]
== self.uclim_marker.contains(event)[0]
):
inv = self.ax.transData.inverted()
self.pick_pos = inv.transform(
(event.x, event.y)
)[0]
def on_release_event(self, _):
if self.dragged is not None:
self.dragged = None
def on_move_event(self, event):
"""Update text position and redraw"""
if event.button == 1:
inv = self.ax.transData.inverted()
new_pos = (inv.transform((event.x, event.y))[0])
if self.dragged is not None:
old_pos = self.dragged.get_position()
if self.dragged.get_id() == 'lclim':
if new_pos < self.ax.get_xlim()[0]:
new_pos = self.ax.get_xlim()[0]
self.lclim = new_pos
if self.lclim > self.uclim:
self.lclim = self.uclim*0.999
self.dragged.set_position(
(self.lclim, old_pos[1])
)
self.patches[
self.cropped_patches_index[0]].set_width(
self.bins[self.cropped_patches_index[0] + 1]
- self.bins[self.cropped_patches_index[0]]
)
elif self.dragged.get_id() == 'uclim':
if new_pos > self.ax.get_xlim()[1]:
new_pos = self.ax.get_xlim()[1]
self.uclim = new_pos
if self.uclim < self.lclim:
self.uclim = self.lclim*1.001
self.dragged.set_position(
(self.uclim, old_pos[1])
)
self.patches[
self.cropped_patches_index[1]].set_width(
self.bins[self.cropped_patches_index[1] + 1]
- self.bins[self.cropped_patches_index[1]]
)
else:
pass
# self.dummy_patches = []
self.color_patches()
self.ax.figure.canvas.draw()
else:
pass
return True
def highlight_picker(self, event):
if event.button == 1:
pass
else:
if self.uclim_marker.contains(event)[0]:
if not self.uclim_hightlight:
self.uclim_hightlight = True
self.uclim_marker.set_path_effects(
[PathEffects.withStroke(linewidth=2, foreground="c")]
)
self.ax.figure.canvas.draw()
else:
pass
else:
if self.uclim_hightlight:
self.uclim_hightlight = False
self.uclim_marker.set_path_effects(
[PathEffects.Normal()]
)
self.ax.figure.canvas.draw()
else:
pass
if self.lclim_marker.contains(event)[0]:
if self.lclim_hightlight:
pass
else:
self.lclim_hightlight = True
self.lclim_marker.set_path_effects(
[PathEffects.withStroke(linewidth=2, foreground="c")]
)
self.ax.figure.canvas.draw()
else:
if self.lclim_hightlight:
self.lclim_hightlight = False
self.lclim_marker.set_path_effects(
[PathEffects.Normal()]
)
self.ax.figure.canvas.draw()
else:
pass
return True
class MainWindow(QDialog):
def __init__(self):
super().__init__()
self.img = np.random.random((500, 500))
self.layout = None
self.image = None
self.contrast = None
self.create_widgets()
def create_widgets(self):
self.layout = QGridLayout(self)
self.image = ImageFigure(self.img)
self.contrast = self.image.contrast
self.layout.addWidget(self.image.canvas, 0, 0)
self.layout.addWidget(self.contrast.canvas, 1, 0)
def check_limits(base_xlim, base_ylim, new_xlim, new_ylim):
if new_xlim[0] < base_xlim[0]:
overlap = base_xlim[0] - new_xlim[0]
new_xlim[0] = base_xlim[0]
if new_xlim[1] + overlap > base_xlim[1]:
new_xlim[1] = base_xlim[1]
else:
new_xlim[1] += overlap
if new_xlim[1] > base_xlim[1]:
overlap = new_xlim[1] - base_xlim[1]
new_xlim[1] = base_xlim[1]
if new_xlim[0] - overlap < base_xlim[0]:
new_xlim[0] = base_xlim[0]
else:
new_xlim[0] -= overlap
if new_ylim[1] < base_ylim[1]:
overlap = base_ylim[1] - new_ylim[1]
new_ylim[1] = base_ylim[1]
if new_ylim[0] + overlap > base_ylim[0]:
new_ylim[0] = base_ylim[0]
else:
new_ylim[0] += overlap
if new_ylim[0] > base_ylim[0]:
overlap = new_ylim[0] - base_ylim[0]
new_ylim[0] = base_ylim[0]
if new_ylim[1] - overlap < base_ylim[1]:
new_ylim[1] = base_ylim[1]
else:
new_ylim[1] -= overlap
return new_xlim, new_ylim
if __name__ == '__main__':
app = QApplication(sys.argv)
GUI = MainWindow()
GUI.show()
sys.exit(app.exec_())