model_training_lightgbm.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import lightgbm as lgb
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.metrics import mean_squared_error,mean_absolute_error
  5. from flask import Flask,request
  6. import time
  7. import traceback
  8. import logging
  9. from common.database_dml import get_data_from_mongo,insert_pickle_model_into_mongo
  10. from common.processing_data_common import missing_features,str_to_list
  11. app = Flask('model_training_lightgbm——service')
  12. def build_model(df,args):
  13. np.random.seed(42)
  14. #lightgbm预测下
  15. numerical_features,categorical_features,label,model_name,num_boost_round,model_params,col_time = str_to_list(args['numerical_features']),str_to_list(args['categorical_features']),args['label'],args['model_name'],int(args['num_boost_round']),eval(args['model_params']),args['col_time']
  16. features = numerical_features+categorical_features
  17. print("features:************",features)
  18. if 'is_limit' in df.columns:
  19. df = df[df['is_limit']==False]
  20. # 清洗特征平均缺失率大于20%的天
  21. df = missing_features(df, features, col_time)
  22. # 拆分数据为训练集和测试集
  23. X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42)
  24. # 创建LightGBM数据集
  25. lgb_train = lgb.Dataset(X_train, y_train,categorical_feature=categorical_features)
  26. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  27. # 设置参数
  28. params = {
  29. 'objective': 'regression',
  30. 'metric': 'rmse',
  31. 'boosting_type': 'gbdt',
  32. 'verbose':1
  33. }
  34. params.update(model_params)
  35. # 训练模型
  36. print('Starting training...')
  37. gbm = lgb.train(params,
  38. lgb_train,
  39. num_boost_round=num_boost_round,
  40. valid_sets=[lgb_train, lgb_eval],
  41. )
  42. y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
  43. # 评估
  44. mse = mean_squared_error(y_test, y_pred)
  45. rmse = np.sqrt(mse)
  46. mae = mean_absolute_error(y_test, y_pred)
  47. print(f'The test rmse is: {rmse},"The test mae is:"{mae}')
  48. return gbm
  49. @app.route('/model_training_lightgbm', methods=['POST'])
  50. def model_training_lightgbm():
  51. # 获取程序开始时间
  52. start_time = time.time()
  53. result = {}
  54. success = 0
  55. print("Program starts execution!")
  56. try:
  57. args = request.values.to_dict()
  58. print('args',args)
  59. logger.info(args)
  60. power_df = get_data_from_mongo(args)
  61. model = build_model(power_df,args)
  62. insert_pickle_model_into_mongo(model,args)
  63. success = 1
  64. except Exception as e:
  65. my_exception = traceback.format_exc()
  66. my_exception.replace("\n","\t")
  67. result['msg'] = my_exception
  68. end_time = time.time()
  69. result['success'] = success
  70. result['args'] = args
  71. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  72. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  73. print("Program execution ends!")
  74. return result
  75. if __name__=="__main__":
  76. print("Program starts execution!")
  77. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  78. logger = logging.getLogger("model_training_lightgbm log")
  79. from waitress import serve
  80. serve(app, host="0.0.0.0", port=10089)
  81. print("server start!")