-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathduet_model.py
More file actions
67 lines (56 loc) · 2.59 KB
/
duet_model.py
File metadata and controls
67 lines (56 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from ts_benchmark.baselines.duet.layers.linear_extractor_cluster import Linear_extractor_cluster
import torch.nn as nn
from einops import rearrange
from ts_benchmark.baselines.duet.utils.masked_attention import Mahalanobis_mask, Encoder, EncoderLayer, FullAttention, AttentionLayer
import torch
class DUETModel(nn.Module):
def __init__(self, config):
super(DUETModel, self).__init__()
self.cluster = Linear_extractor_cluster(config)
self.CI = config.CI
self.n_vars = config.enc_in
self.mask_generator = Mahalanobis_mask(config.seq_len)
self.Channel_transformer = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(
True,
config.factor,
attention_dropout=config.dropout,
output_attention=config.output_attention,
),
config.d_model,
config.n_heads,
),
config.d_model,
config.d_ff,
dropout=config.dropout,
activation=config.activation,
)
for _ in range(config.e_layers)
],
norm_layer=torch.nn.LayerNorm(config.d_model)
)
self.linear_head = nn.Sequential(nn.Linear(config.d_model, config.pred_len), nn.Dropout(config.fc_dropout))
def forward(self, input):
# x: [batch_size, seq_len, n_vars]
if self.CI:
channel_independent_input = rearrange(input, 'b l n -> (b n) l 1')
reshaped_output, L_importance = self.cluster(channel_independent_input)
temporal_feature = rearrange(reshaped_output, '(b n) l 1 -> b l n', b=input.shape[0])
else:
temporal_feature, L_importance = self.cluster(input)
# B x d_model x n_vars -> B x n_vars x d_model
temporal_feature = rearrange(temporal_feature, 'b d n -> b n d')
if self.n_vars > 1:
changed_input = rearrange(input, 'b l n -> b n l')
channel_mask = self.mask_generator(changed_input)
channel_group_feature, attention = self.Channel_transformer(x=temporal_feature, attn_mask=channel_mask)
output = self.linear_head(channel_group_feature)
else:
output = temporal_feature
output = self.linear_head(output)
output = rearrange(output, 'b n d -> b d n')
output = self.cluster.revin(output, "denorm")
return output, L_importance