model_training_ml.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import lightgbm as lgb
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.metrics import mean_squared_error, mean_absolute_error
  5. from flask import Flask, request
  6. import time
  7. import traceback
  8. import logging
  9. from common.database_dml import get_data_from_mongo, insert_pickle_model_into_mongo
  10. from common.processing_data_common import missing_features, str_to_list
  11. from sklearn.pipeline import Pipeline
  12. from sklearn.svm import SVR
  13. from sklearn.preprocessing import MinMaxScaler
  14. app = Flask('model_training_ml——service')
  15. """
  16. 基于model_training_lightgbm.py
  17. 机器学习通用训练方法,特点
  18. 1. 保存模型同时,保存模型特征
  19. 2. 支持模型训练样本权重(需要在预处理部分生成权重特征)
  20. 参数格式如下
  21. """
  22. def train_lgb(data_split, categorical_features, model_params, num_boost_round, sample_weight=None):
  23. X_train, X_test, y_train, y_test = data_split
  24. # 创建LightGBM数据集
  25. lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=categorical_features, weight=sample_weight)
  26. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  27. # 设置参数
  28. params = {
  29. 'objective': 'regression',
  30. 'metric': 'rmse',
  31. 'boosting_type': 'gbdt',
  32. 'verbose': 1
  33. }
  34. print(type(model_params))
  35. params.update(model_params)
  36. # 训练模型
  37. print('Starting training...')
  38. gbm = lgb.train(params,
  39. lgb_train,
  40. num_boost_round=num_boost_round,
  41. valid_sets=[lgb_train, lgb_eval],
  42. )
  43. y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
  44. return gbm, y_pred
  45. def train_svr(data_split, model_params, sample_weight=None):
  46. X_train, X_test, y_train, y_test = data_split
  47. svr = Pipeline([('scaler', MinMaxScaler()),
  48. ('model', SVR(**model_params))])
  49. # 训练模型
  50. print('Starting training...')
  51. svr.fit(X_train, y_train, model__sample_weight=sample_weight)
  52. y_pred = svr.predict(X_test)
  53. return svr, y_pred
  54. def build_model(df, args):
  55. np.random.seed(42)
  56. # lightgbm预测下
  57. numerical_features, categorical_features, label, model_name, num_boost_round, model_params, col_time = str_to_list(
  58. args['numerical_features']), str_to_list(args['categorical_features']), args['label'], args['model_name'], int(
  59. args['num_boost_round']), eval(args['model_params']), args['col_time']
  60. # 样本权重
  61. sample_weight = None
  62. if 'sample_weight' in args.keys():
  63. sample_weight = args['sample_weight']
  64. features = numerical_features + categorical_features
  65. print("features:************", features)
  66. if 'is_limit' in df.columns:
  67. df = df[df['is_limit'] == False]
  68. # 清洗特征平均缺失率大于20%的天
  69. df = missing_features(df, features, col_time)
  70. df = df[~np.isnan(df[label])]
  71. # 拆分数据为训练集和测试集
  72. X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42,
  73. shuffle=False)
  74. model_type = args['model_type']
  75. # 区分常规机器学习模型和lgb,这里只实例化svr,后续可扩展
  76. if model_type == "lightgbm":
  77. model, y_pred = train_lgb([X_train, X_test, y_train, y_test], categorical_features, model_params,
  78. num_boost_round, sample_weight=sample_weight)
  79. elif model_type == "SVR":
  80. model, y_pred = train_svr([X_train, X_test, y_train, y_test], model_params, sample_weight=sample_weight)
  81. else:
  82. raise ValueError(f"Invalid model_type, must be one of [lightgbm, SVR]")
  83. # 评估
  84. mse = mean_squared_error(y_test, y_pred)
  85. rmse = np.sqrt(mse)
  86. mae = mean_absolute_error(y_test, y_pred)
  87. print(f'The test rmse is: {rmse},"The test mae is:"{mae}')
  88. return model, features
  89. @app.route('/model_training_ml', methods=['POST'])
  90. def model_training_ml():
  91. # 获取程序开始时间
  92. start_time = time.time()
  93. result = {}
  94. success = 0
  95. print("Program starts execution!")
  96. try:
  97. args = request.values.to_dict()
  98. print('args', args)
  99. logger.info(args)
  100. power_df = get_data_from_mongo(args)
  101. model, features = build_model(power_df, args)
  102. insert_pickle_model_into_mongo(model, args, features=features)
  103. success = 1
  104. except Exception as e:
  105. my_exception = traceback.format_exc()
  106. my_exception.replace("\n", "\t")
  107. result['msg'] = my_exception
  108. end_time = time.time()
  109. result['success'] = success
  110. result['args'] = args
  111. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  112. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  113. print("Program execution ends!")
  114. return result
  115. if __name__ == "__main__":
  116. print("Program starts execution!")
  117. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  118. logger = logging.getLogger("model_training_ml log")
  119. from waitress import serve
  120. serve(app, host="0.0.0.0", port=10125)
  121. print("server start!")