Transformer_base.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import math
  2. import torch
  3. from torch import nn
  4. from utils.Arg import Arg
  5. arg = Arg()
  6. class PositionalEncoding(nn.Module):
  7. def __init__(self, d_model, dropout=0.1, max_len=5000):
  8. super(PositionalEncoding, self).__init__()
  9. self.dropout = nn.Dropout(p=dropout)
  10. pe = torch.zeros(max_len, d_model)
  11. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  12. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  13. pe[:, 0::2] = torch.sin(position * div_term)
  14. pe[:, 1::2] = torch.cos(position * div_term)
  15. pe = pe.unsqueeze(0).transpose(0, 1)
  16. self.register_buffer('pe', pe)
  17. def forward(self, x):
  18. x = x + self.pe[:x.size(0), :]
  19. return self.dropout(x)
  20. # class TimeSeriesTransformer(nn.Module):
  21. # def __init__(self, input_dim, output_dim, d_model, nhead, num_layers, dropout=0.1):
  22. # super().__init__()
  23. # self.input_dim = input_dim
  24. # self.output_dim = output_dim
  25. # self.pos_enc = PositionalEncoding(d_model, dropout)
  26. # self.lstm = nn.LSTM(input_dim, d_model, batch_first=True) # 添加LSTM层
  27. # self.transformer = nn.Transformer(d_model, nhead, num_layers, dropout=dropout)
  28. # self.act = nn.GELU() # 尝试使用GELU作为激活函数
  29. # self.hidden_dim = 128 # 定义一个隐藏层的维度
  30. # self.linear = nn.Linear(d_model, self.hidden_dim)
  31. # self.output_proj = nn.Linear(self.hidden_dim, output_dim)
  32. # self.layer_norm = nn.LayerNorm(d_model)
  33. #
  34. # def forward(self, src, tgt):
  35. # src = self.pos_enc(src) # 在输入数据通过LSTM和Transformer之前,先进行位置编码
  36. # src, _ = self.lstm(src)
  37. # tgt = self.pos_enc(tgt) # 在目标数据通过LSTM和Transformer之前,先进行位置编码
  38. # tgt, _ = self.lstm(tgt)
  39. #
  40. # output = self.transformer(src, tgt)
  41. # output = self.layer_norm(output)
  42. # output = self.act(self.linear(output))
  43. # output = self.output_proj(output)
  44. # output = output.squeeze(0)
  45. # return output
  46. class TimeSeriesTransformer(nn.Module):
  47. def __init__(self, input_dim, output_dim, d_model, nhead, num_layers, dropout=0.1):
  48. super().__init__()
  49. self.input_dim = input_dim
  50. self.output_dim = output_dim
  51. self.pos_enc = PositionalEncoding(d_model, dropout)
  52. self.transformer = nn.Transformer(d_model, nhead, num_layers, dropout=dropout)
  53. self.act = nn.GELU() # 尝试使用GELU作为激活函数
  54. self.hidden_dim = 64 # 降低隐藏层维度
  55. self.linear = nn.Linear(d_model, self.hidden_dim)
  56. self.output_proj = nn.Linear(self.hidden_dim, output_dim)
  57. self.layer_norm = nn.LayerNorm(d_model)
  58. self.sigmoid = nn.Tanh()
  59. def forward(self, src, tgt):
  60. src = self.pos_enc(src) # 在输入数据通过Transformer之前,先进行位置编码
  61. tgt = self.pos_enc(tgt) # 在目标数据通过Transformer之前,先进行位置编码
  62. output = self.transformer(src, tgt)
  63. output = self.layer_norm(output)
  64. output = self.act(self.linear(output))
  65. output = self.output_proj(output)
  66. output = self.sigmoid(output.squeeze(0))
  67. return output