LSTM.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. from torch import nn
  3. class TimeSeriesTransformer(nn.Module):
  4. def __init__(self, input_size_3=18, num_layers=1):
  5. super(TimeSeriesTransformer, self).__init__()
  6. self.hidden_size3 = 128
  7. self.num_layers = num_layers
  8. self.GRU3 = nn.GRU(input_size_3, self.hidden_size3, num_layers, batch_first=True, bidirectional=True)
  9. self.attention_C = nn.Linear(2*self.hidden_size3,1)
  10. self.fc1 = nn.Linear(2 * self.hidden_size3, 128)
  11. self.bn1 = nn.BatchNorm1d(128)
  12. self.relu1 = nn.ReLU()
  13. self.fc2 = nn.Linear(128, 64)
  14. self.bn2 = nn.BatchNorm1d(64)
  15. self.relu2 = nn.ReLU()
  16. self.fc3 = nn.Linear(64,1)
  17. def attention_layer(self, lstm_out, attention_w):
  18. attention_scores = attention_w(lstm_out)
  19. attention_weights = torch.softmax(attention_scores, dim=1)
  20. context_vector = torch.sum(attention_weights * lstm_out, dim=1)
  21. return context_vector
  22. #def forward(self, inputs_1, inputs_2, inputs_3):
  23. def forward(self, inputs_3):
  24. # inputs_1.shape = (batch_size, seq_len, input_size_1)
  25. # inputs_2.shape = (batch_size, seq_len, input_size_2)
  26. # inputs_3.shape = (batch_size, seq_len, input_size_3)
  27. # 初始化hidden state和cell state
  28. h2 = torch.zeros(self.num_layers*2, inputs_3.size(0), self.hidden_size3).to(inputs_3.device)
  29. output_3, _ = self.GRU3(inputs_3, h2)
  30. context_C = self.attention_layer(output_3, self.attention_C)
  31. h = torch.cat([context_C], dim=1)
  32. h = self.fc1(h)
  33. h = self.relu1(h)
  34. h = self.fc2(h)
  35. h = self.relu2(h)
  36. output = self.fc3(h)
  37. return output