Jeoffrey George - Fab Futures - Data Science
Home About

Week 2 - Fitting a function >

Week 1 - Tools/Plotting dataset¶

About the dataset¶

The global dataset of historical yields for soybean 1981–2016

LLM help¶

ChatGPT 5.1, prompts used

  • "I have a database split up into different NC4 files. I want to visualize this dataset using Python. How do I do it?"
  • "How to add for the plot "ds_all["Yield"].sel(year=2000).plot()" a world map underneath to visualize in which countries the Yield is."
  • "Rather than selecting a specific year, I want the plot to be interactive and be able above the graph to type the year I want to show"
  • "And for this graph "ts = ds_all["Yield"].sel(lat=45, lon=10, method="nearest") ts.plot()" I want to be able to choose the median per country rather than a specific longitude and latitude and therefore I want to be interactive and choose the country I want to select"

Requirements¶

  • xarray
  • netCDF4
  • cartopy
In [1]:
pip -q install xarray netCDF4 cartopy
Note: you may need to restart the kernel to use updated packages.

Dataset reading¶

In [2]:
import os #os.path to manipulate paths
import xarray as xr #for multidimensional labeled arrays and datasets, using dimensions, coordinates, and attributes on top of NumPy-like arrays

from pathlib import Path
directory = Path("/home/jovyan/work/jeogeorge/datasets/soybean_yield_1981-2016")
files = sorted(                         # Sort the resulting list alphabetically
    str(p)                              # Convert each Path object to a string
    for p in directory.iterdir()        # Iterate over all items in the directory
    if p.name.startswith("yield_")      # Keep only files whose names start with 'yield_'
    and p.name.endswith(".nc4")         # Keep only files whose names end with '.nc4'
)

# Extract years from filenames
years = [
    int(os.path.basename(f).split("_")[1].split(".")[0])
    for f in files
]
datasets = []
for f, y in zip(files, years):
    ds = xr.open_dataset(f)
    # add a year dimension
    ds = ds.expand_dims(year=[y])
    datasets.append(ds)

ds_all = xr.concat(datasets, dim="year")

# Rename variable and unit
ds_all = ds_all.rename({"var": "Yield"})
ds_all["Yield"].attrs["units"] = r"t.ha$^{-1}$"


print (ds_all)
<xarray.Dataset> Size: 37MB
Dimensions:  (year: 36, lat: 360, lon: 720)
Coordinates:
  * year     (year) int64 288B 1981 1982 1983 1984 1985 ... 2013 2014 2015 2016
  * lat      (lat) float64 3kB -89.75 -89.25 -88.75 -88.25 ... 88.75 89.25 89.75
  * lon      (lon) float64 6kB 0.25 0.75 1.25 1.75 ... 358.2 358.8 359.2 359.8
Data variables:
    Yield    (year, lat, lon) float32 37MB nan nan nan nan ... nan nan nan nan

Plotting a year of the dataset¶

In [3]:
ds_all["Yield"].sel(year=2000).plot()
Out[3]:
<matplotlib.collections.QuadMesh at 0xe297a33dbcb0>
No description has been provided for this image
In [4]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

# Data for one year
yield_2000 = ds_all["Yield"].sel(year=2000)

plt.figure(figsize=(10, 5))

# Plot data; PlateCarree - lat/lon projection of the data
ax = plt.axes(projection=ccrs.PlateCarree())
yield_2000.plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    cbar_kwargs={"label": "Yield (t.ha$^{-1}$)"}
)

# Add basemap features
ax.coastlines()
ax.set_global()
ax.set_title("Soybean Yield in 2000 (t.ha$^{-1}$)")

plt.show()
No description has been provided for this image

Integrate slider into plot to select year¶

In [5]:
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

years_avail = ds_all.year.values

def plot_year(year):
    data = ds_all["Yield"].sel(year=year)

    plt.figure()
    ax = plt.axes(projection=ccrs.PlateCarree())
    data.plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        add_colorbar=True,
        cbar_kwargs={"label": r"Yield (t ha$^{-1}$)"}
    )
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.set_global()
    ax.set_title(f"Soybean Yield in {year}")
    plt.show()

interact(
    plot_year,
    year=IntSlider(
        min=int(years_avail.min()),
        max=int(years_avail.max()),
        step=1,
        value=int(years_avail.min()),
        description="Year",
    )
)
interactive(children=(IntSlider(value=1981, description='Year', max=2016, min=1981), Output()), _dom_classes=(…
Out[5]:
<function __main__.plot_year(year)>

Chatgpt 5.1 prompt "Using matplolib animation, I want to modify this script to create a GIF file showing all the years from 1981 to 2016"

In [15]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cartopy.crs as ccrs
years = range(1981, 2016)
fig = plt.figure(figsize=(10, 5))
ax = plt.axes(projection=ccrs.PlateCarree())
im = ds_all["Yield"].sel(year=years[0]).plot(
    ax=ax,
    transform=ccrs.PlateCarree(),
    cmap="viridis",
    add_colorbar=False
)
cbar = plt.colorbar(im, ax=ax, orientation="vertical", pad=0.02)
cbar.set_label("Yield (t.ha$^{-1}$)")
ax.coastlines()
ax.set_global()
def update(year):
    """Update plot for each frame"""
    ax.clear()

    data = ds_all["Yield"].sel(year=year)
    data.plot(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap="viridis",
        add_colorbar=False
    )

    ax.coastlines()
    ax.set_global()
    ax.set_title(f"Soybean Yield in {year} (t.ha$^{-1}$)")

    return ax,
ani = animation.FuncAnimation(
    fig,
    update,
    frames=years,
    interval=500  # milliseconds between frames
)
plt.show()
plt.close()
No description has been provided for this image

soybean_yield_1981_2016.gif

Plot yield data for a specific location¶

In [7]:
ts = ds_all["Yield"].sel(lat=45, lon=10, method="nearest")
ts.plot.scatter()
Out[7]:
<matplotlib.collections.PathCollection at 0xe297a3425010>
No description has been provided for this image

Plot yield data for a chosen country through an interactive plot¶

Using cartopy and shapely libraries¶

In [8]:
import cartopy.io.shapereader as shpreader
import numpy as np
from shapely.geometry import Point
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt

# 1. Load country boundaries from Natural Earth

# Load shapefile for country polygons (admin level 0)
shp_path = shpreader.natural_earth(
    resolution='110m',
    category='cultural',
    name='admin_0_countries'
)

# Read shapefile records
reader = shpreader.Reader(shp_path)
records = list(reader.records())

# Extract and sort country names
countries = sorted(rec.attributes["NAME"] for rec in records)

# 2. Prepare latitude–longitude grid from dataset

lats = ds_all["lat"].values
lons = ds_all["lon"].values

# Create 2D lat/lon grid for point-in-polygon testing
lon2d, lat2d = np.meshgrid(lons, lats)

# 3. Build a mask of grid cells inside a given country
def country_mask(name):
    # find the matching country record
    rec = next(r for r in records if r.attributes["NAME"] == name)
    geom = rec.geometry
    
    # For precision, handle multipolygons
    if geom.geom_type == 'MultiPolygon':
        polys = list(geom.geoms)
    else:
        polys = [geom]
    
    mask = np.zeros(lon2d.shape, dtype=bool)
    for poly in polys:
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                mask[i, j] |= poly.contains(Point(lon2d[i, j], lat2d[i, j]))
    return mask

# 4. Compute a country-level time series by median aggregation

def country_timeseries(name):
    """
    Returns the median yield over all grid cells inside
    the selected country for each time step.
    """
    mask = country_mask(name)
    data = ds_all["Yield"].where(mask)
    return data.median(dim=("lat", "lon"))

# 5. Interactive widget to plot individual data points

@interact(
    country=widgets.Dropdown(
        options=countries,
        value="China",
        description="Country"
    )
)
def plot_country(country):
    """
    Plots the country's median yield time series as individual points.
    """

    ts = country_timeseries(country)
    ts.plot.scatter()  # scatter instead of line

    plt.ylabel(r"Yield (t.ha$^{-1}$)")
    plt.title(f"Median Soybean yield in {country}")
    plt.show()
interactive(children=(Dropdown(description='Country', index=31, options=('Afghanistan', 'Albania', 'Algeria', …

Chatgpt 5.1 prompt "Rather than the median show all the data point of a country using boxplots"

In [9]:
# 4. Collect all yield values per year for a country
def country_yearly_values(name):
    """
    Returns a list of arrays:
    one array of all Yield values per year (inside the country).
    """
    mask = country_mask(name)
    data = ds_all["Yield"].where(mask)

    yearly_values = []
    for y in data["year"].values:
        vals = data.sel(year=y).values.flatten()
        vals = vals[~np.isnan(vals)]  # remove NaNs
        yearly_values.append(vals)

    return yearly_values, data["year"].values

# 5. Interactive boxplot

@interact(
    country=widgets.Dropdown(
        options=countries,
        value="China",
        description="Country",
    )
)
def plot_country(country):
    values, years = country_yearly_values(country)

    plt.boxplot(
        values,
        showfliers=True,
        medianprops=dict(color="black"),
    )

    plt.xticks(
        ticks=np.arange(1, len(years) + 1),
        labels=years,
        rotation=90,
    )

    plt.ylabel(r"Yield (t.ha$^{-1}$)")
    plt.xlabel("Year")
    plt.title(f"Distribution of Soybean Yield in {country}")
    plt.tight_layout()
    plt.show()
interactive(children=(Dropdown(description='Country', index=31, options=('Afghanistan', 'Albania', 'Algeria', …

Using regionmask¶

In [10]:
pip install -q regionmask
Note: you may need to restart the kernel to use updated packages.
In [11]:
import matplotlib.pyplot as plt
import regionmask
import ipywidgets as widgets
from ipywidgets import interact

# 1. Use regionmask's built-in Natural Earth countries
#    (no geopandas.datasets.get_path)
countries = regionmask.defined_regions.natural_earth_v5_0_0.countries_110

# 2. Create a mask on the xarray grid (dims: lat, lon)
#    For new regionmask versions: .mask(obj) or .mask(lon, lat)
mask = countries.mask(ds_all)   # ds_all must have lon/lat coords

# 3. Helper: compute time series for one country
def get_country_data(country_name):
    # Convert name → region index
    idx = countries.map_keys(country_name)      # region number
    country_mask = mask == idx                  # True where grid in that country
    data = ds_all["Yield"].where(country_mask)
    return data

# 4. Interactive dropdown
country_list = list(countries.names)            # list of country names
# (optional, to sort alphabetically)
country_list = sorted(country_list)

@interact(
    country=widgets.Dropdown(
        options=country_list,
        value="China",          # any default that exists in country_list
        description="Country",
    )
)
def plot_country_boxplot(country):
    data = get_country_data(country)

    # Convert to tidy dataframe
    df = (
        data
        .to_dataframe(name="Yield")
        .reset_index()
        .dropna()
    )
    df.boxplot(
        column="Yield",
        by="year",
        showfliers=True,   # show outliers 
        grid=False
    )
    plt.suptitle("")
    plt.title(f"Soybean yield distribution in {country}")
    plt.xlabel("Year")
    plt.ylabel(r"Yield (t.ha$^{-1}$)")
    plt.xticks(rotation=45)
    plt.show()
interactive(children=(Dropdown(description='Country', index=31, options=('Afghanistan', 'Albania', 'Algeria', …

Adding Country variable to original dataset¶

ChatGPT5.1 prompt "Based on the previous script, I want to edit the dataset by adding a new variable "Country" for each specific yield data based on longitude and latitude, and based on this dataset, <xarray.Dataset> Size: 37MB Dimensions: (year: 36, lat: 360, lon: 720) Coordinates: * year (year) int64 288B 1981 1982 1983 1984 1985 ... 2013 2014 2015 2016 * lat (lat) float64 3kB -89.75 -89.25 -88.75 -88.25 ... 88.75 89.25 89.75 * lon (lon) float64 6kB 0.25 0.75 1.25 1.75 ... 358.2 358.8 359.2 359.8 Data variables: Yield (year, lat, lon) float32 37MB nan nan nan nan ... nan nan nan nan ""

In [12]:
import matplotlib.pyplot as plt
import regionmask

countries = regionmask.defined_regions.natural_earth_v5_0_0.countries_110
mask = countries.mask(ds_all)   # same lon/lat grid as ds_all
import numpy as np
import xarray as xr

# mask: region indices (same shape as ds_all over lat/lon, often (year, lat, lon) or (lat, lon))
region_idx = mask  # DataArray

# Array of country names corresponding to region numbers 0..N-1
name_arr = np.array(countries.names, dtype=object)

# Create an empty array of strings with same shape as mask
country_arr = np.empty(region_idx.shape, dtype=object)
country_arr[:] = "Ocean"   # or "None" for non-land cells

# regionmask uses NaN for ocean / outside regions
valid = ~np.isnan(region_idx.values)
country_arr[valid] = name_arr[region_idx.values[valid].astype(int)]

# Turn this into a DataArray with same coords/dims as mask
Country = xr.DataArray(
    country_arr,
    coords=region_idx.coords,
    dims=region_idx.dims,
    name="Country",
)

# Attach to your dataset
ds_all["Country"] = Country

print (ds_all)
<xarray.Dataset> Size: 39MB
Dimensions:  (year: 36, lat: 360, lon: 720)
Coordinates:
  * year     (year) int64 288B 1981 1982 1983 1984 1985 ... 2013 2014 2015 2016
  * lat      (lat) float64 3kB -89.75 -89.25 -88.75 -88.25 ... 88.75 89.25 89.75
  * lon      (lon) float64 6kB 0.25 0.75 1.25 1.75 ... 358.2 358.8 359.2 359.8
Data variables:
    Yield    (year, lat, lon) float32 37MB nan nan nan nan ... nan nan nan nan
    Country  (lat, lon) object 2MB 'Antarctica' 'Antarctica' ... 'Ocean' 'Ocean'
In [13]:
# Turn full dataset into a DataFrame with country labels
df = ds_all[["Yield", "Country"]].to_dataframe().reset_index()
print(df)
df = df.dropna(subset=["Yield"])
df = df[df["Country"] != "Ocean"]
print(df)
df.to_parquet("/home/jovyan/work/jeogeorge/datasets/soybean_yield_country.parquet") #save parquet format
         year    lat     lon  Yield     Country
0        1981 -89.75    0.25    NaN  Antarctica
1        1981 -89.75    0.75    NaN  Antarctica
2        1981 -89.75    1.25    NaN  Antarctica
3        1981 -89.75    1.75    NaN  Antarctica
4        1981 -89.75    2.25    NaN  Antarctica
...       ...    ...     ...    ...         ...
9331195  2016  89.75  357.75    NaN       Ocean
9331196  2016  89.75  358.25    NaN       Ocean
9331197  2016  89.75  358.75    NaN       Ocean
9331198  2016  89.75  359.25    NaN       Ocean
9331199  2016  89.75  359.75    NaN       Ocean

[9331200 rows x 5 columns]
         year    lat     lon     Yield   Country
102140   1981 -19.25  310.25  1.765330    Brazil
102853   1981 -18.75  306.75  1.765330    Brazil
131604   1981   1.25  282.25  1.520654  Colombia
131605   1981   1.25  282.75  1.250026  Colombia
131611   1981   1.25  285.75  1.898537  Colombia
...       ...    ...     ...       ...       ...
9278164  2016  53.25  122.25  2.729482     China
9278165  2016  53.25  122.75  2.966493     China
9278166  2016  53.25  123.25  2.846184     China
9278167  2016  53.25  123.75  1.603321     China
9278168  2016  53.25  124.25  1.497798     China

[212570 rows x 5 columns]

Example of data distribution for a specific country¶

In [14]:
import matplotlib.pyplot as plt
import numpy as np
country_name = "United States of America"
sel = df[(df["Country"] == country_name) & df["Yield"].notna()]

groups = sel.groupby("year")  
years = sorted(sel["year"].unique())
data_by_year = [groups.get_group(y)["Yield"].values for y in years]

plt.boxplot(
    data_by_year,
    positions=np.arange(len(years)),
    showfliers=True,  
)

plt.xticks(np.arange(len(years)), years, rotation=90)
plt.xlabel("Year")
plt.ylabel(r"Yield (t ha$^{-1}$)")
plt.title(f"Distribution of soybean yield in {country_name} by year")
plt.tight_layout()
plt.show()
No description has been provided for this image