model_keras_fenqu.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # -*- coding: UTF-8 -*-
  2. from keras.layers import Input, Dense, LSTM, concatenate, Conv1D, Conv2D, MaxPooling1D, Reshape, Flatten, Lambda
  3. from keras.models import Model, load_model
  4. from keras.callbacks import ModelCheckpoint, EarlyStopping
  5. from keras import optimizers
  6. from keras.callbacks import TensorBoard
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. from keras.callbacks import TensorBoard, EarlyStopping
  10. def get_keras_model(opt):
  11. lstm_input = Input(shape=(opt.Model['time_step'], opt.input_size_lstm))
  12. lstm = lstm_input
  13. for i in range(opt.Model['lstm_layers']):
  14. rs = True
  15. if i == opt.Model['lstm_layers']-1:
  16. rs = False
  17. lstm = LSTM(units=opt.Model['hidden_size'], dropout=opt.Model['dropout_rate'], return_sequences=rs)(lstm)
  18. output = Dense(16, name='dense_1')(lstm)
  19. # output = Flatten(data_format='channels_last')(output)
  20. lstm1 = lstm_input
  21. for i in range(opt.Model['lstm_layers']):
  22. rs = True
  23. if i == opt.Model['lstm_layers']-1:
  24. rs = False
  25. lstm1 = LSTM(units=opt.Model['hidden_size'], dropout=opt.Model['dropout_rate'], return_sequences=rs)(lstm1)
  26. output1 = Dense(16, name='dense_2')(lstm1)
  27. # output1 = Flatten(data_format='channels_last')(output1)
  28. outputs = Lambda(sum)([output, output1])
  29. # outputs = Dense(16, name='dense_3')(outputs)
  30. model = Model(lstm_input, [output, output1])
  31. # model = Model(lstm_input, outputs)
  32. # model.compile(loss={'dense_1': 'mse', 'dense_2': 'mse', 'dense_3': 'mse'},
  33. # loss_weights={'dense_1': 500, 'dense_2': 500, 'dense_3': 0.04},
  34. # metrics={'dense_1': ['accuracy', 'mse'], 'dense_2': ['accuracy', 'mse'], 'dense_3': ['accuracy', 'mse']},
  35. # optimizer='adam') # metrics=["mae"]
  36. model.compile(loss={'dense_1': 'mse', 'dense_2': 'mse'},
  37. metrics={'dense_1': ['accuracy', 'mse'], 'dense_2': ['accuracy', 'mse'],},
  38. optimizer='adam') # metrics=["mae"]
  39. return model
  40. def train_init(use_cuda=False):
  41. import tensorflow as tf
  42. from keras.backend.tensorflow_backend import set_session
  43. if use_cuda:
  44. # gpu init
  45. sess_config = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True)
  46. sess_config.gpu_options.per_process_gpu_memory_fraction = 0.7 # 最多使用70%GPU内存
  47. sess_config.gpu_options.allow_growth=True # 初始化时不全部占满GPU显存, 按需分配
  48. sess = tf.Session(config=sess_config)
  49. set_session(sess)
  50. else:
  51. session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
  52. tf.set_random_seed(1234)
  53. sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
  54. set_session(sess)
  55. def train(opt, train_and_valid_data):
  56. train_init(opt.use_cuda)
  57. train_X, train_Y, valid_X, valid_Y = train_and_valid_data
  58. print("----------", np.array(train_X[0]).shape)
  59. print("++++++++++", np.array(train_X[1]).shape)
  60. model = get_keras_model(opt)
  61. model.summary()
  62. weight_lstm_1, bias_lstm_1 = model.get_layer('dense_1').get_weights()
  63. print("weight_lstm_1 = ", weight_lstm_1)
  64. print("bias_lstm_1 = ", bias_lstm_1)
  65. if opt.add_train:
  66. model.load_weights(opt.model_save_path + 'model_kerass.h5')
  67. check_point = ModelCheckpoint(filepath=opt.model_save_path + opt.save_name, monitor='val_loss',
  68. save_best_only=True, mode='auto')
  69. early_stop = EarlyStopping(monitor='val_loss', patience=opt.Model['patience'], mode='auto')
  70. history = model.fit(train_X, train_Y, batch_size=opt.Model['batch_size'], epochs=opt.Model['epoch'], verbose=2,
  71. validation_data=(valid_X, valid_Y), callbacks=[check_point, early_stop])
  72. loss = history.history['loss']
  73. epochs = range(1, len(loss) + 1)
  74. plt.title('Loss')
  75. # plt.plot(epochs, acc, 'red', label='Training acc')
  76. plt.plot(epochs, loss, 'blue', label='Validation loss')
  77. plt.legend()
  78. # plt.show()
  79. def predict(config, test_X):
  80. model = get_keras_model(config)
  81. model.load_weights(config.model_save_path + 'model_' + config.save_frame + '.h5')
  82. result = model.predict(test_X, batch_size=1)
  83. # result = result.reshape((-1, config.output_size))
  84. return result