training.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import random
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import Dataset, DataLoader
  6. import numpy as np
  7. import pandas as pd
  8. from dataset.TimeDataset import TimeSeriesDataset
  9. #from model.Transformer_base import TimeSeriesTransformer
  10. from model.LSTM import TimeSeriesTransformer
  11. from tqdm import tqdm
  12. from utils.Arg import Arg
  13. from utils import ModeTest
  14. import matplotlib.pyplot as plt
  15. import training_model
  16. arg = Arg()
  17. # 超参数
  18. input_dim = arg.input_dim
  19. output_dim = arg.output_dim
  20. input_seq_length = arg.input_seq_length
  21. output_seq_length = arg.output_seq_length
  22. d_model = arg.d_model
  23. nhead = arg.nhead
  24. num_layers = arg.num_layers
  25. dropout = arg.dropout
  26. batch_size = arg.batch_size
  27. epochs = arg.epochs
  28. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  29. def setup_seed(seed):
  30. torch.manual_seed( seed )
  31. torch.cuda.manual_seed_all( seed )
  32. np.random.seed( seed )
  33. random.seed( seed )
  34. def plot_fig(time,out,tag):
  35. plt.plot(time,out)
  36. plt.plot(time,tag)
  37. plt.show()
  38. def train(model_use = False):
  39. #model = TimeSeriesTransformer(input_dim, output_dim, d_model, nhead, num_layers, dropout).to(device)
  40. model = TimeSeriesTransformer()
  41. #model = torch.compile(model, mode="reduce-overhead")
  42. if model_use:
  43. print("载入历史训练的模型")
  44. model.load_state_dict(torch.load('./save/best_loss_short_encoder.pt'))
  45. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  46. criterion = nn.MSELoss()#nn.L1Loss()#nn.MSELoss()
  47. best_loss = float('inf')
  48. model.train()
  49. for epoch in range(epochs):
  50. epoch_loss = 0.0
  51. data_len = 0
  52. for i in tqdm(range(6), desc='Training progress:'):
  53. if i == 4 :
  54. continue
  55. file_inputs_2 = './data/Dataset_training/NWP/NWP_{}.csv'.format(i)
  56. file_inputs_3 = './data/Dataset_training/power/power_{}.csv'.format(i)
  57. dataset = TimeSeriesDataset(file_inputs_3, file_inputs_2)
  58. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  59. # 训练模型
  60. data_loss = 0.0
  61. for batch_idx, (input_seq, output_seq) in enumerate(dataloader):
  62. input_seq, output_seq = input_seq.to(device), output_seq.to(device)
  63. # 前向传播
  64. # input_seq = input_seq.permute(1, 0, 2)
  65. # tgt = input_seq[-1:]
  66. #predictions = model(input_seq,tgt)
  67. predictions = model(input_seq)
  68. # 计算损失
  69. loss = criterion(predictions, output_seq)
  70. # 反向传播
  71. optimizer.zero_grad()
  72. loss.backward()
  73. optimizer.step()
  74. data_loss += loss.item()
  75. data_len += len(dataloader)
  76. epoch_loss += data_loss
  77. #print(f"Datasate is {i} ,Loss is {data_loss/data_len}")
  78. print(f"Epoch {epoch+1}, Loss: {epoch_loss / data_len}")
  79. # 保存 模型
  80. if epoch_loss < best_loss:
  81. best_loss = epoch_loss
  82. print("Best loss model is saved")
  83. torch.save(model.state_dict(), './save/best_loss_short_encoder.pt')
  84. if __name__ == '__main__':
  85. setup_seed(50)
  86. model_use = True
  87. model = TimeSeriesTransformer()
  88. #model = torch.compile(model, mode="reduce-overhead")
  89. if model_use:
  90. print("载入历史训练的模型")
  91. model.load_state_dict(torch.load('save/lstm_base.pt'))
  92. #training_model.base_train(model)
  93. # re_train_for_data(model, 2022,11)
  94. # re_train_for_data(model, 2022,9)
  95. #train(model_use = True)
  96. #training_model.re_train_for_data(model, 2022, 5)
  97. #training_model.re_train_for_data(model, 2022, 10)
  98. # for i in [5]:
  99. # #training_model.re_train_for_data(model, 2022,11)
  100. # training_model.re_train_for_data(model,2023,4)
  101. # ModeTest.test_model(2023,i,"lstm_base_pro.pt") #69
  102. # break
  103. for i in range(4,5):
  104. #training_model.re_train_for_turbine_sum_power(model)
  105. #training_model.re_train_for_data(model, 2023,1)
  106. #training_model.re_train_for_data(model, 2022,1)
  107. #training_model.re_train_for_data(model, 2023, 2)
  108. ModeTest.test_model(2022, i, "lstm_base_pro.pt")