123456789101112131415161718192021222324252627282930313233343536373839 |
- import torch
- from torch import nn
- class TimeSeriesTransformer(nn.Module):
- def __init__(self, input_size_3=18, num_layers=1):
- super(TimeSeriesTransformer, self).__init__()
- self.hidden_size3 = 128
- self.num_layers = num_layers
- self.GRU3 = nn.GRU(input_size_3, self.hidden_size3, num_layers, batch_first=True, bidirectional=True)
- self.attention_C = nn.Linear(2*self.hidden_size3,1)
- self.fc1 = nn.Linear(2 * self.hidden_size3, 128)
- self.bn1 = nn.BatchNorm1d(128)
- self.relu1 = nn.ReLU()
- self.fc2 = nn.Linear(128, 64)
- self.bn2 = nn.BatchNorm1d(64)
- self.relu2 = nn.ReLU()
- self.fc3 = nn.Linear(64,1)
- def attention_layer(self, lstm_out, attention_w):
- attention_scores = attention_w(lstm_out)
- attention_weights = torch.softmax(attention_scores, dim=1)
- context_vector = torch.sum(attention_weights * lstm_out, dim=1)
- return context_vector
- #def forward(self, inputs_1, inputs_2, inputs_3):
- def forward(self, inputs_3):
- # inputs_1.shape = (batch_size, seq_len, input_size_1)
- # inputs_2.shape = (batch_size, seq_len, input_size_2)
- # inputs_3.shape = (batch_size, seq_len, input_size_3)
- # 初始化hidden state和cell state
- h2 = torch.zeros(self.num_layers*2, inputs_3.size(0), self.hidden_size3).to(inputs_3.device)
- output_3, _ = self.GRU3(inputs_3, h2)
- context_C = self.attention_layer(output_3, self.attention_C)
- h = torch.cat([context_C], dim=1)
- h = self.fc1(h)
- h = self.relu1(h)
- h = self.fc2(h)
- h = self.relu2(h)
- output = self.fc3(h)
- return output
|