import torch from torch import nn class TimeSeriesTransformer(nn.Module): def __init__(self, input_size_3=18, num_layers=1): super(TimeSeriesTransformer, self).__init__() self.hidden_size3 = 128 self.num_layers = num_layers self.GRU3 = nn.GRU(input_size_3, self.hidden_size3, num_layers, batch_first=True, bidirectional=True) self.attention_C = nn.Linear(2*self.hidden_size3,1) self.fc1 = nn.Linear(2 * self.hidden_size3, 128) self.bn1 = nn.BatchNorm1d(128) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(128, 64) self.bn2 = nn.BatchNorm1d(64) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(64,1) def attention_layer(self, lstm_out, attention_w): attention_scores = attention_w(lstm_out) attention_weights = torch.softmax(attention_scores, dim=1) context_vector = torch.sum(attention_weights * lstm_out, dim=1) return context_vector #def forward(self, inputs_1, inputs_2, inputs_3): def forward(self, inputs_3): # inputs_1.shape = (batch_size, seq_len, input_size_1) # inputs_2.shape = (batch_size, seq_len, input_size_2) # inputs_3.shape = (batch_size, seq_len, input_size_3) # 初始化hidden state和cell state h2 = torch.zeros(self.num_layers*2, inputs_3.size(0), self.hidden_size3).to(inputs_3.device) output_3, _ = self.GRU3(inputs_3, h2) context_C = self.attention_layer(output_3, self.attention_C) h = torch.cat([context_C], dim=1) h = self.fc1(h) h = self.relu1(h) h = self.fc2(h) h = self.relu2(h) output = self.fc3(h) return output