Issue to discuss the M layer design. For context, see design document here: sktime/enhancement-proposals#39
My current proposed design, this is base on v1 and v2 metadata layer design. Long-term state:
- starts off the current "metadata" class, but named like the model, e.g.,
TFT. The current ligthning network is renamed TFT_NN
- has tags and
get_test_params etc similar to current "metadata" class
__init__ has all args of two objects: the loader (D2, e.g., DecoderEncoderModule) and the network (e.g., TFT_NN). minus data
- method
get_loader_class gets the loader class (e.g., the class DecoderEncoderModule); get_loader(data: TimeSeries) produces an loader object, an instance of the get_loader_class return.
- method
get_nn_class returns the nn class (e.g., TFT_NN); get_nn(loader) gets an instance of the nn class.
- finally, there is a method
init(data), which calls the above in sequence, and produces a pair of loader and nn, as if the two get methods were called in sequence.
__call__ dispatches to init
So, a usage vignette could look like:
from lightning.pytorch import Trainer
from pytorch_forecasting import TimeSeries
from pytorch_forecasting.models import TFT
dataset = TimeSeries(...)
model_cfg = TFT(
max_encoder_length=30,
max_prediction_length=1,
batch_size=32,
loss=nn.MSELoss(),
logging_metrics=[MAE(), SMAPE()],
optimizer="adam",
hidden_size=64,
num_layers=2,
attention_head_size=4,
)
net, loader = model_pkg(dataset)
trainer = Trainer(
max_epochs=5,
accelerator="auto",
devices=1,
enable_progress_bar=True,
log_every_n_steps=10,
)
trainer.fit(net, loader)
etc
The only thing that changes for other models are the model class, and the args/values of it, for model_pkg.
In sktime, we would add the trainer as an arg to __init__, and sktime fit(data) does self.trainer(*self.model_cfg(data)) (with some potential conversion for data - or we could allow TimeSeries as an mtype)
Issue to discuss the M layer design. For context, see design document here: sktime/enhancement-proposals#39
My current proposed design, this is base on v1 and v2 metadata layer design. Long-term state:
TFT. The current ligthning network is renamedTFT_NNget_test_paramsetc similar to current "metadata" class__init__has all args of two objects: the loader (D2, e.g.,DecoderEncoderModule) and the network (e.g.,TFT_NN). minus dataget_loader_classgets the loader class (e.g., the classDecoderEncoderModule);get_loader(data: TimeSeries)produces an loader object, an instance of theget_loader_classreturn.get_nn_classreturns the nn class (e.g.,TFT_NN);get_nn(loader)gets an instance of the nn class.init(data), which calls the above in sequence, and produces a pair of loader and nn, as if the twogetmethods were called in sequence.__call__dispatches toinitSo, a usage vignette could look like:
The only thing that changes for other models are the model class, and the args/values of it, for
model_pkg.In
sktime, we would add the trainer as an arg to__init__, andsktimefit(data)doesself.trainer(*self.model_cfg(data))(with some potential conversion fordata- or we could allowTimeSeriesas anmtype)