In [1]:
import numpy as np
import pandas
import matplotlib.pyplot as plt

Multiple lines, legends, datetime¶

If we draw multiple datasets and give them each a label,

ax.plot(..., label="mylabel")

then it's easy to create a legend out of these labelled items: simply call

ax.legend()

I prefer to position my legend outside the Axes, using

ax.legend(loc='center left',  bbox_to_anchor=(1.05,0.5), edgecolor='none')
fig.set_tight_layout(True)

Dates and times. Matplotlib understands datetimes. In this code, the datetime column is a string (that's what read_csv gives us), so we first convert it into a datetime objects that pandas and matplotlib can work with.

In [2]:
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)
stopsearch['date'] = pandas.to_datetime(stopsearch.datetime).dt.date
stopsearch['outcome'] = np.where(stopsearch.outcome=='A no further action disposal', 'nothing', 'find')

# Number of events per date, sorted by timestamp
# (if timestamps were unsorted, the line would wiggle backwards and forwards)
df = stopsearch.groupby(['date','outcome']).apply(len).unstack(fill_value=0).reset_index()
df = df.iloc[np.argsort(df.date)]

fig,ax = plt.subplots(figsize=(7,2.5))
ax.plot(df.date, df.find + df.nothing, label='stops', linewidth=1, color='0.5')
ax.plot(df.date, df.find, label='find', linewidth=3)
ax.set_title('Number of stop-and-search incidents per day')
ax.legend()

# Some magic to improve tick labels for an entire figure
fig.autofmt_xdate(bottom=0.2, rotation=-30, ha='left')

plt.show()

Histogram, line annotations, text annotation¶

Histogram.

ax.hist(VALUES, bins=INT)

Line annotations. To draw a straight horizontal or vertical line or range,

ax.axhline(y)
ax.axhspan(ymin, ymax)
ax.axvline(x)
ax.axvspan(xmin, xmax)

For sloped lines, I think it's simplest to create a dummy dataframe with a row for the start and a row for the end, and just use ax.plot.

Text annotation.

ax.text(x, y, TEXT)

This is different to the standard data-plotting command, like ax.scatter, in that x and y and TEXT MUST NOT be vectors.

To control the alignment of the text with respect to the x and y coordinates, use e.g. ha='left', va='top'.

In [3]:
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)

fig,ax = plt.subplots()
ax.hist(stopsearch.location_latitude, bins=45, alpha=.5, edgecolor='white')
ax.axvline(x=51.5285582, linestyle='dashed', color='black')
ax.text(51.5285582, 2500, 'London', ha='left', va='top')

ax.set_title('Stop-and-search incidents, Nov 2021')

plt.show()

Scatter plot, colour scale¶

Matplotlib has several built-in colour scales.

Choose a colour scale with either of

cmap = plt.get_cmap('SCALENAME')
cmap = plt.get_cmap('SCALENAME', n)

In the first case, cmap is a function from $[0,1]$ to colours. In the second case, it maps $\{0,1,…,n-1\}$ to colours.

The code on the right plots stop-and-search incidents, colour-coded by police force. It uses the discrete colour scale 'Set2' (which only has eight distinct colours, so some forces share the same colour). The scatter plot shows us roughly an outline of England and Wales, colour-coded by region.

In [4]:
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_202110w1.csv'
stopsearch = pandas.read_csv(url)

levels = stopsearch.force.unique()
cmap = plt.get_cmap('Set2', len(levels))

fig,ax = plt.subplots()
for i,lev in enumerate(levels):
    df = stopsearch.loc[stopsearch.force == lev]
    ax.scatter(df.location_longitude, df.location_latitude, s=3, alpha=.5, color=cmap(i))

# Set the aspect ratio, based on the UK’s average latitude
ax.set_aspect(1/np.cos(54/360*2*np.pi))

# Pick coordinates to show (I chose these after seeing the plot first)
ax.set_xlim([-5,2])
ax.set_ylim([50.2, 55.8])

# Get rid of the tick marks and the outer frame
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')

plt.show()

We can also achieve the same outcome without using a ‘for’ loop, by calling ax.scatter once and passing in the c argument,

ax.scatter(…, c=COLOURLIST)
In [8]:
levels = stopsearch.force.unique()
cols = plt.get_cmap('Set2', len(levels))

forceI = stopsearch.force.replace({lev:i for i,lev in enumerate(levels)})
forceC = [cols(i) for i in forceI]

fig,ax = plt.subplots()
ax.scatter(stopsearch.location_longitude,
           stopsearch.location_latitude,
           s = 3,
           alpha = .5,
           c = forceC)

ax.set_aspect(1/np.cos(54/360*2*np.pi))
ax.set_xlim([-5,2])
ax.set_ylim([50.2, 55.8])
ax.set_xticks([])
ax.set_yticks([])
ax.axis('off')

plt.show()

Heatmap, colour bar¶

To plot a heatmap, in other words to plot an array of values using coloured rectangles for each cell,

ax.imshow(ARRAY, 
    origin='lower', 
    extent=(LEFT,RIGHT,BOTTOM,TOP),
    cmap='COLSCALE', vmin=LO, vmax=HI)

This draws the array inside the rectangle specified by extent, with ARRAY[0,0] in the bottom left. The values are coloured with the specified colour scale, with the colour scale running from LO to HI.

To add a legend for the colour scale,

plt.colorbar(im, ax)

where im is the return value from imshow, and where ax is the Axes or list of Axes from which to steal space to squeeze in the legend.

In [5]:
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)

# Extract hour and day-of-week for each stop-and-search incident
stopsearch['datetime'] = pandas.to_datetime(stopsearch.datetime)
stopsearch['date'] = stopsearch.datetime.dt.date
stopsearch['day'] = stopsearch.datetime.dt.dayofweek # Mon=0, Sun=6
stopsearch['hour'] = stopsearch.datetime.dt.hour # 0 .. 23

# Count number of stops by day-of-week and week-of-year
num_stops = stopsearch.groupby(['day','hour']).apply(len).reset_index(name='n')
# Normalize to get stops per day
# (since some days appear more than others, in this dataset)
num_days = stopsearch.groupby('day')['date'].apply(lambda x: (x.max()-x.min()).days/7 + 1).reset_index(name='d')
df = num_stops.merge(num_days, on='day')
df['n_per_day'] = df.n / df.d
df = df.set_index(['day','hour']).n_per_day.unstack(fill_value=0)

fig,ax = plt.subplots(figsize=(10,3))
columns,rows = df.columns, df.index
im = ax.imshow(df.values, 
               origin='lower', 
               extent=(0, 24, -0.5, 6.5), 
               cmap='Blues', vmin=0, vmax=1)
ax.set_yticks(np.arange(0,7))
ax.invert_yaxis()
ax.set_yticklabels(['Mon','Tue','Wed','Thu','Fri','Sat','Sun'])
ax.set_xticks([0,6,12,18])

ax.set_title('Number of stop-and-search incidents per hour')

plt.colorbar(im, ax=ax)
plt.show()

Subplots not in a grid; subplot spacing¶

For an array of subplots, use plt.subplots(nrows,ncols). For ‘wrapped’ subplots, use

fig = plt.figure()
ax1 = fig.add_subplot(nrows, ncols, index=1)
ax2 = fig.add_subplot(nrows, ncols, index=2)
...

We still have to say what the total size of the grid will be (so it knows how big to make each plot), but it only generates the subplots in that grid that we ask for.

We can adjust subplot spacing with

fig.subplots_adjust(wspace, hspace)

The plot below shows a histogram showing the typical number of stops on a Monday, and on a Tuesday, and so on.

In [6]:
url = 'https://www.cl.cam.ac.uk/teaching/current/DataSci/data/stopsearch_cam2021.csv'
stopsearch = pandas.read_csv(url)

stopsearch['datetime'] = pandas.to_datetime(stopsearch.datetime)
stopsearch['weekday'] = stopsearch.datetime.dt.dayofweek # Mon=0, Sun=6
stopsearch['date'] = stopsearch.datetime.dt.date
df = stopsearch.groupby(['date','weekday']).apply(len).reset_index(name='n')

fig = plt.figure(figsize=(7,5))
fig.subplots_adjust(hspace=0.35)

weekday_names = ['Mon','Tue','Wed','Thu','Fri','Sat','Sun']
for i, weekday in enumerate(weekday_names):
    ax = fig.add_subplot(3, 3, i+1)
    counts = df.loc[df.weekday==i, 'n']
    ax.hist(counts, bins=range(15), alpha=.3)
    ax.axvline(x=np.median(counts), linestyle='dotted', color='black')
    if i < 4: ax.set_xticklabels([])
    if (i % 3) != 0: ax.set_yticklabels([])
    ax.set_title(weekday, y=0.9, va='bottom')

fig.suptitle('Distribution of number of stops per day')

plt.show()