I found this to constrain the graph only move along the x axis. But I also wanted to let the graph only move on range 0-100 when I drag the pan. I tried this method, but not working well on histogram.
I want the histogram only can drag pan in the range of the data not out of range on x axis after scroll to zoom in or out.
import matplotlib
import matplotlib.pyplot as plt
import mplcursors
import numpy as np
import torch
class My_Axes(matplotlib.axes.Axes):
name = "My_Axes"
def drag_pan(self, button, key, x, y):
matplotlib.axes.Axes.drag_pan(self, button, 'x', x, y) # pretend key=='x'
matplotlib.projections.register_projection(My_Axes)
class Restrictor:
def __init__(self, ax, x=None, y=None):
self.ax = ax
self.x = x
self.y = y
self.initial_xlim = ax.get_xlim() # Store the initial x-axis range
def apply_restrictions(self):
if self.x is not None:
self.ax.set_xlim(self.x[0], self.x[1])
if self.y is not None:
self.ax.set_ylim(self.y[0], self.y[1])
def set_x_restriction(self, x):
self.x = x
self.apply_restrictions()
def set_y_restriction(self, y):
self.y = y
self.apply_restrictions()
def set_restrictions(self, x=None, y=None):
if x is not None:
self.x = x
if y is not None:
self.y = y
self.apply_restrictions()
def plot_histogram(data, bin_size):
figure = plt.figure()
ax = figure.add_subplot(111, projection="My_Axes")
frequencies, bins, patches = ax.hist(data.flatten(), bins=bin_size)
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.set_title('Histogram')
cmap = plt.get_cmap('viridis')
bin_colors = cmap(np.linspace(0, 1, len(patches)))
for patch, color in zip(patches, bin_colors):
patch.set_facecolor(color)
cursor = mplcursors.cursor(patches)
cursor.connect("add", lambda sel: sel.annotation.set_text(f"Range: [{sel.artist[sel.index].get_x()}, {sel.artist[sel.index].get_x() + sel.artist[sel.index].get_width()}]\nFrequency: {int(frequencies[sel.index])}"))
min_bin = min(bins)
max_bin = max(bins)
ax.set(xlim=(min_bin, max_bin), ylim=(0, None), autoscale_on=False)
res = Restrictor(ax, x=(min_bin, max_bin), y=(0, None))
ax.restrictor = res # Store the Restrictor instance as an attribute of the axes
plt.gca().figure.canvas.mpl_connect('scroll_event', zoom_graph)
plt.show()
def zoom_graph(event):
if event.xdata is not None and event.ydata is not None:
ax = plt.gca()
xlim = ax.get_xlim()
initial_x_range = ax.restrictor.initial_xlim # Get the stored initial x-axis range from the Restrictor
scale_factor = 1.5 if event.button == 'up' else 1 / 1.5
new_xlim = (
(xlim[0] - event.xdata) * scale_factor + event.xdata,
(xlim[1] - event.xdata) * scale_factor + event.xdata
)
# # Check if the new x-axis range is within the initial range
# if new_xlim[0] < initial_x_range[0]:
# new_xlim = (initial_x_range[0], initial_x_range[0] + (xlim[1] - xlim[0]))
# if new_xlim[1] > initial_x_range[1]:
# new_xlim = (initial_x_range[1] - (xlim[1] - xlim[0]), initial_x_range[1])
ax.set_xlim(new_xlim)
ax.figure.canvas.draw()
# Usage example:
bin_size = 20
data = torch.randn(512, 512).numpy()
plot_histogram(data, bin_size)