My Favorite Plots Using Matplotlib - Part I

Before I start, let me just be clear and say that there are many, many great things that one can do with Matplotlib. And most importantly, similar results can be achieved in multiple different ways. What you will see in this article is by no means any sort of universal truth, and if you like to visualise your data in a different way, then that’s perfectly fine!

This is not a tutorial

The purpose of this article isn’t to go through the fundamentals of Matplotlib and Seaborn, but simply for me to share some of the plotting functions that I have been using the most while working as a data analyst.

I will also include some data pre-processing functions that can make your life much easier and hopefully help you create beautiful and meaningful plots.

Pick your poison

There are mainly two different schools of thought when it comes to creating and rendering charts using Matplotlib or Seaborn (or maybe three if we include people who only visualise their data using the Pandas .plot() method).

  • Many people will prefer using .subplots(), and unpack the figure and axes objects from the tuple that this function creates:
fig,ax = plt.subplots()

ax.plot(x, y)

  • Though I can perfectly understand why one might want to do that, I’ve always been more enclined to build my charts this way:
plt.figure()

plt.plot(x,y)

plt.show()

All the code that you will find throughout this article will be written using the second approach, but I guess it wouldn’t take much effort to make it fit into the .subplots() way.

Prepare your data first

We’ll be using a random video games dataset that I found on Kaggle a long time ago.

def getDataframe(csv_file):

    df = pd.read_csv(csv_file)

    return df 



df = getDataframe("videogamessales.csv")

df.sample(5)

This is what our raw data looks like. We have some categorical and numerical variables, nothing fancy or that needs any preprocessing.

alt text

Now let’s take a look at the three main aggregating functions that I like to use:

  • This might seem a bit overkill, but I like to have a function that counts elements within a Pandas serie, and returns the output as a dataframe object.
from collections import Counter



def getCount(serie,howmany):

    counted = Counter(serie)

    labels = [l for l,v in counted.most_common(howmany)]

    values = [v for l,v in counted.most_common(howmany)]

    result = pd.DataFrame({"Labels":labels,"Values":values})

    return result



getCount(df["Publisher"],6)

alt text

Now to be fair, it’d probably be much easier to write x = df["Publisher"].value_counts().keys()[:6] and then y = df["Publisher"].value_counts().values[:6] but I like to have a dataframe on which I can apply some methods like .query() if needed.

  • The next function does something a bit similar, but returns how many instances of serie2 can be found within each unique serie1. This is particularly useful for stacked bar plots. As you probably already know if you’ve read some other articles in this blog, I’m not a big fan of multi-level indexing and that’s what lines 3 and 4 ungracefully try to get rid of.
def getDoubleCount(data,serie1,serie2):

    count = data.groupby([serie1,serie2])[[serie2]].count()

    count.rename(columns={serie2:"volume"}, inplace=True)

    count.reset_index(drop=False, inplace=True)

    return count 



getDoubleCount(df,"Publisher","Platform")

alt text

  • While we’re here, we might also want to create a function that allows us to do some SQL-like aggregations, freely picking whichever calculation we want to do.
def getGroupBy(data,serie,aggr):

    grouped = data.groupby(serie, as_index=False).agg(aggr)

    return grouped.round(2)



getGroupBy(df,"Platform", {"NA_Sales":"mean","EU_Sales":"max","JP_Sales":"min"})

alt text

  • And finally, we’ll sometimes need a function that transposes the unique values within a serie into several individual series, then counts their values as rows against another unique serie. If this isn’t clear (it isn’t to me as I’m writing these lines), then let’s see a concrete example:
def getDummies(dataframe,filtered,pivoted,howmany,ind):

    # dummies

    data = dataframe.filter(filtered)

    dummies = pd.get_dummies(data, prefix="count", columns=pivoted)

    for c in dummies.columns.to_list()[1:]:

        dummies.rename(columns={c: c.replace("count_","")}, inplace=True)

    # limiting the number of columns

    top = {}

    for c in dummies.columns.to_list()[1:]:

        top[c] = dummies[f"{c}"].sum()

    top = sorted(top, key=top.get, reverse=True)[:howmany]

    top.insert(0, ind)

    dummies = dummies.filter(top)

    # reducing the number of cols

    agg = {}

    for c in dummies.columns.to_list()[1:]:

        agg[c] = "sum"

    dummies = dummies.groupby(ind, as_index=False).agg(agg)

    return dummies



# dummies table for the top 6 elements

dummies = getDummies(df,["Publisher","Platform"],["Platform"],6,"Publisher")

dummies

alt text

Alright, that should hopefully make more sense now!

It’s all about style

  • Templates

Both Matplotlib and Seaborn have their own styling templates, which can be called as follows:

from matplotlib import style as style

style.use("whichever style you choose to use")

Or

import seaborn as sns

sns.set_style("same as above")

I personally like to use Seaborn’s "whitegrid", but Matplotlib’s "fivethirtyeight" is also a widely popular choice.

The full list of templates for Matplotlib is available here, and there for Seaborn.

  • Colormaps

From a colour scheme perspective, the palette “Blues” (or “Blues_r” for reversed colors) will be used. This color hue is part of Cynthia Brewer’s ColorBrewer color schemes, and you should make sure to check out the comprehensive list of supported colormaps if you want to add colors palettes to your charts.

Note: who is Cynthia Brewer?
Cynthia Brewer is an American professor of geography who developed in 2002 the famous online ColorBrewer tool, which provides colorblind-friendly sets of colors that serve as the basis for many of the modern data visualisation tools.

alt text

  • RC parameters

Oh, one last thing that I find quite important, is to set high-level rcParams that once loaded onto runtime memory will apply to all of our plots. RC stands for runtime configuration, and will save you some time as the parameters set during this process will affect all your Matplotlib and Seaborn charts.

rc = {

    "figure.figsize": (18,5),

    "axes.edgecolor": "white",

    "font.family": "monospace",

    "font.size": 13

     }

plt.rcParams.update(rc)

Here for instance, we’re setting the default figure size for each plot, the font family, etc.. Please note that these settings can be individually overridden if needed. You can find a comprehensive list of rcParams on this page.

Show me the money

It seems like we are finally ready to create some plots and explore our dataset!

  • Categorical data

Using the data aggregation functions we defined earlier, we can start with two side-by-side bar charts to show the most common values for df["Publisher"] and df["Platform"].

def getBarCharts(x1,y1,x2,y2,my_title1,my_title2):

    fig = plt.figure()

    plt.subplot(121)

    left = sns.barplot(x=x1, y=y1, orient="h", palette="Blues_r")

    sns.despine(left=True, right=True, top=True, bottom=True)

    plt.xlabel("")

    plt.ylabel("")

    left.set_title(my_title1)

    plt.subplot(122)

    right = sns.barplot(x=x2, y=y2, orient="h", palette="Blues_r")

    sns.despine(left=True, right=True, top=True, bottom=True)

    plt.xlabel("")

    plt.ylabel("")

    right.set_title(my_title2)

    plt.subplots_adjust(wspace = 0.2, hspace = 0.4, top = 0.9)

    plt.show()



left = getCount(df["Publisher"], 7)

right = getCount(df["Platform"], 7)



getBarCharts(left["Values"], left["Labels"],

             right["Values"], right["Labels"],

             "Top publishers", "Top platforms"

             )

alt text

The left and right variables will hold the aggregated dataframes that we can create using the .getCount() function. Please note that as the bars are set to be horizontal (orient="h"), the x parameter in sns.barplot() takes the numerical values as found under ["Values"], while the y parameter contains the categorical values as found under ["Labels"].

To split the figure into side-by-side subplots, we pass three elements into the plt.subplots() method: the first number (here 1) stands for the y axis and shows that we want one vertical level. The second number (here 2) means that we want two plots on the horizontal x axis, while the third value (here 1, then 2) refers to the current plot that we are editing. We can further down use the function plt.subplots_adjust() to refine the spacing between each subplot.

Next comes a set of stacked horizontal bars that we will use to count the volume of platforms for each video games editor. This is when the getDummies() function that we defined earlier will come in handy.

def getStackedBars(data,x,title):

    data.plot(

        kind="barh",

        x=x,

        stacked=True, 

        legend="upper right",

        title=title,

        cmap="Blues_r"

        );



d = dummies.query("Publisher in ('Electronic Arts','Namco Bandai Games','Activision','THQ','Ubisoft')")

getStackedBars(d,"Publisher","test")

alt text

What you see above should be self-explanatory. We’re using Pandas’s built-in .plot() method as it makes visualising stacked bars much easier. If you’re wondering why I added a semi-colon ";" after the closing parenthesis, well that’s because we get the following message if we don’t (which actually isn’t an error message, but it just doesn’t look good I think):

<matplotlib.axes._subplots.AxesSubplot at xxxxx>

Another thing we can do with this dataset, is visualise the most recurring tokens using a word cloud. Or rather, two side-by-side word clouds. To get these, we first need to import the wordcloud library, and then join both df["Publisher"] and df["Name"] as strings. We could also remove stopwords, but to do that we’d need a list of tokens, which we don’t.

from wordcloud import WordCloud, STOPWORDS



def getWordCloud(text_left,text_right,title):

    text_left = " ".join(text_left)

    text_right = " ".join(text_right)

    wc_left = WordCloud(max_words=200, width=600, height=400, background_color="white", margin=0, max_font_size=80, min_font_size=10, colormap="Blues_r").generate(text_left)

    wc_right = WordCloud(max_words=200, width=600, height=400, background_color="white", margin=0, max_font_size=80, min_font_size=10, colormap="Blues_r").generate(text_left) 

    plt.figure()

    plt.suptitle(title, size=20, y=1)

    plt.subplot(1,2,1)

    plt.imshow(wc_left)

    plt.axis("off")

    plt.subplot(1,2,2)

    plt.imshow(wc_right)

    plt.axis("off")

    plt.show()



getWordCloud(df["Publisher"].dropna(),df["Name"],"Top keywords")

alt text

  • Distribution

We will also want to visually assess the distribution of numerical variables, using either a set of histograms, ridge lines, or violin plots.

def getHistPlots(x1,x2,title1,title2):

    fig = plt.figure()

    plt.subplot(121)

    left = sns.histplot(x1, kde=True, bins=10, palette="Blues_r")

    sns.despine(left=True, right=True, top=True, bottom=True)

    plt.xlabel("")

    plt.ylabel('')

    left.set_title(title1)

    plt.subplot(122)

    right = sns.histplot(x2, kde=True, bins=10, palette="Blues_r")

    sns.despine(left=True, right=True, top=True, bottom=True)

    plt.xlabel("")

    plt.ylabel("")

    right.set_title(title2)

    plt.subplots_adjust(wspace = 0.2, hspace = 0.4, top = 0.9)

    plt.show()



left = df.query("Publisher == 'Ubisoft'")

right = df.query("Publisher == 'Electronic Arts'")



getHistPlots(left["Year"],right["Year"],"Years (Ubisoft)","Years (EA)")

alt text

To be fair, the logic behind these side-by-side plotting functions is pretty much identical. What changes really is the sns.type_of_plot_you_want_to_get() method that we end up picking. You’ll notice that .histplot() allows for manually choosing the number of bins that we want to see, as well as enabling kde=True which will:

"(..) compute a kernel density estimate to smooth the distribution and show on the plot as (one or more) line(s)." Source

We’re almost there! I’m personally not a big fan of violin plots, but I have been asked on several occasions to incorporate this type of visualisation into reports that we would provide for other data-focused teams or for leadership, and I thought I might as well share the simple function I ended up writing.

def getViolinPlots(data,x,y,title):

    plt.figure()

    sns.boxplot(data=data, x=x, y=y, orient="v", palette="Blues_r")

    plt.title(title)

    sns.despine(left=True, right=True, top=True, bottom=True)

    plt.show()



violin = df.query("Publisher in ('Ubisoft','Electronic Arts')")

getViolinPlots(violin,"Publisher","Year","Release years for Ubisoft and EA")

alt text

  • Correlation

I’ve kept for the last part of this article the two plotting functions that I have used the most as a data analyst over the past couple of years, especially for technical presentations or when investigating the dependence between multiple variables within a same serie.

def getCorrPlot(data,x,y,howmany,title):

    # aggregation

    corr = data.filter([x,y])

    dataframe = getCount(data[y],howmany)

    dataframe = dataframe["Labels"].to_list()

    corr = corr.loc[corr[y].isin(dataframe)]



    # dummies and re aggregation



    corr = pd.get_dummies(corr,prefix="",prefix_sep="",columns=[y])

    aggregate = {}

    for col in corr.columns[1:]:

        aggregate[col] = "sum"

    corr = corr.groupby(x, as_index=False).agg(aggregate)



    # normalising

    for col in corr.columns[1:]:

        corr[col] = np.log(corr[col] + 1)

    

    # correlation matrix

    corr.drop(columns=[x],inplace=True)

    corr = corr.corr()

    sns.clustermap(corr, figsize=(10,10), cmap="Blues")

    sns.despine(left=True, right=True, top=True, bottom=True)

    plt.title(title) 

    plt.show()



getCorrPlot(df,"Publisher","Platform",10,"Correlation plot")

alt text

I find Seaborn’s .clustermap() method to be particularly suited for small to medium size datasets, as larger samples might make the cluster trees on top difficult to read. A clustermap combines hierarchical clusters with traditional heat map plots. I would strongly advise you to normalise all your variables before running this type of plot though (see the line that says np.log(corr[col] + 1)).

For larger datasets, I would recommend giving NetworkX a try. This is a great library for anyone who wants to study the structure, dynamics, and functions of a given graph or network. I won’t spend much time on graph theory, first because it’s not what we are discussing in this article, but also because I really wouldn’t consider myself an expert in that field. If you want to understand the basics of graph theory, I encourage you to start by visiting this website.

import networkx as nx 

n = df.filter(["Publisher", "Platform", "Year"])



def getNetwPlot(data, serie1, serie2, serie3):

    G = nx.from_pandas_edgelist(data, serie1, serie2, edge_attr=True)

    edgelist = nx.to_edgelist(G)



    colors = [i/len(G.nodes) for i in range(len(G.nodes))]



    plt.figure(figsize=(12,8))

    nx.draw(

        G,

        with_labels=True,

        node_size=[v * 200 for v in dict(G.degree()).values()],

        width=[v[2][serie3] / 500 for v in edgelist],

        font_size=10,

        node_color=colors,

        cmap="BuPu"

    )

    plt.title("NetworkX")

    plt.show()



getNetwPlot(n, "Publisher", "Platform", "Year")

alt text

To create the network chart below, we start by filtering out the series that we want to visualise. We then create a list of edges, using NetworkX’s to_edgelist() method. Setting the colors variable is entirely optional, but I would recommend modifying the multiplying value within the node_size and width list comprehensions if the edges don’t show (try * 50 and re run the function).

Thanks for reading!