Matplotlib
Matplotlib is a widely used plotting library for scientific computing in Python.
The best way to learn Matplotlib is to browse through galleries until you find
something you like, and then copy it. But to make sense of what you read,
you need to know the basic structure of a plot, which is what this tutorial describes.
Building a plot
Plotting data
In matplotlib, the principal object we work with is an Axes. This is a 2d area in
which we can draw points, lines, bars, text, etc.
It has x and y
scales with tickmarks, and it can have legends and labels.
To create a simple plot consisting of a single Axes, we call plt.subplots()
.
Calling plt.subplots()
actually returns a pair of (Figure,Axes).
Most of the time, we only care about the Axes object.
The Figure
object is for controlling the layout when there are multiple subplots.
import matplotlib.pyplot as plt
fig,ax = plt.subplots()
... # draw things on the Axes
plt.show()
Draw a line graph of y1
as a function of x
,
and a scatter plot of y2
as a function of x
,
using the drawing commands
ax.plot(x, y1, linestyle='--', alpha=0.7)
ax.scatter(x, y2, marker='+', color='red')
import numpy as np
x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)
fig,ax = plt.subplots()
ax.plot(x, y1, linestyle='--', alpha=0.7)
ax.scatter(x, y2, marker='+', color='red')
plt.show()
Legends
We should always label our plots! For simple plots, matplotlib makes it easy:
- Use
ax.plot(…, label="mylabel"
)
- Call
plt.legend()
Label the line "theory
and the points "experiment"
.
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)
fig,ax = plt.subplots()
ax.plot(x, y1, linestyle='--', alpha=0.7)
ax.scatter(x, y2, marker='+', color='red')
plt.show()
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0,10,100)
y1 = np.sin(x)
y2 = np.random.normal(loc=np.sin(x), scale=0.1)
fig,ax = plt.subplots()
ax.plot(x, y1, linestyle='--', alpha=0.7, label='theory')
ax.scatter(x, y2, marker='+', color='red', label='experiment')
plt.legend()
plt.show()
Exercise
The code on the right lists the top finalists from Eurovision 2022.
Plot a horizontal bar graph to show each country's points, using the drawing command
ax.barh(ev2022['country'], ev2022['points'])
There is a long list of plots we can draw.
The documentation
lists them all, but it's better to browse matplotlib galleries and copy the one you want.
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'points': [631, 466, 459, 438, 312, 268]}
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'points': [631, 466, 459, 438, 312, 268]}
fig,ax = plt.subplots()
ax.barh(ev2022['country'], ev2022['points'])
plt.show()
Subplots
A plot with multiple supblots is a great way to show how some complex output depends on
a parameter. Let's compare Eurovision jury versus audience scores.
To create a plot with multiple subplots,
fig,axes = plt.subplots(nrows,ncols)
The axes
we get back will be either a single Axes object, or a list of Axes, or a matrix,
depending on nrows
and ncols
.
(The defaults are nrows=ncols=1
, which is why plt.subplots()
gives us a simple plot with a single Axes.)
Modify the subplots
call to give us two Axes side-by-side, using
fig,axes = plt.subplots(1,2)
.
Show jury points on axes[0]
and audience points
on axes[1]
.
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,ax = plt.subplots()
ax.barh(ev2022['country'], ev2022['jury'])
plt.show()
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2)
axes[0].barh(ev2022['country'], ev2022['jury'])
axes[1].barh(ev2022['country'], ev2022['audience'])
plt.show()
DRY
When we have several subplots, it's error-prone to have long blocks
of copy-pasted code. It's better to use a for
loop instead,
so that we only need to write out our plotting commands once.
(This is good general-purpose programming advice. It's called DRY,
Don't Repeat Yourself.)
Modify this code to use a for
loop, rather than repeating
the barh
. Hint: use zip.
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2)
axes[0].barh(ev2022['country'], ev2022['jury'])
axes[1].barh(ev2022['country'], ev2022['audience'])
plt.show()
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
plt.show()
Making it look good
Graphical legibility
The plots we've produced so far are pretty ugly, and an ugly
plot won't convey your message!
Code for plotting often starts out with a tiny core of the sort we've seen that does the actual plotting
(subplots, and
things drawn on the subplots) … and then it gets
swamped by a huge number of tweaks to make the plots legible,
i.e. easily understandable by your reader.
I recommend adding in these tweaks systematically, one by one, checking how it looks each time.
The following pages show my typical steps.
Labels
To label the x and y axes,
ax.set_xlabel(label
)
ax.set_ylabel(label
)
To label a subplot,
ax.set_title(label
)
To label the entire plot,
fig.suptitle(label
)
Label the subplots "Jury" and "Audience". Label the x-axis of each subplot "points".
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
plt.show()
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
plt.show()
Axis limits
Matplotlib will usually pick axis limits automatically to fit everything we're drawing. If it
hasn't done a good job, we can set them manually:
ax.set_xlim([low
,high
])
ax.set_ylim([low
,high
])
To specify that subplots should use the same horizontal or vertical axis,
plt.subplots(…, sharex=bool
, sharey=bool
)
Update this figure so that the two subplots share the same y-axis.
What do you notice about the y-axis tick labels?
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
plt.show()
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True)
# Since the y-axis is shared, no need to draw y-tick labels on both subplots
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
plt.show()
Tick marks
Once we've got the axis limits we want, we may like to adjust tick positions and/or labels.
ax.set_xticks([x1,x2,…])
ax.set_xticklabels([lbl1,lbl2,…])
# similarly for yticks, ytick_labels
Set ticks [0,50,…] for the first subplot, and [0,100,…] for the second.
Also, make the ticks more elegant, and draw gridlines, with these commands:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
plt.show()
import numpy as np
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
plt.show()
Size and layout
To set the overall size of the plot,
plt.subplots(… figsize=(width
,height
))
The width and height are nominally measured in inches, but the actual output
depends on what dpi your computer thinks it's using.
Another useful command, to adjust the spacing of the subplots to accommodate all the tick labels, is
fig.set_tight_layout(True)
To control the ratio of subplot sizes, or the space between them,
pass in a gridspec_kw
option to plt.subplots
, for example
plt.subplots(…,
gridspec_kw={'width_ratios': [0.3, 0.5]})
Set the overall figure size to be 6 inches wide and 3 inches tall, and fix the clipping
of the y-tickmarks.
import numpy as np
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
plt.show()
import numpy as np
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
plt.show()
Colour theme
We can control the overal colour theme:
plt.style.use('dark_background')
plt.style.use('default')
We can also set the background colour of an individual Axes, or of the entire Figure:
ax.set_facecolor(col
)
fig,axes = plt.subplots(…, facecolor=col
)
There are several ways to specify colors in matplotlib. Examples:
(0.1, 0.2, 0.5)
, '0f0f0f'
, '0.9'
, 'blue'
.
Set the background color of the Axes to be '0.9'
, which means 90% grey.
import numpy as np
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
plt.show()
import numpy as np
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
ax.set_facecolor('0.9')
plt.show()
Saving the plot
To save a plot, run this just before the final plt.show()
command:
plt.savefig('filename
')
It will format the output appropriately based on the filename's suffix (.svg, .png, .pdf, etc.).
I usually use some extra options, to trim off any extra space around the outside.
plt.savefig('filename
',
bbox_inches='tight',
pad_inches=0)
It won't do any good using savefig
in this tutorial, because there's no way to access the
file it creates :-(
import numpy as np
import matplotlib.pyplot as plt
ev2022 = {
'country': ['ukraine','uk','spain','sweden','serbia','italy'],
'jury': [192, 283, 231, 258, 87, 158],
'audience': [439, 183, 228, 180, 225, 110]}
fig,axes = plt.subplots(1,2, sharey=True, figsize=(6,3))
fig.set_tight_layout(True)
for ax,k in zip(axes, ['jury','audience']):
ax.barh(ev2022['country'], ev2022[k])
axes[0].set_title('Jury')
axes[1].set_title('Audience')
for ax in axes: ax.set_xlabel('points')
axes[0].set_xticks(np.arange(0,300,50))
axes[1].set_xticks(np.arange(0,600,100))
for ax in axes:
ax.tick_params(axis='x', labelrotation=-60)
ax.grid(axis='x', alpha=0.5)
ax.set_facecolor('0.9')
plt.savefig('ev2022.pdf', bbox_inches='tight', pad_inches=0)
plt.show()
Summary
I like to build up plots step by step, adding pieces in order, and checking at each step what the plot looks like.
If I add everything in one go, chances are it won't work, and I won't know which bit went wrong.
Here are the steps we've covered:
- Prepare the data to plot
- Set up the Axes with
plt.subplots(nrows,ncols)
- Use Axes methods
to draw data inside each subplot — points, lines, bars, etc.
- Label each subplot, and label the x-axis and y-axis
- Set the x-axis and y-axis limits, and specify whether the axes should be shared between subplots
- Set tick marks
- Adjust the overall size and layout
- Tweak the colours
- Save the plot
After going through these steps, you may like to fine-tune the appearance — font sizes etc.
All the matplotib commands have a huge number of options,
but it's hard to see any rhyme or reason behind the API,
so you'll need to make frequent use of search engines and ChatGPT,
and maybe if things get desperate even the
documentation.
Working with data
Loading from csv
We may have a csv file with data we want to plot.
The easiest way to import it is with the
pandas
package.
import pandas
dataframe
= pandas.read_csv(url
)
This returns a DataFrame
object, which represents
tabular data (think Excel spreadsheet), and behaves like a dictionary of columns.
The code on the right
fetches the dataset behind
xkcd comic #2048.
Plot a scatter plot of xkcd['y']
against xkcd['x']
.
import pandas
xkcd = pandas.read_csv('https://www.cl.cam.ac.uk/teaching/current/DataSci/data/xkcd.csv')
xkcd
import pandas
import matplotlib.pyplot as plt
xkcd = pandas.read_csv('https://www.cl.cam.ac.uk/teaching/current/DataSci/data/xkcd.csv')
fig,ax = plt.subplots()
ax.scatter(xkcd['x'], xkcd['y'])
plt.show()
Aggregating data
Pandas is good for processing data, e.g. tabulations.
Here's an example. We load a csv that has one row per stop-and-search incident in
Cambridgeshire in 2021. The groupby
command on the right tabulates it by
the 'object_of_search'
column.
Pandas is a vast data-handling library, and it's a lifesaver for
anyone doing practical data science.
You can work through a pandas tutorial if you want to
learn more.
Several of the plots in the gallery
use pandas to pre-process data. The code on the right is just to show you what
sort of code to ignore, if all you want to learn about is matplotlib plotting code!
Draw a bar chart, with horizontal bars of length df['n']
,
and y-labels df['object_of_search']
.
import pandas
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
df = stopsearch.groupby(['object_of_search']).apply(len).reset_index(name='n')
df
import pandas
import matplotlib.pyplot as plt
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
df = stopsearch.groupby(['object_of_search']).apply(len).reset_index(name='n')
fig,ax = plt.subplots(figsize=(6,3))
ax.barh(df['object_of_search'], df['n'])
fig.tight_layout()
plt.show()
Plot gallery
Finding inspiration
The best way to learn matplotlib is to browse through plot galleries until you find something you like, and then copy it.
The following pages showcase some examples.
Another useful source: the official Matplotlib gallery.
Multiple lines, legends, datetime
If we draw multiple datasets and give them each a label,
ax.plot(…, label='label
')
then it's easy to create a legend out of these labelled items: simply call
ax.legend()
I prefer to position my legend outside the Axes, using
ax.legend(loc='center left',
bbox_to_anchor=(1.05,0.5),
edgecolor='none')
fig.set_tight_layout(True)
Dates and times
Matplotlib understands datetimes. In the code on the right, the datetime
column is a string
(that's what read_csv
gives us), so we first convert it into a datetime object that pandas and matplotlib
can work with.
import numpy as np
import pandas
import matplotlib.pyplot as plt
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
stopsearch['date'] = pandas.to_datetime(stopsearch.datetime).dt.date
stopsearch['outcome'] = np.where(stopsearch.outcome=='A no further action disposal', 'nothing', 'find')
# Number of events per date, sorted by timestamp
# (if timestamps were unsorted, the line would wiggle backwards and forwards)
df = stopsearch.groupby(['date','outcome']).apply(len).unstack(fill_value=0).reset_index()
df = df.iloc[np.argsort(df.date)]
fig,ax = plt.subplots(figsize=(7,2.5))
ax.plot(df.date, df.find + df.nothing, label='stops', linewidth=1, color='0.5')
ax.plot(df.date, df.find, label='find', linewidth=3)
ax.set_title('Number of stop-and-search incidents per day')
ax.legend()
# Some magic to improve tick labels for an entire figure
fig.autofmt_xdate(bottom=0.2, rotation=-30, ha='left')
plt.show()
Histogram, annotations
Histogram
ax.hist(values
, bins=int
)
Line annotations
To draw a straight horizontal or vertical line or range,
ax.axhline(y)
ax.axhspan(ymin, ymax)
ax.axvline(x)
ax.axvspan(xmin, xmax)
For sloped lines, I think it's simplest to create a dummy dataframe with a row for the start and a row for the end,
and just use ax.plot
.
Text annotation
ax.text(x, y, text
)
This is different to the standard data-plotting command, like ax.scatter
,
in that x
and y
and
text
MUST NOT be vectors.
To control the alignment of the text with respect to the x and y coordinates, use e.g.
ha='left', va='top'
.
import pandas
import matplotlib.pyplot as plt
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)
fig,ax = plt.subplots()
ax.hist(stopsearch.location_latitude, bins=45, alpha=.5, edgecolor='white')
ax.axvline(x=51.5285582, linestyle='dashed', color='black')
ax.text(51.5285582, 2500, 'London', ha='left', va='top')
ax.set_title('Stop-and-search incidents, Nov 2021')
plt.show()
Scatter plot, colour scale
Matplotlib has several built-in
colour scales.
Choose a colour scale with either of
cmap = plt.get_cmap('scalename
')
cmap = plt.get_cmap('scalename
', n)
In the first case, cmap
is a function from [0,1] to colours. In the second case,
it maps {0,1,…,n-1} to colours.
The code on the right plots stop-and-search incidents, colour-coded by police force.
It uses the discrete colour scale 'Set2'
(which only has eight distinct colours,
so some forces share the same colour). The scatter plot shows us roughly an outline of England and Wales,
colour-coded by region.
We can also achieve the same outcome without using a ‘for’ loop, by
calling ax.scatter
once and passing in the c
argument,
ax.scatter(…, c=colourlist
)
Click on “Show me” to see how it's done.
import numpy as np
import pandas
import matplotlib.pyplot as plt
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)
levels = stopsearch.force.unique()
cmap = plt.get_cmap('Set2', len(levels))
fig,ax = plt.subplots()
for i,lev in enumerate(levels):
df = stopsearch.loc[stopsearch.force == lev]
ax.scatter(df.location_longitude, df.location_latitude, s=3, alpha=.5, color=cmap(i))
# Set the aspect ratio, based on the UK’s average latitude
ax.set_aspect(1/np.cos(54/360*2*np.pi))
# Pick coordinates to show (I chose these after seeing the plot first)
ax.set_xlim([-5,2])
ax.set_ylim([50.2, 55.8])
# Get rid of the tick marks and the outer frame
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
plt.show()
import numpy as np
import pandas
import matplotlib.pyplot as plt
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)
levels = stopsearch.force.unique()
cols = plt.get_cmap('Set2', len(levels))
forceI = stopsearch.force.replace({lev:i for i,lev in enumerate(levels)})
forceC = [cols(i) for i in forceI]
fig,ax = plt.subplots()
ax.scatter(stopsearch.location_longitude,
stopsearch.location_latitude,
s = 3,
alpha = .5,
c = forceC)
# Set the aspect ratio, based on the UK’s average latitude
ax.set_aspect(1/np.cos(54/360*2*np.pi))
# Pick coordinates to show (I chose these after seeing the plot first)
ax.set_xlim([-5,2])
ax.set_ylim([50.2, 55.8])
# Get rid of the tick marks and the outer frame
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')
plt.show()
Heatmap, colour bar
To plot a heatmap, in other words to plot an array of values using coloured rectangles for each cell,
ax.imshow(array
,
origin='lower',
extent=(left
,right
,bottom
,top
),
cmap='colscale
', vmin=lo
, vmax=hi
)
This draws the array inside the rectangle specified by extent
, with array
[0,0]
in the bottom left. The values are coloured with the specified
colour scale,
with the colour scale running from lo
to hi
.
To add a legend for the colour scale,
plt.colorbar(im, ax)
where im
is the return value from imshow
, and where ax
is the Axes or list of Axes
from which to steal space to squeeze in the legend.
import numpy as np
import pandas
import matplotlib.pyplot as plt
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
# Extract hour and day-of-week for each stop-and-search incident
stopsearch['datetime'] = pandas.to_datetime(stopsearch.datetime)
stopsearch['date'] = stopsearch.datetime.dt.date
stopsearch['day'] = stopsearch.datetime.dt.dayofweek # Mon=0, Sun=6
stopsearch['hour'] = stopsearch.datetime.dt.hour # 0 .. 23
# Count number of stops by day-of-week and week-of-year
num_stops = stopsearch.groupby(['day','hour']).apply(len).reset_index(name='n')
# Normalize to get stops per day
# (since some days appear more than others, in this dataset)
num_days = stopsearch.groupby('day')['date'].apply(lambda x: (x.max()-x.min()).days/7 + 1).reset_index(name='d')
df = num_stops.merge(num_days, on='day')
df['n_per_day'] = df.n / df.d
df = df.set_index(['day','hour']).n_per_day.unstack(fill_value=0)
fig,ax = plt.subplots(figsize=(10,3))
columns,rows = df.columns, df.index
im = ax.imshow(df.values,
origin='lower',
extent=(0, 24, -0.5, 6.5),
cmap='Blues', vmin=0, vmax=1)
ax.set_yticks(np.arange(0,7))
ax.invert_yaxis()
ax.set_yticklabels(['Mon','Tue','Wed','Thu','Fri','Sat','Sun'])
ax.set_xticks([0,6,12,18])
ax.set_title('Number of stop-and-search incidents per hour')
plt.colorbar(im, ax=ax)
plt.show()
Next steps
Beyond matplotlib
Matplotlib is a powerful plotting library. There are many topics that this tutorial hasn't
touched on, such as interaction and animation.
But it has a creaking API …
Alternatives
You might get fed up, and switch to seaborn
which is a matplotlib wrapper, or plotly.
In my opinion these aren't complete enough to fully replace matplotlib,
and it's a needless pain to learn two packages. For serious plotting I switch to another language, R,
which has a far superior plotting library called ggplot2.
import numpy as np
import matplotlib.pyplot as plt
n = 5000
π = np.pi
k = np.random.choice(4, p=[.6,.3,.05,.05], size=n)
t = np.random.uniform(size=n)
x = np.column_stack([np.sin(2*π*t), 0.55*np.sin(2*π*(0.4*t+0.3)), -0.3*np.ones(n), 0.3*np.ones(n)])
y = np.column_stack([np.cos(2*π*t), 0.55*np.cos(2*π*(0.4*t+0.3)), 0.3*np.ones(n), 0.3*np.ones(n)])
xy = np.column_stack([x[np.arange(n), k], y[np.arange(n), k]])
xy = np.random.normal(loc=xy, scale=.08)
x,y = xy[:,0], xy[:,1]
fig,((ax_x,dummy),(ax_xy,ax_y)) = plt.subplots(2,2, figsize=(4,4), sharex='col', sharey='row',
gridspec_kw={'height_ratios':[1,2], 'width_ratios':[2,1]})
dummy.remove()
ax_xy.scatter(xy[:,0], xy[:,1], s=3, alpha=.1)
ax_x.hist(x, density=True, bins=60)
ax_y.hist(y, density=True, bins=60, orientation='horizontal')
fig.suptitle("Well done!", x=0.12, ha='left')
plt.show()