model_prediction_lstm.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from flask import Flask,request
  2. import time
  3. import logging
  4. import traceback
  5. import numpy as np
  6. from itertools import chain
  7. from common.database_dml import get_data_from_mongo,insert_data_into_mongo,get_h5_model_from_mongo,get_scaler_model_from_mongo
  8. app = Flask('model_prediction_lstm——service')
  9. # 创建时间序列数据
  10. def create_sequences(data_features,data_target,time_steps):
  11. X, y = [], []
  12. if len(data_features)<time_steps:
  13. print("数据长度不能比时间步长小!")
  14. return np.array(X), np.array(y)
  15. else:
  16. for i in range(len(data_features) - time_steps+1):
  17. X.append(data_features[i:(i + time_steps)])
  18. if len(data_target)>0:
  19. y.append(data_target[i + time_steps -1])
  20. return np.array(X), np.array(y)
  21. def model_prediction(df,args):
  22. features, time_steps, col_time, model_name,col_reserve = str_to_list(args['features']), int(args['time_steps']),args['col_time'],args['model_name'],str_to_list(args['col_reserve'])
  23. feature_scaler,target_scaler = get_scaler_model_from_mongo(args)
  24. df = df.fillna(method='ffill').fillna(method='bfill').sort_values(by=col_time)
  25. scaled_features = feature_scaler.transform(df[features])
  26. X_predict, _ = create_sequences(scaled_features, [], time_steps)
  27. # 加载模型时传入自定义损失函数
  28. # model = load_model(f'{farmId}_model.h5', custom_objects={'rmse': rmse})
  29. model = get_h5_model_from_mongo(args)
  30. y_predict = list(chain.from_iterable(target_scaler.inverse_transform([model.predict(X_predict).flatten()])))
  31. result = df[-len(y_predict):]
  32. result['predict'] = y_predict
  33. result.loc[result['predict'] < 0, 'predict'] = 0
  34. result['model'] = model_name
  35. features_reserve = col_reserve + ['model', 'predict']
  36. return result[set(features_reserve)]
  37. def str_to_list(arg):
  38. if arg == '':
  39. return []
  40. else:
  41. return arg.split(',')
  42. @app.route('/model_prediction_lstm', methods=['POST'])
  43. def model_prediction_lstm():
  44. # 获取程序开始时间
  45. start_time = time.time()
  46. result = {}
  47. success = 0
  48. print("Program starts execution!")
  49. try:
  50. args = request.values.to_dict()
  51. print('args',args)
  52. logger.info(args)
  53. power_df = get_data_from_mongo(args)
  54. model = model_prediction(power_df,args)
  55. insert_data_into_mongo(model,args)
  56. success = 1
  57. except Exception as e:
  58. my_exception = traceback.format_exc()
  59. my_exception.replace("\n","\t")
  60. result['msg'] = my_exception
  61. end_time = time.time()
  62. result['success'] = success
  63. result['args'] = args
  64. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  65. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  66. print("Program execution ends!")
  67. return result
  68. if __name__=="__main__":
  69. print("Program starts execution!")
  70. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  71. logger = logging.getLogger("model_prediction_lstm log")
  72. from waitress import serve
  73. serve(app, host="0.0.0.0", port=10097)
  74. print("server start!")