Linear Regression of multivariate dataΒΆ

In this example, we demonstrate how to use sklearn_xarray classes to solve a simple linear regression problem on synthetic dataset.

This class demonstrates the use of Stacker and Select.

import numpy as np
import xarray as xr
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline, make_union

from sklearn_xarray import Stacker, Select

# Make synthetic data
lat, lon = np.ogrid[-45:45:50j, 0:360:100j]
noise = np.random.randn(lat.shape[0], lon.shape[1])

data_vars = {
    'a': (['lat', 'lon'], np.sin(lat/90 + lon/100)),
    'b': (['lat', 'lon'], np.cos(lat/90 + lon/100)),
    'noise': (['lat', 'lon'], noise)

coords = {'lat': lat.ravel(), 'lon': lon.ravel()}
dataset = xr.Dataset(data_vars, coords)

make a simple linear model for the output

\[y = a + .5 b + 1\]
x = dataset[['a', 'b']]
y = dataset.a + dataset.b * .5 + .3 * dataset.noise  + 1


now we want to fit a linear regression model using these data

mod = make_pipeline(
        make_pipeline(Select('a'), Stacker()),
        make_pipeline(Select('b'), Stacker())),

for now we have to use Stacker manually to transform the output data into a 2d array

y_np = Stacker().fit_transform(y)


<xarray.DataArray (samples: 5000, features: 1)>
array([[ 1.138895],
       [ 0.799281],
       [ 0.790091],
       [ 0.388912],
  * samples   (samples) MultiIndex
  - lat       (samples) float64 -45.0 -45.0 -45.0 -45.0 -45.0 -45.0 -45.0 ...
  - lon       (samples) float64 0.0 3.636 7.273 10.91 14.55 18.18 21.82 ...
  * features  (features) int64 1

fit the model, y_np)

# print the coefficients
lm = mod.named_steps['linearregression']
coefs = tuple(lm.coef_.flat)
print("The exact regression model is y = 1 + a + .5 b + noise")
print("The estimated coefficients are a: {}, b: {}".format(*coefs))
print("The estimated intercept is {}".format(lm.intercept_[0]))


The exact regression model is y = 1 + a + .5 b + noise
The estimated coefficients are a: 0.9826705586550489, b: 0.5070234156860342
The estimated intercept is 1.0154227436758414

Total running time of the script: ( 0 minutes 0.584 seconds)

Gallery generated by Sphinx-Gallery