Tutorial for Mixture of Expert (MoE) Forecasting Model

This notebook provides a minimal example on how to use the MoE forecasting model.

MoE runs in 2 settings: 1. Using external expert models 2. Using free parameters (no external experts)

Example codes are provided for both cases below.

[1]:
# workaround to enable info-level logging in Jupyter notebook
%config Application.log_level='WORKAROUND'
%config Application.log_level='INFO'
import logging
logging.getLogger().setLevel(logging.INFO)
ERROR:root:The 'log_level' trait of an IPKernelApp instance expected any of [0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL'], not the str 'WORKAROUND'.

Load dataset

Note: change data dir below if inappropriate

[2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from ts_datasets.forecast import *

time_series, metadata = M4()[0]
print(time_series.shape)
INFO:numexpr.utils:Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
100%|██████████| 414/414 [00:00<00:00, 610.42it/s]
(748, 1)

Now, we’ll split the data into train & test splits. Visualize the 0th dim of the data.

[3]:
from merlion.utils import TimeSeries

train_data = TimeSeries.from_pd(time_series[metadata["trainval"]])
test_data  = TimeSeries.from_pd(time_series[~metadata["trainval"]])


print('train timeseries shape: ', train_data.to_pd().values.shape)
print('test timeseries shape: ', test_data.to_pd().values.shape)

column_names = list(train_data.to_pd().columns)
idx=0
tr = train_data.to_pd()[column_names[idx]]
plt.plot(tr)
plt.show()

te = test_data.to_pd()[column_names[idx]]
plt.plot(te)
train timeseries shape:  (700, 1)
test timeseries shape:  (48, 1)
../../_images/examples_advanced_2_MoE_Forecasting_tutorial_5_1.png
[3]:
[<matplotlib.lines.Line2D at 0x7fda90599730>]
../../_images/examples_advanced_2_MoE_Forecasting_tutorial_5_3.png

Create MoE model composed of external expert models and train

Specify hyper-parameters

[4]:
# save directory for ensemble state. Replace it with your own choice.
save_dir = 'models/moe'

###
nfree_experts=0 # <- no free parameters provided
lookback_len=20
max_forecast_steps=3
target_seq_index = 0
use_gpu=False
###


## Pytorch network hyper-params. These are the also the hyper-params that are used in case moe_model=None is passed to MoE_ForecasterEnsemble.
hidden_dim=256
dim_head = 2
mlp_dim=256
dim_dropout=0. # if data is multi-dimensionsal, this can be set to a non-zero value to allow model to handle missing dimensions during test time
time_step_dropout=0
## Pytorch network hyper-params

Create expert models and MoE ensembler and train

[5]:
from merlion.models.factory import ModelFactory
from merlion.models.ensemble.MoE_forecast import MoE_ForecasterEnsemble, MoE_ForecasterEnsembleConfig, TransformerModel
from merlion.models.ensemble.base import EnsembleTrainConfig
from merlion.transform.base import Identity
from merlion.transform.resample import TemporalResample


## Define configs for all the experts as well as the MoE ensembler
conf_sarima =  {
        "order": [15, 1, 5],
        "seasonal_order": [2, 0, 1, 24],
        "max_forecast_steps": max_forecast_steps,
        "target_seq_index":target_seq_index,
        "transform": Identity()
        }

config_arima = {
        "order": [15, 1, 5],
        "max_forecast_steps": max_forecast_steps,
        "target_seq_index": target_seq_index,
        "transform": Identity()
    }

config_vector_ar = {"max_forecast_steps": max_forecast_steps, "target_seq_index": target_seq_index, "maxlags": 14}

config_ensemble = MoE_ForecasterEnsembleConfig(batch_size=64, lr=0.0001,\
                                            nfree_experts=nfree_experts, epoch_max=100,\
                                            lookback_len=lookback_len,\
                                            max_forecast_steps=max_forecast_steps,\
                                            target_seq_index=target_seq_index,
                                            use_gpu=use_gpu,\
                                            transform=TemporalResample())

train_config_ensemble = EnsembleTrainConfig(valid_frac=0.5)

# Define expert models
model = ModelFactory.create("Sarima", **conf_sarima)
model2 = ModelFactory.create("Arima", **config_arima)
model3 = ModelFactory.create("VectorAR", **config_vector_ar)

models = [model, model2, model3]
nexperts = len(models)

'''
Instantiate deep network for MoE. It can also be instantiated as None. In that case, the default Pytorch network
specified in the MoE_ForecasterEnsemble class will be used. FYI, the network below is used as the default network in
MoE_ForecasterEnsemble.

'''
moe_model = TransformerModel(input_dim=train_data.dim, lookback_len=lookback_len, nexperts=nexperts,\
                    output_dim=max_forecast_steps, nfree_experts=nfree_experts,\
                    hid_dim=hidden_dim, dim_head = dim_head, mlp_dim=mlp_dim,\
                     pool='cls', dim_dropout=dim_dropout,\
                    time_step_dropout=time_step_dropout)
# moe_model = None # use me if you want to see the default model in use

# create MoE forecaster model
ensemble = MoE_ForecasterEnsemble(config=config_ensemble, models= models, moe_model=moe_model)

# train & save MoE
loss_list = ensemble.train(train_data=train_data, train_config = train_config_ensemble)
ensemble.save(save_dir)
INFO:merlion.models.ensemble.MoE_forecast:Training model 1/3...
INFO:merlion.models.ensemble.MoE_forecast:Training model 2/3...
INFO:merlion.models.ensemble.MoE_forecast:Training model 3/3...
INFO:merlion.models.ensemble.MoE_forecast:Extracting and storing expert predictions
  0%|          | 0/6 [00:00<?, ?it/s]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions...
 17%|█▋        | 1/6 [00:06<00:33,  6.78s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions...
 33%|███▎      | 2/6 [00:10<00:19,  4.91s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions...
 50%|█████     | 3/6 [00:14<00:13,  4.52s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions...
 67%|██████▋   | 4/6 [00:17<00:08,  4.11s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions...
 83%|████████▎ | 5/6 [00:22<00:04,  4.34s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions...
INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions...
100%|██████████| 6/6 [00:23<00:00,  3.91s/it]
Epoch 1 Loss: 1.304425: 100%|██████████| 6/6 [00:03<00:00,  1.92it/s]
Epoch 2 Loss: 1.326118: 100%|██████████| 6/6 [00:03<00:00,  1.85it/s]
Epoch 3 Loss: 1.285732: 100%|██████████| 6/6 [00:03<00:00,  1.61it/s]
Epoch 4 Loss: 1.239944: 100%|██████████| 6/6 [00:03<00:00,  1.86it/s]
Epoch 5 Loss: 1.217539: 100%|██████████| 6/6 [00:03<00:00,  1.92it/s]
Epoch 6 Loss: 1.194959: 100%|██████████| 6/6 [00:04<00:00,  1.43it/s]
Epoch 7 Loss: 1.182055: 100%|██████████| 6/6 [00:03<00:00,  1.86it/s]
Epoch 8 Loss: 1.175249: 100%|██████████| 6/6 [00:02<00:00,  2.15it/s]
Epoch 9 Loss: 1.173702: 100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
Epoch 10 Loss: 1.186111: 100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
Epoch 11 Loss: 1.184515: 100%|██████████| 6/6 [00:02<00:00,  2.08it/s]
Epoch 12 Loss: 1.169571: 100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
Epoch 13 Loss: 1.174548: 100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
Epoch 14 Loss: 1.180143: 100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
Epoch 15 Loss: 1.186909: 100%|██████████| 6/6 [00:03<00:00,  1.95it/s]
Epoch 16 Loss: 1.175635: 100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
Epoch 17 Loss: 1.189920: 100%|██████████| 6/6 [00:02<00:00,  2.23it/s]
Epoch 18 Loss: 1.182831: 100%|██████████| 6/6 [00:03<00:00,  1.88it/s]
Epoch 19 Loss: 1.176439: 100%|██████████| 6/6 [00:03<00:00,  1.90it/s]
Epoch 20 Loss: 1.174525: 100%|██████████| 6/6 [00:02<00:00,  2.06it/s]
Epoch 21 Loss: 1.172877: 100%|██████████| 6/6 [00:02<00:00,  2.07it/s]
Epoch 22 Loss: 1.182378: 100%|██████████| 6/6 [00:03<00:00,  1.80it/s]
Epoch 23 Loss: 1.165431: 100%|██████████| 6/6 [00:02<00:00,  2.23it/s]
Epoch 24 Loss: 1.178699: 100%|██████████| 6/6 [00:02<00:00,  2.48it/s]
Epoch 25 Loss: 1.179477: 100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
Epoch 26 Loss: 1.174176: 100%|██████████| 6/6 [00:02<00:00,  2.10it/s]
Epoch 27 Loss: 1.187021: 100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
Epoch 28 Loss: 1.175006: 100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
Epoch 29 Loss: 1.179644: 100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
Epoch 30 Loss: 1.172667: 100%|██████████| 6/6 [00:02<00:00,  2.01it/s]
Epoch 31 Loss: 1.166672: 100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
Epoch 32 Loss: 1.180196: 100%|██████████| 6/6 [00:02<00:00,  2.21it/s]
Epoch 33 Loss: 1.180659: 100%|██████████| 6/6 [00:02<00:00,  2.03it/s]
Epoch 34 Loss: 1.172211: 100%|██████████| 6/6 [00:02<00:00,  2.04it/s]
Epoch 35 Loss: 1.179359: 100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
Epoch 36 Loss: 1.171128: 100%|██████████| 6/6 [00:03<00:00,  1.97it/s]
Epoch 37 Loss: 1.175621: 100%|██████████| 6/6 [00:03<00:00,  1.65it/s]
Epoch 38 Loss: 1.178169: 100%|██████████| 6/6 [00:02<00:00,  2.02it/s]
Epoch 39 Loss: 1.171359: 100%|██████████| 6/6 [00:03<00:00,  1.94it/s]
Epoch 40 Loss: 1.178329: 100%|██████████| 6/6 [00:02<00:00,  2.09it/s]
Epoch 41 Loss: 1.183362: 100%|██████████| 6/6 [00:03<00:00,  1.99it/s]
Epoch 42 Loss: 1.192560: 100%|██████████| 6/6 [00:02<00:00,  2.57it/s]
Epoch 43 Loss: 1.180345: 100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
Epoch 44 Loss: 1.173638: 100%|██████████| 6/6 [00:02<00:00,  2.13it/s]
Epoch 45 Loss: 1.152926: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
Epoch 46 Loss: 1.173875: 100%|██████████| 6/6 [00:02<00:00,  2.49it/s]
Epoch 47 Loss: 1.164550: 100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
Epoch 48 Loss: 1.159929: 100%|██████████| 6/6 [00:03<00:00,  1.90it/s]
Epoch 49 Loss: 1.153424: 100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
Epoch 50 Loss: 1.154480: 100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
Epoch 51 Loss: 1.140436: 100%|██████████| 6/6 [00:02<00:00,  2.04it/s]
Epoch 52 Loss: 1.146496: 100%|██████████| 6/6 [00:03<00:00,  1.67it/s]
Epoch 53 Loss: 1.103287: 100%|██████████| 6/6 [00:03<00:00,  1.95it/s]
Epoch 54 Loss: 1.126842: 100%|██████████| 6/6 [00:03<00:00,  1.91it/s]
Epoch 55 Loss: 1.090270: 100%|██████████| 6/6 [00:03<00:00,  1.89it/s]
Epoch 56 Loss: 1.083246: 100%|██████████| 6/6 [00:02<00:00,  2.12it/s]
Epoch 57 Loss: 1.103495: 100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
Epoch 58 Loss: 1.081064: 100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
Epoch 59 Loss: 1.031706: 100%|██████████| 6/6 [00:02<00:00,  2.14it/s]
Epoch 60 Loss: 1.028626: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
Epoch 61 Loss: 1.070291: 100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
Epoch 62 Loss: 1.070380: 100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
Epoch 63 Loss: 1.057747: 100%|██████████| 6/6 [00:03<00:00,  1.96it/s]
Epoch 64 Loss: 1.022432: 100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
Epoch 65 Loss: 1.092833: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
Epoch 66 Loss: 1.046667: 100%|██████████| 6/6 [00:02<00:00,  2.01it/s]
Epoch 67 Loss: 1.024409: 100%|██████████| 6/6 [00:03<00:00,  1.95it/s]
Epoch 68 Loss: 1.006745: 100%|██████████| 6/6 [00:02<00:00,  2.12it/s]
Epoch 69 Loss: 1.026753: 100%|██████████| 6/6 [00:03<00:00,  1.99it/s]
Epoch 70 Loss: 1.035914: 100%|██████████| 6/6 [00:03<00:00,  1.78it/s]
Epoch 71 Loss: 1.011550: 100%|██████████| 6/6 [00:03<00:00,  1.76it/s]
Epoch 72 Loss: 1.025698: 100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
Epoch 73 Loss: 1.021778: 100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
Epoch 74 Loss: 0.999083: 100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
Epoch 75 Loss: 1.003671: 100%|██████████| 6/6 [00:02<00:00,  2.02it/s]
Epoch 76 Loss: 1.007998: 100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
Epoch 77 Loss: 0.992967: 100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
Epoch 78 Loss: 1.039574: 100%|██████████| 6/6 [00:02<00:00,  2.33it/s]
Epoch 79 Loss: 0.979768: 100%|██████████| 6/6 [00:03<00:00,  1.85it/s]
Epoch 80 Loss: 0.980897: 100%|██████████| 6/6 [00:02<00:00,  2.22it/s]
Epoch 81 Loss: 1.021148: 100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
Epoch 82 Loss: 0.964973: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
Epoch 83 Loss: 0.970942: 100%|██████████| 6/6 [00:03<00:00,  1.83it/s]
Epoch 84 Loss: 0.996262: 100%|██████████| 6/6 [00:03<00:00,  1.74it/s]
Epoch 85 Loss: 0.961435: 100%|██████████| 6/6 [00:02<00:00,  2.04it/s]
Epoch 86 Loss: 0.975354: 100%|██████████| 6/6 [00:02<00:00,  2.07it/s]
Epoch 87 Loss: 0.980454: 100%|██████████| 6/6 [00:02<00:00,  2.06it/s]
Epoch 88 Loss: 0.957371: 100%|██████████| 6/6 [00:02<00:00,  2.04it/s]
Epoch 89 Loss: 0.940112: 100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
Epoch 90 Loss: 0.939432: 100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
Epoch 91 Loss: 0.952867: 100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
Epoch 92 Loss: 0.973232: 100%|██████████| 6/6 [00:02<00:00,  2.08it/s]
Epoch 93 Loss: 0.988913: 100%|██████████| 6/6 [00:02<00:00,  2.53it/s]
Epoch 94 Loss: 0.954114: 100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
Epoch 95 Loss: 0.906431: 100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
Epoch 96 Loss: 0.909996: 100%|██████████| 6/6 [00:03<00:00,  1.94it/s]
Epoch 97 Loss: 0.945528: 100%|██████████| 6/6 [00:02<00:00,  2.18it/s]
Epoch 98 Loss: 0.922009: 100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
Epoch 99 Loss: 0.920070: 100%|██████████| 6/6 [00:02<00:00,  2.21it/s]
Epoch 100 Loss: 0.912859: 100%|██████████| 6/6 [00:03<00:00,  1.72it/s]

Load the saved ensemble model

[6]:
ensemble_loaded = MoE_ForecasterEnsemble.load(save_dir)
WARNING:merlion.models.ensemble.base:When initializing an ensemble, you must either provide the dict `model_configs` (mapping each model's name to its config) when creating the `DetectorEnsembleConfig`, or provide a list of `models` to the constructor of `EnsembleBase`. Received both. Overriding `model_configs` with the configs belonging to `models`.

Forecast using the loaded model

[7]:
lookback_len=20
forecast_len=3

sample_length = lookback_len + forecast_len

target_seq_index=0
start_idx = 0
end_idx = sample_length

timestamps = test_data.univariates[test_data.names[0]].time_stamps[start_idx: end_idx]
data = test_data.to_pd().values
data = data[start_idx: end_idx]

timestamps = timestamps[lookback_len:]
x = data[:lookback_len]
x_ts = test_data[start_idx: start_idx+ lookback_len]
y = data[lookback_len:, target_seq_index]

print('True output:\n')
print(y)

# perform single forecast
print('Performing single forecast:\n')
forecast, se = ensemble_loaded.forecast(time_stamps=timestamps,
                 time_series_prev=x_ts, expert_idx=None, mode='max', use_gpu=False)

print('Forecast\n', forecast)
print('Standard Error\n', se)

# perform batch forecast (for simplicity, just feeding a list of single sample)
print('\n\nPerforming batch forecast (notice the output is a list):\n')
forecast, se = ensemble_loaded.batch_forecast(time_stamps_list=[timestamps],
                 time_series_prev_list = [x_ts], expert_idx=None, mode='max', use_gpu=False)

print('Forecasts\n', forecast)
print('Standard Errors\n', se)
True output:

[803. 769. 751.]
Performing single forecast:

Forecast
                              H1
1677-10-28 00:00:00  810.339722
1677-10-28 01:00:00  796.296448
1677-10-28 02:00:00  760.550842
Standard Error
                          H1_err
1677-10-28 00:00:00   42.765224
1677-10-28 01:00:00  198.829971
1677-10-28 02:00:00  350.278381


Performing batch forecast (notice the output is a list):

Forecasts
 [                           H1_0
1677-10-28 00:00:00  810.339722
1677-10-28 01:00:00  796.296448
1677-10-28 02:00:00  760.550842]
Standard Errors
 [                       H1_err_0
1677-10-28 00:00:00   42.765224
1677-10-28 01:00:00  198.829971
1677-10-28 02:00:00  350.278381]

Retrieve forecasts of individual experts along with their confidence from the loaded model

[8]:
# perform forecast at the beginning of the test_data timestamp of length 3 and lookback=20
lookback_len=20
forecast_len=3

sample_length = lookback_len + forecast_len

target_seq_index=0
start_idx = 0
end_idx = sample_length

timestamps = test_data.univariates[test_data.names[0]].time_stamps[start_idx: end_idx]
data = test_data.to_pd().values
data = data[start_idx: end_idx]

timestamps = timestamps[lookback_len:]
x = data[:lookback_len] # shape (20,1)
x_ts = test_data[start_idx: start_idx+ lookback_len]
y = data[lookback_len:, target_seq_index]

print('True output:\n')
print(y)

# perform single forecast
print('Getting individual expert forecast and standard deviation for single data (notice the array shape):\n')
forecast, std = ensemble_loaded._forecast(time_stamps=timestamps,
                 time_series_prev=x_ts, expert_idx=None, use_gpu=False)

print(f'Forecast (shape: {forecast.shape})\n', forecast)
print(f'Standard deviation (shape: {std.shape})\n', std)

# perform batch forecast (for simplicity, just feeding a list of single sample)
print('\n\nGetting individual expert forecast and standard deviation for a batch of data (notice the array shape):\n')
forecast, std = ensemble_loaded._batch_forecast(time_stamps_list=[timestamps],
                 time_series_prev_array=np.expand_dims(x, axis=0), # shape (1,20,1)
                 time_series_prev_list = [x_ts], expert_idx=None, use_gpu=False)

print(f'Forecast (shape: {forecast.shape})\n', forecast)
print(f'Standard deviation (shape: {std.shape})\n', std)
True output:

[803. 769. 751.]
Getting individual expert forecast and standard deviation for single data (notice the array shape):

Forecast (shape: (3, 3))
 [[828.93043244 796.29646378 760.5508621 ]
 [812.10700752 760.22152836 709.16665671]
 [810.33973458 757.56872591 700.9503937 ]]
Standard deviation (shape: (3, 3))
 [[0.361902   0.47081903 0.7306497 ]
 [0.26711744 0.3186435  0.14048512]
 [0.3709806  0.21053748 0.12886517]]


Getting individual expert forecast and standard deviation for a batch of data (notice the array shape):

Forecast (shape: (1, 3, 3))
 [[[828.93043244 796.29646378 760.5508621 ]
  [812.10700752 760.22152836 709.16665671]
  [810.33973458 757.56872591 700.9503937 ]]]
Standard deviation (shape: (1, 3, 3))
 [[[0.361902   0.47081903 0.7306497 ]
  [0.26711744 0.3186435  0.14048512]
  [0.3709806  0.21053748 0.12886517]]]

Evaluate MoE

The code below shows how to evaluate MoE.

[9]:


expert_idx=None
# if expert_idx=None, MoE uses all the experts provided and uses the 'mode' strategy specified below to forecast
# if value is int (E.g. 0), MoE only uses the external expert at the corresponding index of `models` to make forecasts
mode='max' # either mean or max. Max picks the expert with the highest confidence; mean computes the weighted average.
use_gpu=False # set True if GPU available for faster speed
use_batch_forecast=True # set True for higher speed

y_pred_list, std_list, y_list, sMAPE_conf, sMAPE_not_conf, recall, overall_sMAPE =\
                ensemble_loaded.evaluate(test_data, mode=mode, expert_idx=expert_idx,\
                                         use_gpu=use_gpu, use_batch_forecast=use_batch_forecast, confidence_thres=100)

out_idx=0 # plot this idx of all the steps forecasted by MoE
print(y_pred_list.shape)
plt.plot(y_pred_list[:100, out_idx], '--', color='k', label='prediction', linewidth=1) # plotting 1st 100 for clarity
plt.plot(y_list[:100, out_idx], color='b', label='data', linewidth=1)
# plt.fill_between(range(y_pred_list[:100, out_idx].shape[0]), y_pred_list[:100, out_idx]-std_list[:100, out_idx],\
#                  y_pred_list[:100, out_idx]+std_list[:100, out_idx]) # standard deviation error band
plt.legend()
plt.show()

print(f'sMAPE on confident samples: {sMAPE_conf:.2f}')
print(f'sMAPE on not confident samples: {sMAPE_not_conf:.2f}')
print(f'Percentage of samples on which MoE was confident: {recall:.2f}%  (use a different confidence_thres to change this)')
print(f'sMAPE on all samples: {overall_sMAPE:.2f}')

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

sMAPE_conf: 2.184 sMAPE_not_conf: 0.000 recall: 100.000% | Plain sMAPE 2.184:   0%|          | 0/1 [00:01<?, ?it/s]
sMAPE_conf: 2.184 sMAPE_not_conf: 0.000 recall: 100.000% | Plain sMAPE 2.184: 100%|██████████| 1/1 [00:01<00:00,  1.71s/it]
torch.Size([25, 3])
../../_images/examples_advanced_2_MoE_Forecasting_tutorial_18_2.png
sMAPE on confident samples: 2.18
sMAPE on not confident samples: 0.00
Percentage of samples on which MoE was confident: 100.00%  (use a different confidence_thres to change this)
sMAPE on all samples: 2.18

Create MoE model containing free parameters (no external experts) and train

Specify hyper-parameters

[10]:
import os
# save directory for ensemble state. Replace it with your own choice.
save_dir = 'models/moe2'

###
nfree_experts= 3000 # <- number of free experts
lookback_len=20
max_forecast_steps=3
target_seq_index = 0
use_gpu=False
###


## Pytorch network hyper-params. These are the also the hyper-params that are used in case moe_model=None is passed to MoE_ForecasterEnsemble.
hidden_dim=256
dim_head = 2
mlp_dim=256
dim_dropout=0. # if data is multi-dimensionsal, this can be set to a non-zero value to allow model to handle missing dimensions during test time
time_step_dropout=0
## Pytorch network hyper-params

Create MoE ensembler and train

[11]:
config_ensemble = MoE_ForecasterEnsembleConfig(
    batch_size=64, lr=0.0001, nfree_experts=nfree_experts, epoch_max=300,
    lookback_len=lookback_len, max_forecast_steps=max_forecast_steps,
    target_seq_index=target_seq_index, use_gpu=use_gpu,
    transform=TemporalResample())

train_config_ensemble = EnsembleTrainConfig(valid_frac=0.5)

# Define expert models
models = [] # <- no external experts provided
nexperts = len(models)

# instantiate deep network for MoE
moe_model = TransformerModel(input_dim=len(train_data.names), lookback_len=lookback_len, nexperts=nexperts,\
                    output_dim=max_forecast_steps, nfree_experts=nfree_experts,\
                    hid_dim=hidden_dim, dim_head = dim_head, mlp_dim=mlp_dim,\
                     pool='cls', dim_dropout=dim_dropout,\
                    time_step_dropout=time_step_dropout)
moe_model = None # use me if you want to see the default model in use

# create MoE forecaster model
ensemble = MoE_ForecasterEnsemble(config=config_ensemble, models= models, moe_model=moe_model)

# train MoE
loss_list = ensemble.train(train_data=train_data, train_config = train_config_ensemble)

ensemble.save(save_dir)
Epoch 1 Loss: 8.253956: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s]
Epoch 2 Loss: 8.215341: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
Epoch 3 Loss: 8.140997: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
Epoch 4 Loss: 8.030143: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
Epoch 5 Loss: 7.880345: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
Epoch 6 Loss: 7.679686: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
Epoch 7 Loss: 7.431627: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
Epoch 8 Loss: 7.121155: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
Epoch 9 Loss: 6.788598: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s]
Epoch 10 Loss: 6.469567: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s]
Epoch 11 Loss: 6.230416: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
Epoch 12 Loss: 6.048095: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
Epoch 13 Loss: 5.915649: 100%|██████████| 11/11 [00:03<00:00,  2.79it/s]
Epoch 14 Loss: 5.811184: 100%|██████████| 11/11 [00:04<00:00,  2.69it/s]
Epoch 15 Loss: 5.744193: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
Epoch 16 Loss: 5.686575: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 17 Loss: 5.657169: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
Epoch 18 Loss: 5.635586: 100%|██████████| 11/11 [00:03<00:00,  3.36it/s]
Epoch 19 Loss: 5.627381: 100%|██████████| 11/11 [00:03<00:00,  3.62it/s]
Epoch 20 Loss: 5.592227: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
Epoch 21 Loss: 5.565175: 100%|██████████| 11/11 [00:03<00:00,  3.18it/s]
Epoch 22 Loss: 5.561776: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
Epoch 23 Loss: 5.541681: 100%|██████████| 11/11 [00:03<00:00,  3.61it/s]
Epoch 24 Loss: 5.545103: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
Epoch 25 Loss: 5.522188: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
Epoch 26 Loss: 5.508581: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
Epoch 27 Loss: 5.489652: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
Epoch 28 Loss: 5.476248: 100%|██████████| 11/11 [00:04<00:00,  2.51it/s]
Epoch 29 Loss: 5.466399: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
Epoch 30 Loss: 5.475472: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
Epoch 31 Loss: 5.461387: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
Epoch 32 Loss: 5.433480: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
Epoch 33 Loss: 5.417164: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
Epoch 34 Loss: 5.391122: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
Epoch 35 Loss: 5.355259: 100%|██████████| 11/11 [00:03<00:00,  3.24it/s]
Epoch 36 Loss: 5.324664: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
Epoch 37 Loss: 5.289907: 100%|██████████| 11/11 [00:03<00:00,  3.53it/s]
Epoch 38 Loss: 5.252770: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
Epoch 39 Loss: 5.224870: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
Epoch 40 Loss: 5.189947: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
Epoch 41 Loss: 5.141784: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
Epoch 42 Loss: 5.089564: 100%|██████████| 11/11 [00:03<00:00,  2.83it/s]
Epoch 43 Loss: 5.053859: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
Epoch 44 Loss: 5.017187: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
Epoch 45 Loss: 4.980298: 100%|██████████| 11/11 [00:03<00:00,  2.88it/s]
Epoch 46 Loss: 4.946596: 100%|██████████| 11/11 [00:03<00:00,  2.99it/s]
Epoch 47 Loss: 4.913864: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
Epoch 48 Loss: 4.886967: 100%|██████████| 11/11 [00:03<00:00,  3.64it/s]
Epoch 49 Loss: 4.860845: 100%|██████████| 11/11 [00:03<00:00,  3.38it/s]
Epoch 50 Loss: 4.831104: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 51 Loss: 4.816362: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
Epoch 52 Loss: 4.779174: 100%|██████████| 11/11 [00:03<00:00,  3.39it/s]
Epoch 53 Loss: 4.736116: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
Epoch 54 Loss: 4.717974: 100%|██████████| 11/11 [00:03<00:00,  3.40it/s]
Epoch 55 Loss: 4.690528: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
Epoch 56 Loss: 4.657373: 100%|██████████| 11/11 [00:04<00:00,  2.68it/s]
Epoch 57 Loss: 4.630430: 100%|██████████| 11/11 [00:04<00:00,  2.64it/s]
Epoch 58 Loss: 4.626685: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s]
Epoch 59 Loss: 4.576031: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
Epoch 60 Loss: 4.536126: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
Epoch 61 Loss: 4.522658: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
Epoch 62 Loss: 4.482119: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
Epoch 63 Loss: 4.440992: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
Epoch 64 Loss: 4.420945: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s]
Epoch 65 Loss: 4.380464: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
Epoch 66 Loss: 4.354569: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
Epoch 67 Loss: 4.330108: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
Epoch 68 Loss: 4.302790: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 69 Loss: 4.276388: 100%|██████████| 11/11 [00:03<00:00,  3.63it/s]
Epoch 70 Loss: 4.252587: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
Epoch 71 Loss: 4.220399: 100%|██████████| 11/11 [00:03<00:00,  2.80it/s]
Epoch 72 Loss: 4.197742: 100%|██████████| 11/11 [00:03<00:00,  2.84it/s]
Epoch 73 Loss: 4.184703: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s]
Epoch 74 Loss: 4.177047: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
Epoch 75 Loss: 4.132162: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
Epoch 76 Loss: 4.098515: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
Epoch 77 Loss: 4.085912: 100%|██████████| 11/11 [00:03<00:00,  3.63it/s]
Epoch 78 Loss: 4.072778: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
Epoch 79 Loss: 4.026296: 100%|██████████| 11/11 [00:03<00:00,  3.13it/s]
Epoch 80 Loss: 4.000065: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
Epoch 81 Loss: 3.971918: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
Epoch 82 Loss: 3.954153: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
Epoch 83 Loss: 3.924814: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
Epoch 84 Loss: 3.895262: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
Epoch 85 Loss: 3.874157: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
Epoch 86 Loss: 3.862910: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
Epoch 87 Loss: 3.845997: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
Epoch 88 Loss: 3.811604: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 89 Loss: 3.794181: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s]
Epoch 90 Loss: 3.760030: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
Epoch 91 Loss: 3.738671: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
Epoch 92 Loss: 3.720268: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
Epoch 93 Loss: 3.713584: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
Epoch 94 Loss: 3.671633: 100%|██████████| 11/11 [00:03<00:00,  3.62it/s]
Epoch 95 Loss: 3.651652: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
Epoch 96 Loss: 3.639432: 100%|██████████| 11/11 [00:03<00:00,  3.07it/s]
Epoch 97 Loss: 3.607364: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s]
Epoch 98 Loss: 3.583107: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s]
Epoch 99 Loss: 3.562718: 100%|██████████| 11/11 [00:03<00:00,  3.03it/s]
Epoch 100 Loss: 3.548420: 100%|██████████| 11/11 [00:03<00:00,  2.93it/s]
Epoch 101 Loss: 3.531171: 100%|██████████| 11/11 [00:04<00:00,  2.66it/s]
Epoch 102 Loss: 3.509029: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
Epoch 103 Loss: 3.485391: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
Epoch 104 Loss: 3.450888: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
Epoch 105 Loss: 3.424491: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s]
Epoch 106 Loss: 3.403693: 100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
Epoch 107 Loss: 3.390665: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
Epoch 108 Loss: 3.359253: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
Epoch 109 Loss: 3.345198: 100%|██████████| 11/11 [00:03<00:00,  2.97it/s]
Epoch 110 Loss: 3.349600: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
Epoch 111 Loss: 3.329763: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 112 Loss: 3.303129: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s]
Epoch 113 Loss: 3.266704: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
Epoch 114 Loss: 3.240781: 100%|██████████| 11/11 [00:03<00:00,  2.94it/s]
Epoch 115 Loss: 3.215163: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s]
Epoch 116 Loss: 3.198615: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
Epoch 117 Loss: 3.169775: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s]
Epoch 118 Loss: 3.152784: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
Epoch 119 Loss: 3.128009: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
Epoch 120 Loss: 3.106502: 100%|██████████| 11/11 [00:03<00:00,  3.67it/s]
Epoch 121 Loss: 3.093420: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s]
Epoch 122 Loss: 3.078810: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
Epoch 123 Loss: 3.050844: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
Epoch 124 Loss: 3.052847: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
Epoch 125 Loss: 3.025279: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
Epoch 126 Loss: 3.009476: 100%|██████████| 11/11 [00:03<00:00,  3.38it/s]
Epoch 127 Loss: 2.985609: 100%|██████████| 11/11 [00:04<00:00,  2.57it/s]
Epoch 128 Loss: 2.966175: 100%|██████████| 11/11 [00:03<00:00,  2.95it/s]
Epoch 129 Loss: 2.948951: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
Epoch 130 Loss: 2.939453: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
Epoch 131 Loss: 2.932430: 100%|██████████| 11/11 [00:03<00:00,  3.57it/s]
Epoch 132 Loss: 2.893917: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
Epoch 133 Loss: 2.855756: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
Epoch 134 Loss: 2.839368: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
Epoch 135 Loss: 2.817158: 100%|██████████| 11/11 [00:03<00:00,  3.30it/s]
Epoch 136 Loss: 2.791486: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
Epoch 137 Loss: 2.780610: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
Epoch 138 Loss: 2.761558: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
Epoch 139 Loss: 2.750291: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
Epoch 140 Loss: 2.730247: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
Epoch 141 Loss: 2.698439: 100%|██████████| 11/11 [00:04<00:00,  2.67it/s]
Epoch 142 Loss: 2.685364: 100%|██████████| 11/11 [00:03<00:00,  3.15it/s]
Epoch 143 Loss: 2.662411: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s]
Epoch 144 Loss: 2.658126: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
Epoch 145 Loss: 2.632239: 100%|██████████| 11/11 [00:03<00:00,  3.14it/s]
Epoch 146 Loss: 2.610601: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s]
Epoch 147 Loss: 2.583837: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
Epoch 148 Loss: 2.556195: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 149 Loss: 2.547484: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
Epoch 150 Loss: 2.525524: 100%|██████████| 11/11 [00:02<00:00,  3.67it/s]
Epoch 151 Loss: 2.514556: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 152 Loss: 2.501829: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 153 Loss: 2.478203: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
Epoch 154 Loss: 2.457710: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s]
Epoch 155 Loss: 2.436375: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
Epoch 156 Loss: 2.412483: 100%|██████████| 11/11 [00:03<00:00,  2.77it/s]
Epoch 157 Loss: 2.394380: 100%|██████████| 11/11 [00:03<00:00,  3.07it/s]
Epoch 158 Loss: 2.380202: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
Epoch 159 Loss: 2.355152: 100%|██████████| 11/11 [00:03<00:00,  3.50it/s]
Epoch 160 Loss: 2.347697: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
Epoch 161 Loss: 2.345931: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
Epoch 162 Loss: 2.296478: 100%|██████████| 11/11 [00:03<00:00,  3.64it/s]
Epoch 163 Loss: 2.282725: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
Epoch 164 Loss: 2.249418: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s]
Epoch 165 Loss: 2.241911: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
Epoch 166 Loss: 2.215718: 100%|██████████| 11/11 [00:03<00:00,  3.45it/s]
Epoch 167 Loss: 2.194424: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
Epoch 168 Loss: 2.184618: 100%|██████████| 11/11 [00:03<00:00,  3.13it/s]
Epoch 169 Loss: 2.156022: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
Epoch 170 Loss: 2.152164: 100%|██████████| 11/11 [00:03<00:00,  3.05it/s]
Epoch 171 Loss: 2.128624: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
Epoch 172 Loss: 2.109892: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
Epoch 173 Loss: 2.137191: 100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
Epoch 174 Loss: 2.118212: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
Epoch 175 Loss: 2.092953: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
Epoch 176 Loss: 2.065081: 100%|██████████| 11/11 [00:03<00:00,  3.61it/s]
Epoch 177 Loss: 2.064701: 100%|██████████| 11/11 [00:02<00:00,  3.69it/s]
Epoch 178 Loss: 2.061531: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
Epoch 179 Loss: 2.041025: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
Epoch 180 Loss: 1.999497: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
Epoch 181 Loss: 1.953928: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
Epoch 182 Loss: 1.929946: 100%|██████████| 11/11 [00:03<00:00,  2.86it/s]
Epoch 183 Loss: 1.923069: 100%|██████████| 11/11 [00:03<00:00,  2.93it/s]
Epoch 184 Loss: 1.905399: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
Epoch 185 Loss: 1.892242: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
Epoch 186 Loss: 1.869357: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
Epoch 187 Loss: 1.840604: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
Epoch 188 Loss: 1.837614: 100%|██████████| 11/11 [00:02<00:00,  3.68it/s]
Epoch 189 Loss: 1.816457: 100%|██████████| 11/11 [00:03<00:00,  3.57it/s]
Epoch 190 Loss: 1.786459: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 191 Loss: 1.776885: 100%|██████████| 11/11 [00:02<00:00,  3.79it/s]
Epoch 192 Loss: 1.766457: 100%|██████████| 11/11 [00:03<00:00,  3.61it/s]
Epoch 193 Loss: 1.742621: 100%|██████████| 11/11 [00:03<00:00,  3.39it/s]
Epoch 194 Loss: 1.723335: 100%|██████████| 11/11 [00:03<00:00,  2.79it/s]
Epoch 195 Loss: 1.708375: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 196 Loss: 1.692905: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
Epoch 197 Loss: 1.694218: 100%|██████████| 11/11 [00:03<00:00,  2.89it/s]
Epoch 198 Loss: 1.677686: 100%|██████████| 11/11 [00:03<00:00,  2.96it/s]
Epoch 199 Loss: 1.644678: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
Epoch 200 Loss: 1.632020: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
Epoch 201 Loss: 1.604089: 100%|██████████| 11/11 [00:03<00:00,  3.53it/s]
Epoch 202 Loss: 1.597124: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
Epoch 203 Loss: 1.580226: 100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
Epoch 204 Loss: 1.577800: 100%|██████████| 11/11 [00:03<00:00,  3.62it/s]
Epoch 205 Loss: 1.556550: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s]
Epoch 206 Loss: 1.531670: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
Epoch 207 Loss: 1.524184: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
Epoch 208 Loss: 1.504326: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
Epoch 209 Loss: 1.495385: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
Epoch 210 Loss: 1.477028: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
Epoch 211 Loss: 1.466456: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
Epoch 212 Loss: 1.437730: 100%|██████████| 11/11 [00:03<00:00,  2.97it/s]
Epoch 213 Loss: 1.439750: 100%|██████████| 11/11 [00:04<00:00,  2.70it/s]
Epoch 214 Loss: 1.427772: 100%|██████████| 11/11 [00:03<00:00,  3.19it/s]
Epoch 215 Loss: 1.417080: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
Epoch 216 Loss: 1.403871: 100%|██████████| 11/11 [00:03<00:00,  3.40it/s]
Epoch 217 Loss: 1.374399: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s]
Epoch 218 Loss: 1.362701: 100%|██████████| 11/11 [00:02<00:00,  3.83it/s]
Epoch 219 Loss: 1.348216: 100%|██████████| 11/11 [00:02<00:00,  3.76it/s]
Epoch 220 Loss: 1.337092: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s]
Epoch 221 Loss: 1.331288: 100%|██████████| 11/11 [00:03<00:00,  3.04it/s]
Epoch 222 Loss: 1.322314: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
Epoch 223 Loss: 1.303005: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
Epoch 224 Loss: 1.308330: 100%|██████████| 11/11 [00:03<00:00,  2.86it/s]
Epoch 225 Loss: 1.307320: 100%|██████████| 11/11 [00:05<00:00,  2.17it/s]
Epoch 226 Loss: 1.259610: 100%|██████████| 11/11 [00:04<00:00,  2.59it/s]
Epoch 227 Loss: 1.247879: 100%|██████████| 11/11 [00:04<00:00,  2.44it/s]
Epoch 228 Loss: 1.232500: 100%|██████████| 11/11 [00:03<00:00,  2.84it/s]
Epoch 229 Loss: 1.223805: 100%|██████████| 11/11 [00:04<00:00,  2.71it/s]
Epoch 230 Loss: 1.210760: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
Epoch 231 Loss: 1.193618: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
Epoch 232 Loss: 1.186537: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
Epoch 233 Loss: 1.159970: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
Epoch 234 Loss: 1.156743: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
Epoch 235 Loss: 1.146074: 100%|██████████| 11/11 [00:03<00:00,  2.96it/s]
Epoch 236 Loss: 1.151413: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
Epoch 237 Loss: 1.131100: 100%|██████████| 11/11 [00:03<00:00,  3.30it/s]
Epoch 238 Loss: 1.122501: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 239 Loss: 1.099850: 100%|██████████| 11/11 [00:03<00:00,  2.81it/s]
Epoch 240 Loss: 1.086024: 100%|██████████| 11/11 [00:04<00:00,  2.57it/s]
Epoch 241 Loss: 1.072572: 100%|██████████| 11/11 [00:03<00:00,  2.93it/s]
Epoch 242 Loss: 1.068453: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
Epoch 243 Loss: 1.044812: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
Epoch 244 Loss: 1.030824: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
Epoch 245 Loss: 1.019708: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s]
Epoch 246 Loss: 1.011575: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
Epoch 247 Loss: 0.990726: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
Epoch 248 Loss: 0.986041: 100%|██████████| 11/11 [00:03<00:00,  2.91it/s]
Epoch 249 Loss: 0.981658: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
Epoch 250 Loss: 0.980595: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
Epoch 251 Loss: 0.968010: 100%|██████████| 11/11 [00:03<00:00,  3.13it/s]
Epoch 252 Loss: 0.962707: 100%|██████████| 11/11 [00:03<00:00,  2.80it/s]
Epoch 253 Loss: 0.942022: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
Epoch 254 Loss: 0.944430: 100%|██████████| 11/11 [00:03<00:00,  2.89it/s]
Epoch 255 Loss: 0.929161: 100%|██████████| 11/11 [00:03<00:00,  2.88it/s]
Epoch 256 Loss: 0.937370: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
Epoch 257 Loss: 0.903115: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
Epoch 258 Loss: 0.883225: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s]
Epoch 259 Loss: 0.879743: 100%|██████████| 11/11 [00:03<00:00,  2.87it/s]
Epoch 260 Loss: 0.868025: 100%|██████████| 11/11 [00:03<00:00,  3.35it/s]
Epoch 261 Loss: 0.871274: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
Epoch 262 Loss: 0.845327: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
Epoch 263 Loss: 0.830346: 100%|██████████| 11/11 [00:03<00:00,  2.75it/s]
Epoch 264 Loss: 0.824749: 100%|██████████| 11/11 [00:03<00:00,  3.36it/s]
Epoch 265 Loss: 0.825509: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
Epoch 266 Loss: 0.826433: 100%|██████████| 11/11 [00:04<00:00,  2.64it/s]
Epoch 267 Loss: 0.803352: 100%|██████████| 11/11 [00:04<00:00,  2.58it/s]
Epoch 268 Loss: 0.803057: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
Epoch 269 Loss: 0.779101: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s]
Epoch 270 Loss: 0.776825: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
Epoch 271 Loss: 0.781099: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
Epoch 272 Loss: 0.785840: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
Epoch 273 Loss: 0.802700: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s]
Epoch 274 Loss: 0.772463: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s]
Epoch 275 Loss: 0.764557: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
Epoch 276 Loss: 0.756142: 100%|██████████| 11/11 [00:03<00:00,  3.40it/s]
Epoch 277 Loss: 0.749000: 100%|██████████| 11/11 [00:03<00:00,  3.18it/s]
Epoch 278 Loss: 0.727899: 100%|██████████| 11/11 [00:03<00:00,  2.91it/s]
Epoch 279 Loss: 0.736639: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 280 Loss: 0.732949: 100%|██████████| 11/11 [00:04<00:00,  2.73it/s]
Epoch 281 Loss: 0.714376: 100%|██████████| 11/11 [00:03<00:00,  2.79it/s]
Epoch 282 Loss: 0.706372: 100%|██████████| 11/11 [00:04<00:00,  2.62it/s]
Epoch 283 Loss: 0.718250: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s]
Epoch 284 Loss: 0.702688: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s]
Epoch 285 Loss: 0.683941: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
Epoch 286 Loss: 0.669917: 100%|██████████| 11/11 [00:03<00:00,  3.04it/s]
Epoch 287 Loss: 0.657017: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
Epoch 288 Loss: 0.645382: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
Epoch 289 Loss: 0.646296: 100%|██████████| 11/11 [00:03<00:00,  3.01it/s]
Epoch 290 Loss: 0.640179: 100%|██████████| 11/11 [00:03<00:00,  2.87it/s]
Epoch 291 Loss: 0.632041: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
Epoch 292 Loss: 0.623746: 100%|██████████| 11/11 [00:03<00:00,  3.01it/s]
Epoch 293 Loss: 0.606923: 100%|██████████| 11/11 [00:04<00:00,  2.67it/s]
Epoch 294 Loss: 0.608943: 100%|██████████| 11/11 [00:04<00:00,  2.50it/s]
Epoch 295 Loss: 0.607394: 100%|██████████| 11/11 [00:03<00:00,  2.89it/s]
Epoch 296 Loss: 0.594561: 100%|██████████| 11/11 [00:04<00:00,  2.60it/s]
Epoch 297 Loss: 0.581359: 100%|██████████| 11/11 [00:04<00:00,  2.61it/s]
Epoch 298 Loss: 0.581903: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s]
Epoch 299 Loss: 0.572175: 100%|██████████| 11/11 [00:03<00:00,  3.24it/s]
Epoch 300 Loss: 0.570626: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]

Evaluate MoE

The code below shows how to evaluate MoE.

[12]:
ensemble_loaded = MoE_ForecasterEnsemble.load(save_dir)
WARNING:merlion.models.ensemble.base:When initializing an ensemble, you must either provide the dict `model_configs` (mapping each model's name to its config) when creating the `DetectorEnsembleConfig`, or provide a list of `models` to the constructor of `EnsembleBase`. Received both. Overriding `model_configs` with the configs belonging to `models`.

Load the saved ensemble model

[13]:


expert_idx=None
# when no external experts are used, the value of expert_idx is not used in the
# forecast/batch_forecast/evaluate functions
mode='max' # either mean or max. Max picks the expert with the highest confidence; mean computes the weighted average.
use_gpu=False # set True if GPU available for faster speed
use_batch_forecast=True # set True for higher speed

y_pred_list, std_list, y_list, sMAPE_conf, sMAPE_not_conf, recall, overall_sMAPE =\
                    ensemble_loaded.evaluate(test_data, mode=mode, expert_idx=expert_idx,\
                                         use_gpu=use_gpu, use_batch_forecast=use_batch_forecast, confidence_thres=1.)

out_idx=0 # plot this idx of all the steps forecasted by MoE
print(y_pred_list.shape)
plt.plot(y_pred_list[:100, out_idx], '--', color='k', label='prediction', linewidth=1)
plt.plot(y_list[:100, out_idx], color='b', label='data', linewidth=1)
# plt.fill_between(range(y_pred_list[:100, out_idx].shape[0]), y_pred_list[:100, out_idx]-std_list[:100, out_idx],\
#                  y_pred_list[:100, out_idx]+std_list[:100, out_idx]) # standard deviation error band
plt.legend()
plt.show()

print(f'sMAPE on confident samples: {sMAPE_conf:.2f}')
print(f'sMAPE on not confident samples: {sMAPE_not_conf:.2f}')
print(f'Percentage of samples on which MoE was confident: {recall:.2f}% (use a different confidence_thres to change this)')
print(f'sMAPE on all samples: {overall_sMAPE:.2f}')

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

sMAPE_conf: 1.872 sMAPE_not_conf: 2.411 recall: 21.333% | Plain sMAPE 2.296: 100%|██████████| 1/1 [00:00<00:00, 37.30it/s]
torch.Size([25, 3])
../../_images/examples_advanced_2_MoE_Forecasting_tutorial_27_2.png
sMAPE on confident samples: 1.87
sMAPE on not confident samples: 2.41
Percentage of samples on which MoE was confident: 21.33% (use a different confidence_thres to change this)
sMAPE on all samples: 2.30