Week 1 - Tools/Plotting dataset¶
About the dataset¶
The global dataset of historical yields for soybean 1981–2016
LLM help¶
ChatGPT 5.1, prompts used
- "I have a database split up into different NC4 files. I want to visualize this dataset using Python. How do I do it?"
- "How to add for the plot "ds_all["Yield"].sel(year=2000).plot()" a world map underneath to visualize in which countries the Yield is."
- "Rather than selecting a specific year, I want the plot to be interactive and be able above the graph to type the year I want to show"
- "And for this graph "ts = ds_all["Yield"].sel(lat=45, lon=10, method="nearest") ts.plot()" I want to be able to choose the median per country rather than a specific longitude and latitude and therefore I want to be interactive and choose the country I want to select"
pip -q install xarray netCDF4 cartopy
Note: you may need to restart the kernel to use updated packages.
Dataset reading¶
import os #os.path to manipulate paths
import xarray as xr #for multidimensional labeled arrays and datasets, using dimensions, coordinates, and attributes on top of NumPy-like arrays
from pathlib import Path
directory = Path("/home/jovyan/work/jeogeorge/datasets/soybean_yield_1981-2016")
files = sorted( # Sort the resulting list alphabetically
str(p) # Convert each Path object to a string
for p in directory.iterdir() # Iterate over all items in the directory
if p.name.startswith("yield_") # Keep only files whose names start with 'yield_'
and p.name.endswith(".nc4") # Keep only files whose names end with '.nc4'
)
# Extract years from filenames
years = [
int(os.path.basename(f).split("_")[1].split(".")[0])
for f in files
]
datasets = []
for f, y in zip(files, years):
ds = xr.open_dataset(f)
# add a year dimension
ds = ds.expand_dims(year=[y])
datasets.append(ds)
ds_all = xr.concat(datasets, dim="year")
# Rename variable and unit
ds_all = ds_all.rename({"var": "Yield"})
ds_all["Yield"].attrs["units"] = r"t.ha$^{-1}$"
print (ds_all)
<xarray.Dataset> Size: 37MB
Dimensions: (year: 36, lat: 360, lon: 720)
Coordinates:
* year (year) int64 288B 1981 1982 1983 1984 1985 ... 2013 2014 2015 2016
* lat (lat) float64 3kB -89.75 -89.25 -88.75 -88.25 ... 88.75 89.25 89.75
* lon (lon) float64 6kB 0.25 0.75 1.25 1.75 ... 358.2 358.8 359.2 359.8
Data variables:
Yield (year, lat, lon) float32 37MB nan nan nan nan ... nan nan nan nan
Plotting a year of the dataset¶
ds_all["Yield"].sel(year=2000).plot()
<matplotlib.collections.QuadMesh at 0xe297a33dbcb0>
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
# Data for one year
yield_2000 = ds_all["Yield"].sel(year=2000)
plt.figure(figsize=(10, 5))
# Plot data; PlateCarree - lat/lon projection of the data
ax = plt.axes(projection=ccrs.PlateCarree())
yield_2000.plot(
ax=ax,
transform=ccrs.PlateCarree(),
cbar_kwargs={"label": "Yield (t.ha$^{-1}$)"}
)
# Add basemap features
ax.coastlines()
ax.set_global()
ax.set_title("Soybean Yield in 2000 (t.ha$^{-1}$)")
plt.show()
Integrate slider into plot to select year¶
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
years_avail = ds_all.year.values
def plot_year(year):
data = ds_all["Yield"].sel(year=year)
plt.figure()
ax = plt.axes(projection=ccrs.PlateCarree())
data.plot(
ax=ax,
transform=ccrs.PlateCarree(),
add_colorbar=True,
cbar_kwargs={"label": r"Yield (t ha$^{-1}$)"}
)
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linewidth=0.5)
ax.set_global()
ax.set_title(f"Soybean Yield in {year}")
plt.show()
interact(
plot_year,
year=IntSlider(
min=int(years_avail.min()),
max=int(years_avail.max()),
step=1,
value=int(years_avail.min()),
description="Year",
)
)
interactive(children=(IntSlider(value=1981, description='Year', max=2016, min=1981), Output()), _dom_classes=(…
<function __main__.plot_year(year)>
Chatgpt 5.1 prompt "Using matplolib animation, I want to modify this script to create a GIF file showing all the years from 1981 to 2016"
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cartopy.crs as ccrs
years = range(1981, 2016)
fig = plt.figure(figsize=(10, 5))
ax = plt.axes(projection=ccrs.PlateCarree())
im = ds_all["Yield"].sel(year=years[0]).plot(
ax=ax,
transform=ccrs.PlateCarree(),
cmap="viridis",
add_colorbar=False
)
cbar = plt.colorbar(im, ax=ax, orientation="vertical", pad=0.02)
cbar.set_label("Yield (t.ha$^{-1}$)")
ax.coastlines()
ax.set_global()
def update(year):
"""Update plot for each frame"""
ax.clear()
data = ds_all["Yield"].sel(year=year)
data.plot(
ax=ax,
transform=ccrs.PlateCarree(),
cmap="viridis",
add_colorbar=False
)
ax.coastlines()
ax.set_global()
ax.set_title(f"Soybean Yield in {year} (t.ha$^{-1}$)")
return ax,
ani = animation.FuncAnimation(
fig,
update,
frames=years,
interval=500 # milliseconds between frames
)
plt.show()
plt.close()
Plot yield data for a specific location¶
ts = ds_all["Yield"].sel(lat=45, lon=10, method="nearest")
ts.plot.scatter()
<matplotlib.collections.PathCollection at 0xe297a3425010>
import cartopy.io.shapereader as shpreader
import numpy as np
from shapely.geometry import Point
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
# 1. Load country boundaries from Natural Earth
# Load shapefile for country polygons (admin level 0)
shp_path = shpreader.natural_earth(
resolution='110m',
category='cultural',
name='admin_0_countries'
)
# Read shapefile records
reader = shpreader.Reader(shp_path)
records = list(reader.records())
# Extract and sort country names
countries = sorted(rec.attributes["NAME"] for rec in records)
# 2. Prepare latitude–longitude grid from dataset
lats = ds_all["lat"].values
lons = ds_all["lon"].values
# Create 2D lat/lon grid for point-in-polygon testing
lon2d, lat2d = np.meshgrid(lons, lats)
# 3. Build a mask of grid cells inside a given country
def country_mask(name):
# find the matching country record
rec = next(r for r in records if r.attributes["NAME"] == name)
geom = rec.geometry
# For precision, handle multipolygons
if geom.geom_type == 'MultiPolygon':
polys = list(geom.geoms)
else:
polys = [geom]
mask = np.zeros(lon2d.shape, dtype=bool)
for poly in polys:
for i in range(mask.shape[0]):
for j in range(mask.shape[1]):
mask[i, j] |= poly.contains(Point(lon2d[i, j], lat2d[i, j]))
return mask
# 4. Compute a country-level time series by median aggregation
def country_timeseries(name):
"""
Returns the median yield over all grid cells inside
the selected country for each time step.
"""
mask = country_mask(name)
data = ds_all["Yield"].where(mask)
return data.median(dim=("lat", "lon"))
# 5. Interactive widget to plot individual data points
@interact(
country=widgets.Dropdown(
options=countries,
value="China",
description="Country"
)
)
def plot_country(country):
"""
Plots the country's median yield time series as individual points.
"""
ts = country_timeseries(country)
ts.plot.scatter() # scatter instead of line
plt.ylabel(r"Yield (t.ha$^{-1}$)")
plt.title(f"Median Soybean yield in {country}")
plt.show()
interactive(children=(Dropdown(description='Country', index=31, options=('Afghanistan', 'Albania', 'Algeria', …
Chatgpt 5.1 prompt "Rather than the median show all the data point of a country using boxplots"
# 4. Collect all yield values per year for a country
def country_yearly_values(name):
"""
Returns a list of arrays:
one array of all Yield values per year (inside the country).
"""
mask = country_mask(name)
data = ds_all["Yield"].where(mask)
yearly_values = []
for y in data["year"].values:
vals = data.sel(year=y).values.flatten()
vals = vals[~np.isnan(vals)] # remove NaNs
yearly_values.append(vals)
return yearly_values, data["year"].values
# 5. Interactive boxplot
@interact(
country=widgets.Dropdown(
options=countries,
value="China",
description="Country",
)
)
def plot_country(country):
values, years = country_yearly_values(country)
plt.boxplot(
values,
showfliers=True,
medianprops=dict(color="black"),
)
plt.xticks(
ticks=np.arange(1, len(years) + 1),
labels=years,
rotation=90,
)
plt.ylabel(r"Yield (t.ha$^{-1}$)")
plt.xlabel("Year")
plt.title(f"Distribution of Soybean Yield in {country}")
plt.tight_layout()
plt.show()
interactive(children=(Dropdown(description='Country', index=31, options=('Afghanistan', 'Albania', 'Algeria', …
Using regionmask¶
pip install -q regionmask
Note: you may need to restart the kernel to use updated packages.
import matplotlib.pyplot as plt
import regionmask
import ipywidgets as widgets
from ipywidgets import interact
# 1. Use regionmask's built-in Natural Earth countries
# (no geopandas.datasets.get_path)
countries = regionmask.defined_regions.natural_earth_v5_0_0.countries_110
# 2. Create a mask on the xarray grid (dims: lat, lon)
# For new regionmask versions: .mask(obj) or .mask(lon, lat)
mask = countries.mask(ds_all) # ds_all must have lon/lat coords
# 3. Helper: compute time series for one country
def get_country_data(country_name):
# Convert name → region index
idx = countries.map_keys(country_name) # region number
country_mask = mask == idx # True where grid in that country
data = ds_all["Yield"].where(country_mask)
return data
# 4. Interactive dropdown
country_list = list(countries.names) # list of country names
# (optional, to sort alphabetically)
country_list = sorted(country_list)
@interact(
country=widgets.Dropdown(
options=country_list,
value="China", # any default that exists in country_list
description="Country",
)
)
def plot_country_boxplot(country):
data = get_country_data(country)
# Convert to tidy dataframe
df = (
data
.to_dataframe(name="Yield")
.reset_index()
.dropna()
)
df.boxplot(
column="Yield",
by="year",
showfliers=True, # show outliers
grid=False
)
plt.suptitle("")
plt.title(f"Soybean yield distribution in {country}")
plt.xlabel("Year")
plt.ylabel(r"Yield (t.ha$^{-1}$)")
plt.xticks(rotation=45)
plt.show()
interactive(children=(Dropdown(description='Country', index=31, options=('Afghanistan', 'Albania', 'Algeria', …
Adding Country variable to original dataset¶
ChatGPT5.1 prompt "Based on the previous script, I want to edit the dataset by adding a new variable "Country" for each specific yield data based on longitude and latitude, and based on this dataset, <xarray.Dataset> Size: 37MB Dimensions: (year: 36, lat: 360, lon: 720) Coordinates: * year (year) int64 288B 1981 1982 1983 1984 1985 ... 2013 2014 2015 2016 * lat (lat) float64 3kB -89.75 -89.25 -88.75 -88.25 ... 88.75 89.25 89.75 * lon (lon) float64 6kB 0.25 0.75 1.25 1.75 ... 358.2 358.8 359.2 359.8 Data variables: Yield (year, lat, lon) float32 37MB nan nan nan nan ... nan nan nan nan ""
import matplotlib.pyplot as plt
import regionmask
countries = regionmask.defined_regions.natural_earth_v5_0_0.countries_110
mask = countries.mask(ds_all) # same lon/lat grid as ds_all
import numpy as np
import xarray as xr
# mask: region indices (same shape as ds_all over lat/lon, often (year, lat, lon) or (lat, lon))
region_idx = mask # DataArray
# Array of country names corresponding to region numbers 0..N-1
name_arr = np.array(countries.names, dtype=object)
# Create an empty array of strings with same shape as mask
country_arr = np.empty(region_idx.shape, dtype=object)
country_arr[:] = "Ocean" # or "None" for non-land cells
# regionmask uses NaN for ocean / outside regions
valid = ~np.isnan(region_idx.values)
country_arr[valid] = name_arr[region_idx.values[valid].astype(int)]
# Turn this into a DataArray with same coords/dims as mask
Country = xr.DataArray(
country_arr,
coords=region_idx.coords,
dims=region_idx.dims,
name="Country",
)
# Attach to your dataset
ds_all["Country"] = Country
print (ds_all)
<xarray.Dataset> Size: 39MB
Dimensions: (year: 36, lat: 360, lon: 720)
Coordinates:
* year (year) int64 288B 1981 1982 1983 1984 1985 ... 2013 2014 2015 2016
* lat (lat) float64 3kB -89.75 -89.25 -88.75 -88.25 ... 88.75 89.25 89.75
* lon (lon) float64 6kB 0.25 0.75 1.25 1.75 ... 358.2 358.8 359.2 359.8
Data variables:
Yield (year, lat, lon) float32 37MB nan nan nan nan ... nan nan nan nan
Country (lat, lon) object 2MB 'Antarctica' 'Antarctica' ... 'Ocean' 'Ocean'
# Turn full dataset into a DataFrame with country labels
df = ds_all[["Yield", "Country"]].to_dataframe().reset_index()
print(df)
df = df.dropna(subset=["Yield"])
df = df[df["Country"] != "Ocean"]
print(df)
df.to_parquet("/home/jovyan/work/jeogeorge/datasets/soybean_yield_country.parquet") #save parquet format
year lat lon Yield Country
0 1981 -89.75 0.25 NaN Antarctica
1 1981 -89.75 0.75 NaN Antarctica
2 1981 -89.75 1.25 NaN Antarctica
3 1981 -89.75 1.75 NaN Antarctica
4 1981 -89.75 2.25 NaN Antarctica
... ... ... ... ... ...
9331195 2016 89.75 357.75 NaN Ocean
9331196 2016 89.75 358.25 NaN Ocean
9331197 2016 89.75 358.75 NaN Ocean
9331198 2016 89.75 359.25 NaN Ocean
9331199 2016 89.75 359.75 NaN Ocean
[9331200 rows x 5 columns]
year lat lon Yield Country
102140 1981 -19.25 310.25 1.765330 Brazil
102853 1981 -18.75 306.75 1.765330 Brazil
131604 1981 1.25 282.25 1.520654 Colombia
131605 1981 1.25 282.75 1.250026 Colombia
131611 1981 1.25 285.75 1.898537 Colombia
... ... ... ... ... ...
9278164 2016 53.25 122.25 2.729482 China
9278165 2016 53.25 122.75 2.966493 China
9278166 2016 53.25 123.25 2.846184 China
9278167 2016 53.25 123.75 1.603321 China
9278168 2016 53.25 124.25 1.497798 China
[212570 rows x 5 columns]
Example of data distribution for a specific country¶
import matplotlib.pyplot as plt
import numpy as np
country_name = "United States of America"
sel = df[(df["Country"] == country_name) & df["Yield"].notna()]
groups = sel.groupby("year")
years = sorted(sel["year"].unique())
data_by_year = [groups.get_group(y)["Yield"].values for y in years]
plt.boxplot(
data_by_year,
positions=np.arange(len(years)),
showfliers=True,
)
plt.xticks(np.arange(len(years)), years, rotation=90)
plt.xlabel("Year")
plt.ylabel(r"Yield (t ha$^{-1}$)")
plt.title(f"Distribution of soybean yield in {country_name} by year")
plt.tight_layout()
plt.show()