model_keras.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # -*- coding: UTF-8 -*-
  2. from keras.layers import Input, Dense, LSTM
  3. from keras.models import Model
  4. from keras.callbacks import ModelCheckpoint, EarlyStopping
  5. def get_keras_model(opt):
  6. input1 = Input(shape=(opt.Model['time_step'], opt.input_size))
  7. lstm = input1
  8. for i in range(opt.Model['lstm_layers']):
  9. lstm = LSTM(units=opt.Model['hidden_size'],dropout=opt.Model['dropout_rate'],return_sequences=True)(lstm)
  10. output = Dense(opt.output_size)(lstm)
  11. model = Model(input1, output)
  12. model.compile(loss='mse', optimizer='adam') # metrics=["mae"]
  13. return model
  14. def gpu_train_init():
  15. import tensorflow as tf
  16. from keras.backend.tensorflow_backend import set_session
  17. sess_config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True)
  18. sess_config.gpu_options.per_process_gpu_memory_fraction = 0.7 # 最多使用70%GPU内存
  19. sess_config.gpu_options.allow_growth=True # 初始化时不全部占满GPU显存, 按需分配
  20. sess = tf.Session(config=sess_config)
  21. set_session(sess)
  22. def train(opt, train_and_valid_data):
  23. if opt.use_cuda: gpu_train_init()
  24. train_X, train_Y, valid_X, valid_Y = train_and_valid_data
  25. model = get_keras_model(opt)
  26. model.summary()
  27. if opt.add_train:
  28. model.load_weights(opt.model_save_path + opt.model_name)
  29. check_point = ModelCheckpoint(filepath=opt.model_save_path + opt.model_name, monitor='val_loss',
  30. save_best_only=True, mode='auto')
  31. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  32. model.fit(train_X, train_Y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2,
  33. validation_data=(valid_X, valid_Y), callbacks=[check_point, early_stop])
  34. def predict(config, test_X):
  35. model = get_keras_model(config)
  36. model.load_weights(config.model_save_path + config.model_name)
  37. result = model.predict(test_X, batch_size=1)
  38. result = result.reshape((-1, config.output_size))
  39. return result