4

I would like to produce a specific type of visualization, consisting of a rather simple dot plot but with a twist: both of the axes are categorical variables (i.e. ordinal or non-numerical values). And this complicates matters instead of making it easier.

To illustrate this question, I will be using a small example dataset that is a modification from seaborn.load_dataset("tips") and defined as such:

import pandas
from six import StringIO
df = """total_bill |  tip  |    sex | smoker | day |   time | size
             16.99 | 1.01  |   Male |     No | Mon | Dinner |    2
             10.34 | 1.66  |   Male |     No | Sun | Dinner |    3
             21.01 | 3.50  |   Male |     No | Sun | Dinner |    3
             23.68 | 3.31  |   Male |     No | Sun | Dinner |    2
             24.59 | 3.61  | Female |     No | Sun | Dinner |    4
             25.29 | 4.71  | Female |     No | Mon | Lunch  |    4
              8.77 | 2.00  | Female |     No | Tue | Lunch  |    2
             26.88 | 3.12  |   Male |     No | Wed | Lunch  |    4
             15.04 | 3.96  |   Male |     No | Sat | Lunch  |    2
             14.78 | 3.23  |   Male |     No | Sun | Lunch  |    2"""
df = pandas.read_csv(StringIO(df.replace(' ','')), sep="|", header=0)

My first approach to produce my graph was to try a call to seaborn as such:

import seaborn
axes = seaborn.pointplot(x="time", y="sex", data=df)

This fails with:

ValueError: Neither the `x` nor `y` variable appears to be numeric.

So does the equivalent seaborn.stripplot and seaborn.swarmplot calls. It does work however if one of the variables is categorical and the other one is numerical. Indeed seaborn.pointplot(x="total_bill", y="sex", data=df) works, but is not what I want.

I also attempted a scatterplot like such:

axes = seaborn.scatterplot(x="time", y="sex", size="day", data=df,
                           x_jitter=True, y_jitter=True)

This produces the following graph which does not contain any jitter and has all the dots overlapping, making it useless:

SeabornScatterPlot

Do you know of any elegant approach or library that could solve my problem ?

I started writing something myself, which I will include below, but this implementation is suboptimal and limited by the number of points that can overlap at the same spot (currently it fails if more than 4 points overlap).

# Modules #
import seaborn, pandas, matplotlib
from six import StringIO

################################################################################
def amount_to_offets(amount):
    """A function that takes an amount of overlapping points (e.g. 3)
    and returns a list of offsets (jittered) coordinates for each of the
    points.

    It follows the logic that two points are displayed side by side:

    2 ->  * *

    Three points are organized in a triangle

    3 ->   *
          * *

    Four points are sorted into a square, and so on.

    4 ->  * *
          * *
    """
    assert isinstance(amount, int)
    solutions = {
        1: [( 0.0,  0.0)],
        2: [(-0.5,  0.0), ( 0.5,  0.0)],
        3: [(-0.5, -0.5), ( 0.0,  0.5), ( 0.5, -0.5)],
        4: [(-0.5, -0.5), ( 0.5,  0.5), ( 0.5, -0.5), (-0.5,  0.5)],
    }
    return solutions[amount]

################################################################################
class JitterDotplot(object):

    def __init__(self, data, x_col='time', y_col='sex', z_col='tip'):
        self.data = data
        self.x_col = x_col
        self.y_col = y_col
        self.z_col = z_col

    def plot(self, **kwargs):
        # Load data #
        self.df = self.data.copy()

        # Assign numerical values to the categorical data #
        # So that ['Dinner', 'Lunch'] becomes [0, 1] etc. #
        self.x_values = self.df[self.x_col].unique()
        self.y_values = self.df[self.y_col].unique()
        self.x_mapping = dict(zip(self.x_values, range(len(self.x_values))))
        self.y_mapping = dict(zip(self.y_values, range(len(self.y_values))))
        self.df = self.df.replace({self.x_col: self.x_mapping, self.y_col: self.y_mapping})

        # Offset points that are overlapping in the same location #
        # So that (2.0, 3.0) becomes (2.05, 2.95) for instance #
        cols = [self.x_col, self.y_col]
        scaling_factor = 0.05
        for values, df_view in self.df.groupby(cols):
            offsets = amount_to_offets(len(df_view))
            offsets = pandas.DataFrame(offsets, index=df_view.index, columns=cols)
            offsets *= scaling_factor
            self.df.loc[offsets.index, cols] += offsets

        # Plot a standard scatter plot #
        g = seaborn.scatterplot(x=self.x_col, y=self.y_col, size=self.z_col, data=self.df, **kwargs)

        # Force integer ticks on the x and y axes #
        locator = matplotlib.ticker.MaxNLocator(integer=True)
        g.xaxis.set_major_locator(locator)
        g.yaxis.set_major_locator(locator)
        g.grid(False)

        # Expand the axis limits for x and y #
        margin = 0.4
        xmin, xmax, ymin, ymax = g.get_xlim() + g.get_ylim()
        g.set_xlim(xmin-margin, xmax+margin)
        g.set_ylim(ymin-margin, ymax+margin)

        # Replace ticks with the original categorical names #
        g.set_xticklabels([''] + list(self.x_mapping.keys()))
        g.set_yticklabels([''] + list(self.y_mapping.keys()))

        # Return for display in notebooks for instance #
        return g

################################################################################
# Graph #
graph = JitterDotplot(data=df)
axes  = graph.plot()
axes.figure.savefig('jitter_dotplot.png')

JitterDotPlot

xApple
  • 6,150
  • 9
  • 48
  • 49
  • Have you looked at ```seaborn.catplot```? – iamchoosinganame May 28 '19 at 17:50
  • 1
    Potentially useful here: [Making a non-overlapping bubble chart in Matplotlib (circle packing)](https://stackoverflow.com/questions/46131572/making-a-non-overlapping-bubble-chart-in-matplotlib-circle-packing) – ImportanceOfBeingErnest May 28 '19 at 18:39
  • @iamchoosinganame Yes, have looked into `catplot`, as I said in my question I tried all seaborn options. Besides, `catplot` is just a method to create FacetGrids. – xApple May 29 '19 at 08:45
  • @ImportanceOfBeingErnest Thanks for that link, it seems very promising. I found another linked question: https://stackoverflow.com/questions/8671808 – xApple May 29 '19 at 08:46
  • @iayork It's specifically mentioned in my question that `swarmplot` returns the error `ValueError: Neither the 'x' nor 'y' variable appears to be numeric`. – xApple May 29 '19 at 13:13
  • Can you just convert to a number (`df['Sex_numerical'] = df[['sex',]].applymap(lambda x:1 if x=='Male' else 0)`, swarm plot that, and rename the tick labels? – iayork May 29 '19 at 14:40

1 Answers1

2

you could first convert time and sex to categorical type and tweak it a little bit:

df.sex = pd.Categorical(df.sex)
df.time = pd.Categorical(df.time)

axes = sns.scatterplot(x=df.time.cat.codes+np.random.uniform(-0.1,0.1, len(df)), 
                       y=df.sex.cat.codes+np.random.uniform(-0.1,0.1, len(df)),
                       size=df.tip)

Output:

enter image description here

With that idea, you can modify the offsets (np.random) in the above code to the respective distance. For example:

# grouping
groups = df.groupby(['time', 'sex'])

# compute the number of samples per group
num_samples = groups.tip.transform('size')

# enumerate the samples within a group
sample_ranks = df.groupby(['time']).cumcount() * (2*np.pi) / num_samples

# compute the offset
x_offsets = np.where(num_samples.eq(1), 0, np.cos(df.sample_rank) * 0.03)
y_offsets = np.where(num_samples.eq(1), 0, np.sin(df.sample_rank) * 0.03)

# plot
axes = sns.scatterplot(x=df.time.cat.codes + x_offsets, 
                       y=df.sex.cat.codes + y_offsets,
                       size=df.tip)

Output:

enter image description here

Quang Hoang
  • 146,074
  • 10
  • 56
  • 74
  • Ah yes that's an elegant idea to place them on a circle, but probably that will only work up to 10 points. For greater numbers of points, one would need to place them a grid I think. I'll continue developing, thanks. – xApple May 29 '19 at 09:13