0

I am trying to integrate new code within existing code written by someone else and I have encountered some problems. The existing code uses matplotlib to make a GUI plotter that can plot various waveforms given an input file. I want to be able to hover over any of the traces on the graph and have an annotation box display what line it is (imagine having 30 lines on one graph and not being able o distinguish them from one another). I found this code (I am going off the first answer): Possible to make labels appear when hovering over a point in matplotlib?

Here is the code:

import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)

x = np.random.rand(15)
y = np.random.rand(15)
names = np.array(list("ABCDEFGHIJKLMNO"))
c = np.random.randint(1,5,size=15)

norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn

fig,ax = plt.subplots()
sc = plt.scatter(x,y,c=c, s=100, cmap=cmap, norm=norm)

annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)

def update_annot(ind):

pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))), 
                       " ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
annot.get_bbox_patch().set_alpha(0.4)


def hover(event):
    vis = annot.get_visible()
    if event.inaxes == ax:
    cont, ind = sc.contains(event)
    if cont:
        update_annot(ind)
        annot.set_visible(True)
        fig.canvas.draw_idle()
    else:
        if vis:
            annot.set_visible(False)
            fig.canvas.draw_idle()

fig.canvas.mpl_connect("motion_notify_event", hover)

plt.show()

The existing code defines ax inside of a plotting function. I would paste the whole function here if it weren't so long, but here is a snippet (below is some of the code above):

            else:
                print ('The label is: %s' % label)
                ax = plt.subplot('111')
                axesDict[labelKey] = ax
            #end if
#
            annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
                                bbox=dict(fc="b"),
                                arrowprops=dict(arrowstyle="->"))
            annot.set_visible(True)

            fig.canvas.mpl_connect("motion_notify_event", hover)
#

The problem is I don't know how to pass ax into the hover function because you cannot call that function with arguments due to the nature of mpl_connect.

I am really, really new to Python and working on existing code of this size has been a challenge. Perhaps I am thinking about the implementation incorrectly and please feel free to point all this out! I am sure I have more questions, but this is a good start. Thank you for your help and time in advance.

EDIT: Here is the long plotting function (this is only the first part of it that I deal with):

    def plotData(self, refreshPlotAxes = False):

        if len(self.waveformObjectList) == 0:
            print ('no waveforms to plot')
            return
        #end if

        startFigureNumber = self.startFigureNumber
        nextFigureNumber = startFigureNumber

        if self.fileDataTypeMode == 'ascii':
            markerArray = self.defaultMarkerArray
        else:
            markerArray = ['']

        waveformIndexList = self.getFilteredWaveformObjectIndexList()

        ###################### First Plot #############################

        if self.plotFreqResp:
            firstLoop = True
            markerIndex = 0
            #which labels are in each figure
            xAxisLabelDictionary = {}
            yAxisLabelDictionary = {}
            subplotDictionary = {}   #subplots for each figure
            plotAxisDictionary = {}  #plot axis for every subplot
            #at the moment, I don't support multiple figures and multiple subplots at the same time,
            #but I might someday
            logXDictDict = {}
            logYDictDict = {}

            plotFilename = 'blank_freqresp.png'

            plotAxisList = []
            numberOfFigures = 0
            numberOfSubPlots = 0
            numberOfLabels = 0

            #set up the plots
            axesDict = {}

            labelList = []
            for waveformObj in self.waveformObjectList:
                label = waveformObj.label
                labelPieces = label.split('_')
                labelList.append(labelPieces)
            #end for waveformObj

            commonLabelPieces = []
            if len(labelList) > 1:
                labelPieces0 = labelList[0]
                for labelPiece in labelPieces0:
                    isCommon = True
                    for labelPieces in labelList:
                        if labelPieces.count(labelPiece) == 0:
                            isCommon = False
                            break
                        #end if
                    #end for
                    if isCommon:
                        commonLabelPieces.append(labelPiece)
                    #end if
                #end for labelPiece
            #end if

            for waveformIndex in waveformIndexList:
                waveformObj = self.waveformObjectList[waveformIndex]

                plotFilename = waveformObj.filename
                [plotFilename, ext] = os.path.splitext(plotFilename)
                plotFilename += '_freqresp.png'

                if firstLoop or (self.plot1SeparatePlots and not self.plot1SubPlots):
                    currentFigureNumber = nextFigureNumber
                    logXDictDict[currentFigureNumber] = {}
                    logYDictDict[currentFigureNumber] = {}
                    nextFigureNumber += 1
                    numberOfFigures += 1
                    figureTxt = 'Figure %d - %s' % (currentFigureNumber, self.appTitle)
                    fig = plt.figure(figureTxt, figsize=self.cwPlotSize)
                #end if

                label = waveformObj.getLabel(shortLabel = self.shortLabel, includeXLabel = self.showXInLabel)
                shortLabel = waveformObj.getLabel(shortLabel = True, includeXLabel = self.showXInLabel)

                if self.enableShortenedLabels:
                    label = waveformObj.label
                    labelPieces = label.split('_')
                    uniqueLabelPieces = []
                    for labelPiece in labelPieces:
                        if commonLabelPieces.count(labelPiece) == 0:
                            uniqueLabelPieces.append(labelPiece)
                        #end if
                    #end for
                    label = '_'.join(uniqueLabelPieces)
                    label += '(' + shortLabel + ')'
                #end if

                try:
                    if waveformObj.hasReference():
                        label += '%s%s @ %s' % (waveformObj.referenceWaveformOperation, waveformObj.referenceWaveform, waveformObj.referenceWaveformFreq)
                    #end if
                except:
                    pass

                [xAxisLabel, yAxisLabel] = waveformObj.axisLabels()[0:2]
                if xAxisLabel == 'none':
                    xAxisLabel = waveformObj.getDataLabels()[0]
                if yAxisLabel == 'none' or yAxisLabel == 'mag':
                    yAxisLabel = waveformObj.getDataLabels()[1]

##                print ('data labels = %s' % str([xAxisLabel, yAxisLabel]))
##                print ('shortLabel = %s' % shortLabel
##                print ('label = %s' % label

                #when there is just one subplot (the default), it's designated '111'
                subplotString = '1'
                logX = self.logHorizontalAxis
                dbY = self.dBVerticalAxis
                if self.plot1SubPlots:
                    subplotString = '000'
                    for subplotNum in self.plot1SubPlotDict['filter'].keys():
                        matchList = self.plot1SubPlotDict['filter'][subplotNum]
                        for matchItem in matchList:
                            if re.search(matchItem, shortLabel):
                                subplotString = subplotNum
                                break
                            #end if
                        #end for
                    #end for

                    if subplotString == '000':
                        firstLoop = False
                        continue

                    try:
                        logX = self.plot1SubPlotDict['xlog'][subplotString]
                    except:
                        pass

                    try:
                        dbY = self.plot1SubPlotDict['ydb'][subplotString]
                    except:
                        pass

                #end if

#                if waveformObj.yUnits.lower().count('db'):
#                    yData = waveformObj.getNormalizeddBVector()
#                    logY = False
                if waveformObj.yUnits.lower().count('bits') or \
                     waveformObj.yUnits.lower().count('data'):
                    yData = waveformObj.getMagnitudeVector()
                    logY = False
                    dbY = False
                    forceLinearYAxis = True
                else:
                    forceLinearYAxis = False
                    if dbY:
                        yData = waveformObj.getNormalizeddBVector(self.absoluteValueForDB)
                        logY = False
                    else:
                        yData = waveformObj.getNormalizedMagnitudeVector()
                        logY = self.logVerticalAxis
                    #end if
                #end if

                fData = waveformObj.getFreqVector()

                labelKey = str(currentFigureNumber) + '_' + subplotString

                if not labelKey in xAxisLabelDictionary:
                    xAxisLabelDictionary[labelKey] = []
                if not labelKey in yAxisLabelDictionary:
                    yAxisLabelDictionary[labelKey] = []
                if not currentFigureNumber in subplotDictionary:
                    subplotDictionary[currentFigureNumber] = []

                xAxisLabelDictionary[labelKey].append(xAxisLabel)
                yAxisLabelDictionary[labelKey].append(yAxisLabel)

                plot1FormatMatchesKey = False
                for key in self.plot1Format.keys():

                    if re.search(key, waveformObj.yLabel) or re.search(key, waveformObj.label):
                        plot1FormatMatchesKey = True
                        break
                    elif re.search(key, label):
                        plot1FormatMatchesKey = True
                        break
                    #end if
                #end for key

                if plot1FormatMatchesKey:
                    pltFormatText = self.plot1Format[key][0]
                    pltLineWidth = self.plot1Format[key][1]
                    pltMarkerSize = self.plot1Format[key][2]
                    allowLabel = self.plot1Format[key][3]
                    if len(self.plot1Format[key]) > 4:
                        markerColor = self.plot1Format[key][4]
                    else:
                        markerColor = -1

                    if pltFormatText is None:
                        pltFormatText = markerArray[markerIndex]+'-'
                        markerIndex += 1
                    if pltLineWidth < 0:
                        pltLineWidth = self.defaultLineWidth
                    if pltMarkerSize < 0:
                        pltMarkerSize = self.defaultMarkerSize
                    if not allowLabel:
                        label = ''
                    if markerColor != -1:
                        markerEdgeColor = None
                        markerEdgeWidth = self.defaultMarkerEdgeWidth
                        markerFaceColor = markerColor
                    else:
                        markerEdgeColor = None
                        markerEdgeWidth = self.defaultMarkerEdgeWidth
                        markerFaceColor = None
                    #end if

                else:
                    pltFormatText = markerArray[markerIndex] + self.defaultLinePattern
                    markerIndex += 1
                    pltLineWidth = self.defaultLineWidth
                    pltMarkerSize = self.defaultMarkerSize
                    markerEdgeColor = None
                    markerEdgeWidth = self.defaultMarkerEdgeWidth
                    markerFaceColor = None
                #end if

                if markerIndex >= len(markerArray):
                    markerIndex = 0

                if labelKey in axesDict:
                    try:
                        plt.sca(axesDict[labelKey])
                    except:
                        print ('something went wrong with subplot label %s' % labelKey)
                        print ('probably due to overlapping subplots.')
                        print ('make adjustments to the figInfoDict items')
                    #end try
                elif self.plot1SubPlots:
                    gridShape = self.plot1SubPlotDict['gridShape']
                    subplotInfo = self.plot1SubPlotDict['figInfoDict'][subplotString]
                    ax = plt.subplot2grid(gridShape, subplotInfo[0], subplotInfo[1], subplotInfo[2])
                    axesDict[labelKey] = ax
                else:
                    print ("Made it inside else condition")
                    print ('The label is: %s' % label)
                    ax = plt.subplot('111')
                    axesDict[labelKey] = ax
                #end if

    #
                annot = ax.annotate("", xy=(-20,20), xytext=(None),textcoords="offset points",
                                    bbox=dict(fc="b"),
                                    arrowprops=dict(arrowstyle="->"))
                annot.set_visible(True)

                h = lambda x: hover(x, annot, label)

                fig.canvas.mpl_connect("motion_notify_event", h)
#

format the plots

for p in range(numberOfFigures):
figureNumber = p + startFigureNumber

figureTxt = 'Figure %d - %s' % (figureNumber, self.appTitle)
plt.figure(figureTxt)

if not figureNumber in subplotDictionary:
    continue

for subplotString in subplotDictionary[figureNumber]:

    labelKey = str(figureNumber) + '_' + subplotString
    try:
        plt.sca(axesDict[labelKey])
    except:
        print ('something went wrong with subplot label %s' % labelKey)
        print ('probably due to overlapping subplots.')
        print ('make adjustments to the figInfoDict items')
        continue
    #end try
    #plt.subplot(subplotString)
    plotAxis = plotAxisDictionary[labelKey]
    #print ('start misc plot settings';
    plt.grid(self.plot1Grid, 'both')

    plot1YticksList = self.plot1YticksList
    plot1XticksList = self.plot1XticksList
    plot1YLimits = self.cwPlotYLimits
    plot1XLimits = self.cwPlotXLimits
    vcursors = []

    logX = logXDictDict[figureNumber][subplotString]
    logY = logYDictDict[figureNumber][subplotString]

    enablePlotXLabel = True
    legendEnable = True

    if self.plot1SubPlots:
        if not logY:
            try:
                plot1YticksList = self.plot1SubPlotDict['yticks'][subplotString]
            except:
                pass
        else:
            plot1YticksList = []
        #end if

        if not logX:
            try:
                plot1XticksList = self.plot1SubPlotDict['xticks'][subplotString]
            except:
                pass
        else:
            plot1XticksList = []
        #end if

        try:
            plot1YLimits = self.plot1SubPlotDict['ylimits'][subplotString]
        except:
            pass

        try:
            plot1XLimits = self.plot1SubPlotDict['xlimits'][subplotString]
        except:
            pass

        try:
            vcursors = self.plot1SubPlotDict['vcursors'][subplotString]
        except:
            pass

        try:
            enablePlotXLabel = self.plot1SubPlotDict['xLabelEnable'][subplotString]
        except:
            pass
        #end

        try:
            legendEnable = self.plot1SubPlotDict['legendEnable'][subplotString]
        except:
            pass
        #end

    #end if

    if logY:
        for tick in plot1YticksList:
            if tick <= 0:
                plot1YticksList = []
                break
            #end if
        #end for
        if len(plot1YLimits) == 2:
            if plot1YLimits[0] <= 0:
                plot1YLimits = []
            #end if
        #end if
    #end if

    if len(plot1YticksList):
        plt.yticks(plot1YticksList)
    if len(plot1XticksList):
        plt.xticks(plot1XticksList)

    if plotAxis == (0.0,1.0,0.0,1.0) or refreshPlotAxes:
        if len(plot1YLimits) == 2:
            plt.ylim(plot1YLimits)

        if len(plot1XLimits) == 2:
            plt.xlim(plot1XLimits)
    else:
        plt.axis(plotAxis)
    #end if

    if len(vcursors):
        ylimits = plt.ylim()
        for x in vcursors:
            plt.plot([x,x], ylimits, self.vcursorFormatText, linewidth = self.vcursorWidth)

    yAxisLabelListSet = list(set(yAxisLabelDictionary[labelKey]))
    if len(yAxisLabelListSet) == 1:
        yAxisLabel = yAxisLabelDictionary[labelKey][0]
    elif len(yAxisLabelListSet) > 1:
        yAxisLabel = yAxisLabelListSet[0]
        for buf in yAxisLabelListSet[1:]:
            yAxisLabel += ',' + buf
        #end for
    else:
        yAxisLabel = ''
    #end if

    xAxisLabelListSet = list(set(xAxisLabelDictionary[labelKey]))
    if len(xAxisLabelListSet) == 1:
        xAxisLabel = xAxisLabelDictionary[labelKey][0]
    elif len(xAxisLabelListSet) > 1:
        xAxisLabel = xAxisLabelListSet[0]
        for buf in xAxisLabelListSet[1:]:
            xAxisLabel += ',' + buf
        #end for
    else:
        xAxisLabel = ''
    #end if

    if not forceLinearYAxis:
        if dbY:
            if not waveformObj.yUnits.lower().count('db'):
                yAxisLabel += ' (dB)'
        else:
            yAxisLabel += ' (lin)'
    #end if

    plt.ylabel(yAxisLabel)
    if enablePlotXLabel:
        plt.xlabel(xAxisLabel)
    else:
        xtickList = plt.xticks()[0]
        plt.xticks(xtickList, '')
    #end if

    prop=matplotlib.font_manager.FontProperties(size=self.legendFontSize)
    if self.shortLabel:
        plt.title(waveformObj.filename, fontsize=12)
    #end if
    if self.cwPlotLegend and legendEnable:
        plt.legend(loc=self.plot1LegendLocation,prop=prop,borderpad=0.3,labelspacing=0.1,handletextpad=0,numpoints=self.numLegendPoints)
    #end if
    #print ('done'

#end for subplotString

plt.draw()

if self.savePlotAsImage:
    plt.savefig(plotFilename, format='png')
  • `ax` is used inside the `hover` function only to check if the event is inside of that axes. You do not strictly need that check, or even if you need it (mostly for performance reasons) you could replace it by something different. However, It's hard to know what exactly you need here. Do you have several axes? Or only one? – ImportanceOfBeingErnest Dec 18 '18 at 19:58
  • @ImportanceOfBeingErnest I have one axis as of now that I deal with. I took that out in the code and the result is beautiful. Thank you for your suggestion. The only thing is that I still need to pass in `annot` to the hover function. Reuben answered below using lambda. Perhaps this would work? – Dymaxion Fuller Dec 18 '18 at 21:19
  • Yes, it probably would. However, I have the feeling that it wouldn't be necessary either. If you have one annotation there is no need to pass it to the function - usually. We don't have the full code to verify. – ImportanceOfBeingErnest Dec 18 '18 at 21:25
  • @ImportanceOfBeingErnest I have added the code. It is quite long. Thank you, again. – Dymaxion Fuller Dec 18 '18 at 21:37
  • @ImportanceOfBeingErnest The last thing that confuses me is the `lines`. I need to use the `line2D` object in matplotlib to check if the event is contained in the axes, correct? – Dymaxion Fuller Dec 18 '18 at 21:40
  • I don't find anthing that is plotted in your code. Actually, it just occurs to me that while you are linking to an answer for scatter, you want to annotate a line? Do you want to annotate each point of that line separately? Does the annoation differ from the line's label? – ImportanceOfBeingErnest Dec 18 '18 at 21:57
  • @ImportanceOfBeingErnest I apologize for the lack of clarity. The lines are imported in from files and each line corresponds to a specific label (which is the file it came from). I want the user to be able to hover over a specific datapoint on a line and see the label. At the end of the code is a `plt.draw()`. I did not add this because it is after a second and third plot are drawn (plots I am not concerned with at the moment) and it involves another 400 lines of code. – Dymaxion Fuller Dec 18 '18 at 23:35
  • @ImportanceOfBeingErnest I added the formatting of the plot code – Dymaxion Fuller Dec 18 '18 at 23:55
  • @ImportanceOfBeingErnest I figured a lot of the issues out, but I am till struggling with one that I have described in another post:https://stackoverflow.com/questions/53857468/annotation-only-works-on-the-last-line-drawn. I had not realized you wrote that original code. Perhaps you could help answer at least part of that question since it related directly to that. Thank you! – Dymaxion Fuller Dec 19 '18 at 18:49

1 Answers1

1

You could do something like this:

h = lambda x: hover(x, ax)
fig.canvas.mpl_connect("motion_notify_event", h)

then change your hover function to:

def hover(event, ax):
    ...