Matplotlib

Matplotlib is a widely used plotting library for scientific computing in Python.

The best way to learn Matplotlib is to browse through galleries until you find something you like, and then copy it. But to make sense of what you read, you need to know the basic structure of a plot, which is what this tutorial describes.

Building a plot

Plotting data

In matplotlib, the principal object we work with is an Axes. This is a 2d area in which we can draw points, lines, bars, text, etc. It has x and y scales with tickmarks, and it can have legends and labels.

To create a simple plot consisting of a single Axes, we call plt.subplots().

Calling plt.subplots() actually returns a pair of (Figure,Axes). Most of the time, we only care about the Axes object. The Figure object is for controlling the layout when there are multiple subplots.

import matplotlib.pyplot as plt

fig,ax = plt.subplots()
... # draw things on the Axes
plt.show()

Draw a line graph of y1 as a function of x, and a scatter plot of y2 as a function of x, using the drawing commands

ax.plot(x, y1, linestyle='--', alpha=0.7)
ax.scatter(x, y2, marker='+', color='red')
import numpy as np

x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)

fig,ax = plt.subplots()
ax.plot(x, y1, linestyle='--', alpha=0.7)
ax.scatter(x, y2, marker='+', color='red')

plt.show()

Legends

We should always label our plots! For simple plots, matplotlib makes it easy:

  1. Use ax.plot(…, label="mylabel")
  2. Call plt.legend()

Label the line "theory and the points "experiment".

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)

fig,ax = plt.subplots()
ax.plot(x, y1, linestyle='--', alpha=0.7)
ax.scatter(x, y2, marker='+', color='red')

plt.show()
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)

fig,ax = plt.subplots()
ax.plot(x, y1, linestyle='--', alpha=0.7, label='theory')
ax.scatter(x, y2, marker='+', color='red', label='experiment')
plt.legend()

plt.show()

Exercise

The code on the right lists the top finalists from Eurovision 2022.

Plot a horizontal bar graph to show each country's points, using the drawing command

ax.barh(ev2022['country'], ev2022['points'])
There is a long list of plots we can draw. The documentation lists them all, but it's better to browse matplotlib galleries and copy the one you want.
ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'points': [631, 466, 459, 438, 312, 268]}
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'points': [631, 466, 459, 438, 312, 268]}

fig,ax = plt.subplots()
ax.barh(ev2022['country'], ev2022['points'])
plt.show()

Subplots

A plot with multiple supblots is a great way to show how some complex output depends on a parameter. Let's compare Eurovision jury versus audience scores.

To create a plot with multiple subplots,

fig,axes = plt.subplots(nrows,ncols)

The axes we get back will be either a single Axes object, or a list of Axes, or a matrix, depending on nrows and ncols. (The defaults are nrows=ncols=1, which is why plt.subplots() gives us a simple plot with a single Axes.)

Modify the subplots call to give us two Axes side-by-side, using fig,axes = plt.subplots(1,2). Show jury points on axes[0] and audience points on axes[1].

import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,ax = plt.subplots()
ax.barh(ev2022['country'], ev2022['jury'])
plt.show()
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2)
axes[0].barh(ev2022['country'], ev2022['jury'])
axes[1].barh(ev2022['country'], ev2022['audience'])
plt.show()

DRY

When we have several subplots, it's error-prone to have long blocks of copy-pasted code. It's better to use a for loop instead, so that we only need to write out our plotting commands once.

(This is good general-purpose programming advice. It's called DRY, Don't Repeat Yourself.)

Modify this code to use a for loop, rather than repeating the barh. Hint: use zip.

import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2)
axes[0].barh(ev2022['country'], ev2022['jury'])
axes[1].barh(ev2022['country'], ev2022['audience'])
plt.show()
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])
plt.show()

Making it look good

Graphical legibility

The plots we've produced so far are pretty ugly, and an ugly plot won't convey your message!

Code for plotting often starts out with a tiny core of the sort we've seen that does the actual plotting (subplots, and things drawn on the subplots) … and then it gets swamped by a huge number of tweaks to make the plots legible, i.e. easily understandable by your reader.

I recommend adding in these tweaks systematically, one by one, checking how it looks each time. The following pages show my typical steps.

Labels

To label the x and y axes,

ax.set_xlabel(label)
ax.set_ylabel(label)

To label a subplot,

ax.set_title(label)

To label the entire plot,

fig.suptitle(label)

Label the subplots "Jury" and "Audience". Label the x-axis of each subplot "points".

import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])
plt.show()
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

plt.show()

Axis limits

Matplotlib will usually pick axis limits automatically to fit everything we're drawing. If it hasn't done a good job, we can set them manually:

ax.set_xlim([low,high])
ax.set_ylim([low,high])

To specify that subplots should use the same horizontal or vertical axis,

plt.subplots(…, sharex=bool, sharey=bool)

Update this figure so that the two subplots share the same y-axis. What do you notice about the y-axis tick labels?

import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

plt.show()
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True)
# Since the y-axis is shared, no need to draw y-tick labels on both subplots

for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

plt.show()

Tick marks

Once we've got the axis limits we want, we may like to adjust tick positions and/or labels.

ax.set_xticks([x1,x2,…])
ax.set_xticklabels([lbl1,lbl2,…])
# similarly for yticks, ytick_labels

Set ticks [0,50,…] for the first subplot, and [0,100,…] for the second. Also, make the ticks more elegant, and draw gridlines, with these commands:

ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

plt.show()
import numpy as np
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
    ax.tick_params(axis='x', labelrotation=-60)
    ax.grid(axis='x', alpha=0.5)

plt.show()

Size and layout

To set the overall size of the plot,

plt.subplots(… figsize=(width,height))

The width and height are nominally measured in inches, but the actual output depends on what dpi your computer thinks it's using.

Another useful command, to adjust the spacing of the subplots to accommodate all the tick labels, is

fig.set_tight_layout(True)

To control the ratio of subplot sizes, or the space between them, pass in a gridspec_kw option to plt.subplots, for example

plt.subplots(…, 
    gridspec_kw={'width_ratios': [0.3, 0.5]})

Set the overall figure size to be 6 inches wide and 3 inches tall, and fix the clipping of the y-tickmarks.

import numpy as np
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True)
for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
    ax.tick_params(axis='x', labelrotation=-60)
    ax.grid(axis='x', alpha=0.5)

plt.show()
import numpy as np
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)

for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
    ax.tick_params(axis='x', labelrotation=-60)
    ax.grid(axis='x', alpha=0.5)

plt.show()

Colour theme

We can control the overal colour theme:

plt.style.use('dark_background')
plt.style.use('default')

We can also set the background colour of an individual Axes, or of the entire Figure:

ax.set_facecolor(col)
fig,axes = plt.subplots(…, facecolor=col)
There are several ways to specify colors in matplotlib. Examples: (0.1, 0.2, 0.5), '0f0f0f', '0.9', 'blue'.

Set the background color of the Axes to be '0.9', which means 90% grey.

import numpy as np
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)

for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
    ax.tick_params(axis='x', labelrotation=-60)
    ax.grid(axis='x', alpha=0.5)

plt.show()
import numpy as np
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)

for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
    ax.tick_params(axis='x', labelrotation=-60)
    ax.grid(axis='x', alpha=0.5)
    ax.set_facecolor('0.9')

plt.show()

Saving the plot

To save a plot, run this just before the final plt.show() command:

plt.savefig('filename')

It will format the output appropriately based on the filename's suffix (.svg, .png, .pdf, etc.).

I usually use some extra options, to trim off any extra space around the outside.

plt.savefig('filename', 
            bbox_inches='tight', 
            pad_inches=0)
It won't do any good using savefig in this tutorial, because there's no way to access the file it creates :-(
import numpy as np
import matplotlib.pyplot as plt

ev2022 = {
    'country': ['ukraine','uk','spain','sweden','serbia','italy'],
    'jury': [192, 283, 231, 258, 87, 158],
    'audience': [439, 183, 228, 180, 225, 110]}
    
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)

for ax,k in zip(axes, ['jury','audience']):
    ax.barh(ev2022['country'], ev2022[k])

axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')

axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
    ax.tick_params(axis='x', labelrotation=-60)
    ax.grid(axis='x', alpha=0.5)
    ax.set_facecolor('0.9')

plt.savefig('ev2022.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

Summary

I like to build up plots step by step, adding pieces in order, and checking at each step what the plot looks like. If I add everything in one go, chances are it won't work, and I won't know which bit went wrong. Here are the steps we've covered:

  1. Prepare the data to plot
  2. Set up the Axes with plt.subplots(nrows,ncols)
  3. Use Axes methods to draw data inside each subplot — points, lines, bars, etc.
  4. Label each subplot, and label the x-axis and y-axis
  5. Set the x-axis and y-axis limits, and specify whether the axes should be shared between subplots
  6. Set tick marks
  7. Adjust the overall size and layout
  8. Tweak the colours
  9. Save the plot

After going through these steps, you may like to fine-tune the appearance — font sizes etc. All the matplotib commands have a huge number of options, but it's hard to see any rhyme or reason behind the API, so you'll need to make frequent use of search engines and ChatGPT, and maybe if things get desperate even the documentation.

Working with data

Loading from csv

We may have a csv file with data we want to plot. The easiest way to import it is with the pandas package.

import pandas
dataframe = pandas.read_csv(url)

This returns a DataFrame object, which represents tabular data (think Excel spreadsheet), and behaves like a dictionary of columns.

The code on the right fetches the dataset behind xkcd comic #2048.

Plot a scatter plot of xkcd['y'] against xkcd['x'].

import pandas

xkcd = pandas.read_csv('https://www.cl.cam.ac.uk/teaching/current/DataSci/data/xkcd.csv')
xkcd
import pandas
import matplotlib.pyplot as plt

xkcd = pandas.read_csv('https://www.cl.cam.ac.uk/teaching/current/DataSci/data/xkcd.csv')

fig,ax = plt.subplots()
ax.scatter(xkcd['x'], xkcd['y'])
plt.show()

Aggregating data

Pandas is good for processing data, e.g. tabulations.

Here's an example. We load a csv that has one row per stop-and-search incident in Cambridgeshire in 2021. The groupby command on the right tabulates it by the 'object_of_search' column.

Pandas is a vast data-handling library, and it's a lifesaver for anyone doing practical data science. You can work through a pandas tutorial if you want to learn more.

Several of the plots in the gallery use pandas to pre-process data. The code on the right is just to show you what sort of code to ignore, if all you want to learn about is matplotlib plotting code!

Draw a bar chart, with horizontal bars of length df['n'], and y-labels df['object_of_search'].

import pandas
    
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
df = stopsearch.groupby(['object_of_search']).apply(len).reset_index(name='n')

df
import pandas
import matplotlib.pyplot as plt

url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
df = stopsearch.groupby(['object_of_search']).apply(len).reset_index(name='n')

fig,ax = plt.subplots(figsize=(6,3))
ax.barh(df['object_of_search'], df['n'])
fig.tight_layout()
plt.show()

Plot gallery

Finding inspiration

The best way to learn matplotlib is to browse through plot galleries until you find something you like, and then copy it.

The following pages showcase some examples.

Another useful source: the official Matplotlib gallery.

Multiple lines, legends, datetime

If we draw multiple datasets and give them each a label,

ax.plot(…, label='label')

then it's easy to create a legend out of these labelled items: simply call

ax.legend()

I prefer to position my legend outside the Axes, using

ax.legend(loc='center left', 
    bbox_to_anchor=(1.05,0.5), 
    edgecolor='none')
fig.set_tight_layout(True)

Dates and times

Matplotlib understands datetimes. In the code on the right, the datetime column is a string (that's what read_csv gives us), so we first convert it into a datetime object that pandas and matplotlib can work with.

import numpy as np
import pandas
import matplotlib.pyplot as plt

url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
stopsearch['date'] = pandas.to_datetime(stopsearch.datetime).dt.date
stopsearch['outcome'] = np.where(stopsearch.outcome=='A no further action disposal', 'nothing', 'find')

# Number of events per date, sorted by timestamp
# (if timestamps were unsorted, the line would wiggle backwards and forwards)
df = stopsearch.groupby(['date','outcome']).apply(len).unstack(fill_value=0).reset_index()
df = df.iloc[np.argsort(df.date)]

fig,ax = plt.subplots(figsize=(7,2.5))
ax.plot(df.date, df.find + df.nothing, label='stops', linewidth=1, color='0.5')
ax.plot(df.date, df.find, label='find', linewidth=3)
ax.set_title('Number of stop-and-search incidents per day')
ax.legend()

# Some magic to improve tick labels for an entire figure
fig.autofmt_xdate(bottom=0.2, rotation=-30, ha='left')

plt.show()

Histogram, annotations

Histogram

ax.hist(values, bins=int)

Line annotations

To draw a straight horizontal or vertical line or range,

ax.axhline(y)
ax.axhspan(ymin, ymax)
ax.axvline(x)
ax.axvspan(xmin, xmax)

For sloped lines, I think it's simplest to create a dummy dataframe with a row for the start and a row for the end, and just use ax.plot.

Text annotation

ax.text(x, y, text)

This is different to the standard data-plotting command, like ax.scatter, in that x and y and text MUST NOT be vectors.

To control the alignment of the text with respect to the x and y coordinates, use e.g. ha='left', va='top'.

import pandas
import matplotlib.pyplot as plt

url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)

fig,ax = plt.subplots()
ax.hist(stopsearch.location_latitude, bins=45, alpha=.5, edgecolor='white')
ax.axvline(x=51.5285582, linestyle='dashed', color='black')
ax.text(51.5285582, 2500, 'London', ha='left', va='top')

ax.set_title('Stop-and-search incidents, Nov 2021')

plt.show()

Scatter plot, colour scale

Matplotlib has several built-in colour scales.

Choose a colour scale with either of

cmap = plt.get_cmap('scalename')
cmap = plt.get_cmap('scalename', n)

In the first case, cmap is a function from [0,1] to colours. In the second case, it maps {0,1,…,n-1} to colours.

The code on the right plots stop-and-search incidents, colour-coded by police force. It uses the discrete colour scale 'Set2' (which only has eight distinct colours, so some forces share the same colour). The scatter plot shows us roughly an outline of England and Wales, colour-coded by region.

We can also achieve the same outcome without using a ‘for’ loop, by calling ax.scatter once and passing in the c argument,

ax.scatter(…, c=colourlist)

Click on “Show me” to see how it's done.

import numpy as np
import pandas
import matplotlib.pyplot as plt

url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)

levels = stopsearch.force.unique()
cmap = plt.get_cmap('Set2', len(levels))

fig,ax = plt.subplots()
for i,lev in enumerate(levels):
    df = stopsearch.loc[stopsearch.force == lev]
    ax.scatter(df.location_longitude, df.location_latitude, s=3, alpha=.5, color=cmap(i))

# Set the aspect ratio, based on the UK’s average latitude
ax.set_aspect(1/np.cos(54/360*2*np.pi))

# Pick coordinates to show (I chose these after seeing the plot first)
ax.set_xlim([-5,2])
ax.set_ylim([50.2, 55.8])

# Get rid of the tick marks and the outer frame
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')

plt.show()
import numpy as np
import pandas
import matplotlib.pyplot as plt

url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)

levels = stopsearch.force.unique()
cols = plt.get_cmap('Set2', len(levels))

forceI = stopsearch.force.replace({lev:i for i,lev in enumerate(levels)})
forceC = [cols(i) for i in forceI]

fig,ax = plt.subplots()
ax.scatter(stopsearch.location_longitude,
           stopsearch.location_latitude,
           s = 3,
           alpha = .5,
           c = forceC)

# Set the aspect ratio, based on the UK’s average latitude
ax.set_aspect(1/np.cos(54/360*2*np.pi))

# Pick coordinates to show (I chose these after seeing the plot first)
ax.set_xlim([-5,2])
ax.set_ylim([50.2, 55.8])

# Get rid of the tick marks and the outer frame
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')

plt.show()

Heatmap, colour bar

To plot a heatmap, in other words to plot an array of values using coloured rectangles for each cell,

ax.imshow(array, 
    origin='lower', 
    extent=(left,right,bottom,top),
    cmap='colscale', vmin=lo, vmax=hi)

This draws the array inside the rectangle specified by extent, with array[0,0] in the bottom left. The values are coloured with the specified colour scale, with the colour scale running from lo to hi.

To add a legend for the colour scale,

plt.colorbar(im, ax)

where im is the return value from imshow, and where ax is the Axes or list of Axes from which to steal space to squeeze in the legend.

import numpy as np
import pandas
import matplotlib.pyplot as plt

url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)

# Extract hour and day-of-week for each stop-and-search incident
stopsearch['datetime'] = pandas.to_datetime(stopsearch.datetime)
stopsearch['date'] = stopsearch.datetime.dt.date
stopsearch['day'] = stopsearch.datetime.dt.dayofweek # Mon=0, Sun=6
stopsearch['hour'] = stopsearch.datetime.dt.hour # 0 .. 23

# Count number of stops by day-of-week and week-of-year
num_stops = stopsearch.groupby(['day','hour']).apply(len).reset_index(name='n')
# Normalize to get stops per day
# (since some days appear more than others, in this dataset)
num_days = stopsearch.groupby('day')['date'].apply(lambda x: (x.max()-x.min()).days/7 + 1).reset_index(name='d')
df = num_stops.merge(num_days, on='day')
df['n_per_day'] = df.n / df.d
df = df.set_index(['day','hour']).n_per_day.unstack(fill_value=0)

fig,ax = plt.subplots(figsize=(10,3))
columns,rows = df.columns, df.index
im = ax.imshow(df.values, 
               origin='lower', 
               extent=(0, 24, -0.5, 6.5), 
               cmap='Blues', vmin=0, vmax=1)
ax.set_yticks(np.arange(0,7))
ax.invert_yaxis()
ax.set_yticklabels(['Mon','Tue','Wed','Thu','Fri','Sat','Sun'])
ax.set_xticks([0,6,12,18])

ax.set_title('Number of stop-and-search incidents per hour')

plt.colorbar(im, ax=ax)
plt.show()

Next steps

Beyond matplotlib

Matplotlib is a powerful plotting library. There are many topics that this tutorial hasn't touched on, such as interaction and animation.

But it has a creaking API …

Alternatives

You might get fed up, and switch to seaborn which is a matplotlib wrapper, or plotly.

In my opinion these aren't complete enough to fully replace matplotlib, and it's a needless pain to learn two packages. For serious plotting I switch to another language, R, which has a far superior plotting library called ggplot2.

import numpy as np
import matplotlib.pyplot as plt

n = 5000
π = np.pi
k = np.random.choice(4, p=[.6,.3,.05,.05], size=n)
t = np.random.uniform(size=n)
x = np.column_stack([np.sin(2*π*t), 0.55*np.sin(2*π*(0.4*t+0.3)), -0.3*np.ones(n), 0.3*np.ones(n)])
y = np.column_stack([np.cos(2*π*t), 0.55*np.cos(2*π*(0.4*t+0.3)), 0.3*np.ones(n), 0.3*np.ones(n)])
xy = np.column_stack([x[np.arange(n), k], y[np.arange(n), k]])
xy = np.random.normal(loc=xy, scale=.08)
x,y = xy[:,0], xy[:,1]

fig,((ax_x,dummy),(ax_xy,ax_y)) = plt.subplots(2,2, figsize=(4,4), sharex='col', sharey='row',
                                          gridspec_kw={'height_ratios':[1,2], 'width_ratios':[2,1]})
dummy.remove()
ax_xy.scatter(xy[:,0], xy[:,1], s=3, alpha=.1)
ax_x.hist(x, density=True, bins=60)
ax_y.hist(y, density=True, bins=60, orientation='horizontal')
fig.suptitle("Well done!", x=0.12, ha='left')
plt.show()