I am trying to integrate new code within existing code written by someone else and I have encountered some problems. The existing code uses matplotlib
to make a GUI plotter that can plot various waveforms given an input file. I want to be able to hover over any of the traces on the graph and have an annotation box display what line it is (imagine having 30 lines on one graph and not being able o distinguish them from one another). I found this code (I am going off the first answer): Possible to make labels appear when hovering over a point in matplotlib?
Here is the code:
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
x = np.random.rand(15)
y = np.random.rand(15)
names = np.array(list("ABCDEFGHIJKLMNO"))
c = np.random.randint(1,5,size=15)
norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn
fig,ax = plt.subplots()
sc = plt.scatter(x,y,c=c, s=100, cmap=cmap, norm=norm)
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
def update_annot(ind):
pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
" ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
annot.get_bbox_patch().set_alpha(0.4)
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = sc.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect("motion_notify_event", hover)
plt.show()
The existing code defines ax
inside of a plotting function. I would paste the whole function here if it weren't so long, but here is a snippet (below is some of the code above):
else:
print ('The label is: %s' % label)
ax = plt.subplot('111')
axesDict[labelKey] = ax
#end if
#
annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
bbox=dict(fc="b"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(True)
fig.canvas.mpl_connect("motion_notify_event", hover)
#
The problem is I don't know how to pass ax
into the hover function because you cannot call that function with arguments due to the nature of mpl_connect
.
I am really, really new to Python and working on existing code of this size has been a challenge. Perhaps I am thinking about the implementation incorrectly and please feel free to point all this out! I am sure I have more questions, but this is a good start. Thank you for your help and time in advance.
EDIT: Here is the long plotting function (this is only the first part of it that I deal with):
def plotData(self, refreshPlotAxes = False):
if len(self.waveformObjectList) == 0:
print ('no waveforms to plot')
return
#end if
startFigureNumber = self.startFigureNumber
nextFigureNumber = startFigureNumber
if self.fileDataTypeMode == 'ascii':
markerArray = self.defaultMarkerArray
else:
markerArray = ['']
waveformIndexList = self.getFilteredWaveformObjectIndexList()
###################### First Plot #############################
if self.plotFreqResp:
firstLoop = True
markerIndex = 0
#which labels are in each figure
xAxisLabelDictionary = {}
yAxisLabelDictionary = {}
subplotDictionary = {} #subplots for each figure
plotAxisDictionary = {} #plot axis for every subplot
#at the moment, I don't support multiple figures and multiple subplots at the same time,
#but I might someday
logXDictDict = {}
logYDictDict = {}
plotFilename = 'blank_freqresp.png'
plotAxisList = []
numberOfFigures = 0
numberOfSubPlots = 0
numberOfLabels = 0
#set up the plots
axesDict = {}
labelList = []
for waveformObj in self.waveformObjectList:
label = waveformObj.label
labelPieces = label.split('_')
labelList.append(labelPieces)
#end for waveformObj
commonLabelPieces = []
if len(labelList) > 1:
labelPieces0 = labelList[0]
for labelPiece in labelPieces0:
isCommon = True
for labelPieces in labelList:
if labelPieces.count(labelPiece) == 0:
isCommon = False
break
#end if
#end for
if isCommon:
commonLabelPieces.append(labelPiece)
#end if
#end for labelPiece
#end if
for waveformIndex in waveformIndexList:
waveformObj = self.waveformObjectList[waveformIndex]
plotFilename = waveformObj.filename
[plotFilename, ext] = os.path.splitext(plotFilename)
plotFilename += '_freqresp.png'
if firstLoop or (self.plot1SeparatePlots and not self.plot1SubPlots):
currentFigureNumber = nextFigureNumber
logXDictDict[currentFigureNumber] = {}
logYDictDict[currentFigureNumber] = {}
nextFigureNumber += 1
numberOfFigures += 1
figureTxt = 'Figure %d - %s' % (currentFigureNumber, self.appTitle)
fig = plt.figure(figureTxt, figsize=self.cwPlotSize)
#end if
label = waveformObj.getLabel(shortLabel = self.shortLabel, includeXLabel = self.showXInLabel)
shortLabel = waveformObj.getLabel(shortLabel = True, includeXLabel = self.showXInLabel)
if self.enableShortenedLabels:
label = waveformObj.label
labelPieces = label.split('_')
uniqueLabelPieces = []
for labelPiece in labelPieces:
if commonLabelPieces.count(labelPiece) == 0:
uniqueLabelPieces.append(labelPiece)
#end if
#end for
label = '_'.join(uniqueLabelPieces)
label += '(' + shortLabel + ')'
#end if
try:
if waveformObj.hasReference():
label += '%s%s @ %s' % (waveformObj.referenceWaveformOperation, waveformObj.referenceWaveform, waveformObj.referenceWaveformFreq)
#end if
except:
pass
[xAxisLabel, yAxisLabel] = waveformObj.axisLabels()[0:2]
if xAxisLabel == 'none':
xAxisLabel = waveformObj.getDataLabels()[0]
if yAxisLabel == 'none' or yAxisLabel == 'mag':
yAxisLabel = waveformObj.getDataLabels()[1]
## print ('data labels = %s' % str([xAxisLabel, yAxisLabel]))
## print ('shortLabel = %s' % shortLabel
## print ('label = %s' % label
#when there is just one subplot (the default), it's designated '111'
subplotString = '1'
logX = self.logHorizontalAxis
dbY = self.dBVerticalAxis
if self.plot1SubPlots:
subplotString = '000'
for subplotNum in self.plot1SubPlotDict['filter'].keys():
matchList = self.plot1SubPlotDict['filter'][subplotNum]
for matchItem in matchList:
if re.search(matchItem, shortLabel):
subplotString = subplotNum
break
#end if
#end for
#end for
if subplotString == '000':
firstLoop = False
continue
try:
logX = self.plot1SubPlotDict['xlog'][subplotString]
except:
pass
try:
dbY = self.plot1SubPlotDict['ydb'][subplotString]
except:
pass
#end if
# if waveformObj.yUnits.lower().count('db'):
# yData = waveformObj.getNormalizeddBVector()
# logY = False
if waveformObj.yUnits.lower().count('bits') or \
waveformObj.yUnits.lower().count('data'):
yData = waveformObj.getMagnitudeVector()
logY = False
dbY = False
forceLinearYAxis = True
else:
forceLinearYAxis = False
if dbY:
yData = waveformObj.getNormalizeddBVector(self.absoluteValueForDB)
logY = False
else:
yData = waveformObj.getNormalizedMagnitudeVector()
logY = self.logVerticalAxis
#end if
#end if
fData = waveformObj.getFreqVector()
labelKey = str(currentFigureNumber) + '_' + subplotString
if not labelKey in xAxisLabelDictionary:
xAxisLabelDictionary[labelKey] = []
if not labelKey in yAxisLabelDictionary:
yAxisLabelDictionary[labelKey] = []
if not currentFigureNumber in subplotDictionary:
subplotDictionary[currentFigureNumber] = []
xAxisLabelDictionary[labelKey].append(xAxisLabel)
yAxisLabelDictionary[labelKey].append(yAxisLabel)
plot1FormatMatchesKey = False
for key in self.plot1Format.keys():
if re.search(key, waveformObj.yLabel) or re.search(key, waveformObj.label):
plot1FormatMatchesKey = True
break
elif re.search(key, label):
plot1FormatMatchesKey = True
break
#end if
#end for key
if plot1FormatMatchesKey:
pltFormatText = self.plot1Format[key][0]
pltLineWidth = self.plot1Format[key][1]
pltMarkerSize = self.plot1Format[key][2]
allowLabel = self.plot1Format[key][3]
if len(self.plot1Format[key]) > 4:
markerColor = self.plot1Format[key][4]
else:
markerColor = -1
if pltFormatText is None:
pltFormatText = markerArray[markerIndex]+'-'
markerIndex += 1
if pltLineWidth < 0:
pltLineWidth = self.defaultLineWidth
if pltMarkerSize < 0:
pltMarkerSize = self.defaultMarkerSize
if not allowLabel:
label = ''
if markerColor != -1:
markerEdgeColor = None
markerEdgeWidth = self.defaultMarkerEdgeWidth
markerFaceColor = markerColor
else:
markerEdgeColor = None
markerEdgeWidth = self.defaultMarkerEdgeWidth
markerFaceColor = None
#end if
else:
pltFormatText = markerArray[markerIndex] + self.defaultLinePattern
markerIndex += 1
pltLineWidth = self.defaultLineWidth
pltMarkerSize = self.defaultMarkerSize
markerEdgeColor = None
markerEdgeWidth = self.defaultMarkerEdgeWidth
markerFaceColor = None
#end if
if markerIndex >= len(markerArray):
markerIndex = 0
if labelKey in axesDict:
try:
plt.sca(axesDict[labelKey])
except:
print ('something went wrong with subplot label %s' % labelKey)
print ('probably due to overlapping subplots.')
print ('make adjustments to the figInfoDict items')
#end try
elif self.plot1SubPlots:
gridShape = self.plot1SubPlotDict['gridShape']
subplotInfo = self.plot1SubPlotDict['figInfoDict'][subplotString]
ax = plt.subplot2grid(gridShape, subplotInfo[0], subplotInfo[1], subplotInfo[2])
axesDict[labelKey] = ax
else:
print ("Made it inside else condition")
print ('The label is: %s' % label)
ax = plt.subplot('111')
axesDict[labelKey] = ax
#end if
#
annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
bbox=dict(fc="b"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(True)
h = lambda x: hover(x, annot, label)
fig.canvas.mpl_connect("motion_notify_event", h)
#
format the plots
for p in range(numberOfFigures):
figureNumber = p + startFigureNumber
figureTxt = 'Figure %d - %s' % (figureNumber, self.appTitle)
plt.figure(figureTxt)
if not figureNumber in subplotDictionary:
continue
for subplotString in subplotDictionary[figureNumber]:
labelKey = str(figureNumber) + '_' + subplotString
try:
plt.sca(axesDict[labelKey])
except:
print ('something went wrong with subplot label %s' % labelKey)
print ('probably due to overlapping subplots.')
print ('make adjustments to the figInfoDict items')
continue
#end try
#plt.subplot(subplotString)
plotAxis = plotAxisDictionary[labelKey]
#print ('start misc plot settings';
plt.grid(self.plot1Grid, 'both')
plot1YticksList = self.plot1YticksList
plot1XticksList = self.plot1XticksList
plot1YLimits = self.cwPlotYLimits
plot1XLimits = self.cwPlotXLimits
vcursors = []
logX = logXDictDict[figureNumber][subplotString]
logY = logYDictDict[figureNumber][subplotString]
enablePlotXLabel = True
legendEnable = True
if self.plot1SubPlots:
if not logY:
try:
plot1YticksList = self.plot1SubPlotDict['yticks'][subplotString]
except:
pass
else:
plot1YticksList = []
#end if
if not logX:
try:
plot1XticksList = self.plot1SubPlotDict['xticks'][subplotString]
except:
pass
else:
plot1XticksList = []
#end if
try:
plot1YLimits = self.plot1SubPlotDict['ylimits'][subplotString]
except:
pass
try:
plot1XLimits = self.plot1SubPlotDict['xlimits'][subplotString]
except:
pass
try:
vcursors = self.plot1SubPlotDict['vcursors'][subplotString]
except:
pass
try:
enablePlotXLabel = self.plot1SubPlotDict['xLabelEnable'][subplotString]
except:
pass
#end
try:
legendEnable = self.plot1SubPlotDict['legendEnable'][subplotString]
except:
pass
#end
#end if
if logY:
for tick in plot1YticksList:
if tick <= 0:
plot1YticksList = []
break
#end if
#end for
if len(plot1YLimits) == 2:
if plot1YLimits[0] <= 0:
plot1YLimits = []
#end if
#end if
#end if
if len(plot1YticksList):
plt.yticks(plot1YticksList)
if len(plot1XticksList):
plt.xticks(plot1XticksList)
if plotAxis == (0.0,1.0,0.0,1.0) or refreshPlotAxes:
if len(plot1YLimits) == 2:
plt.ylim(plot1YLimits)
if len(plot1XLimits) == 2:
plt.xlim(plot1XLimits)
else:
plt.axis(plotAxis)
#end if
if len(vcursors):
ylimits = plt.ylim()
for x in vcursors:
plt.plot([x,x], ylimits, self.vcursorFormatText, linewidth = self.vcursorWidth)
yAxisLabelListSet = list(set(yAxisLabelDictionary[labelKey]))
if len(yAxisLabelListSet) == 1:
yAxisLabel = yAxisLabelDictionary[labelKey][0]
elif len(yAxisLabelListSet) > 1:
yAxisLabel = yAxisLabelListSet[0]
for buf in yAxisLabelListSet[1:]:
yAxisLabel += ',' + buf
#end for
else:
yAxisLabel = ''
#end if
xAxisLabelListSet = list(set(xAxisLabelDictionary[labelKey]))
if len(xAxisLabelListSet) == 1:
xAxisLabel = xAxisLabelDictionary[labelKey][0]
elif len(xAxisLabelListSet) > 1:
xAxisLabel = xAxisLabelListSet[0]
for buf in xAxisLabelListSet[1:]:
xAxisLabel += ',' + buf
#end for
else:
xAxisLabel = ''
#end if
if not forceLinearYAxis:
if dbY:
if not waveformObj.yUnits.lower().count('db'):
yAxisLabel += ' (dB)'
else:
yAxisLabel += ' (lin)'
#end if
plt.ylabel(yAxisLabel)
if enablePlotXLabel:
plt.xlabel(xAxisLabel)
else:
xtickList = plt.xticks()[0]
plt.xticks(xtickList, '')
#end if
prop=matplotlib.font_manager.FontProperties(size=self.legendFontSize)
if self.shortLabel:
plt.title(waveformObj.filename, fontsize=12)
#end if
if self.cwPlotLegend and legendEnable:
plt.legend(loc=self.plot1LegendLocation,prop=prop,borderpad=0.3,labelspacing=0.1,handletextpad=0,numpoints=self.numLegendPoints)
#end if
#print ('done'
#end for subplotString
plt.draw()
if self.savePlotAsImage:
plt.savefig(plotFilename, format='png')