model_training_ml.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import lightgbm as lgb
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.metrics import mean_squared_error, mean_absolute_error
  5. from flask import Flask, request, g
  6. import time
  7. import traceback
  8. import logging
  9. from common.database_dml import get_data_from_mongo, insert_pickle_model_into_mongo
  10. from common.processing_data_common import missing_features, str_to_list
  11. from sklearn.pipeline import Pipeline
  12. from sklearn.svm import SVR
  13. from sklearn.preprocessing import MinMaxScaler
  14. from data_processing.data_operation.weight import WEIGHT_REGISTER
  15. from io import StringIO
  16. from common.log_utils import init_request_logging, teardown_request_logging
  17. app = Flask('model_training_ml——service')
  18. # 请求前设置日志捕获
  19. @app.before_request
  20. def setup_logging():
  21. init_request_logging(logger)
  22. # 请求后清理日志处理器
  23. @app.after_request
  24. def teardown_logging(response):
  25. return teardown_request_logging(response, logger)
  26. def get_sample_weight(df, label, args):
  27. # 样本权重
  28. if 'sample_weight' in args.keys():
  29. if args['sample_weight'] in WEIGHT_REGISTER.keys():
  30. sample_weight = WEIGHT_REGISTER[args['sample_weight']](df[label].values.reshape(-1), **args)
  31. logger.info(f"use predefined weights {args['sample_weight']}")
  32. elif args['sample_weight'] in df.columns.tolist():
  33. sample_weight = df[args['sample_weight']].values.reshape(-1)
  34. logger.info(f'use dataframe col {args["sample_weight"]}')
  35. else:
  36. sample_weight = None
  37. logger.info('sample_weight is neither in the predefined weights nor a column of the DataFrame, not applicable')
  38. else:
  39. sample_weight = None
  40. logger.info('no sample_weight')
  41. return sample_weight
  42. def train_lgb(data_split, categorical_features, model_params, num_boost_round=100, sample_weight=None):
  43. X_train, X_test, y_train, y_test = data_split
  44. # 创建LightGBM数据集
  45. lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=categorical_features, weight=sample_weight)
  46. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  47. # 设置参数
  48. params = {
  49. 'objective': 'regression',
  50. 'metric': 'rmse',
  51. 'boosting_type': 'gbdt',
  52. 'verbose': 1
  53. }
  54. params.update(model_params)
  55. # 训练模型
  56. print('Starting training...')
  57. gbm = lgb.train(params,
  58. lgb_train,
  59. num_boost_round=num_boost_round,
  60. valid_sets=[lgb_train, lgb_eval],
  61. )
  62. y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
  63. return gbm, y_pred
  64. def train_svr(data_split, model_params, sample_weight=None):
  65. X_train, X_test, y_train, y_test = data_split
  66. svr = Pipeline([('scaler', MinMaxScaler()),
  67. ('model', SVR(**model_params))])
  68. # 训练模型
  69. print('Starting training...')
  70. svr.fit(X_train, y_train, model__sample_weight=sample_weight)
  71. y_pred = svr.predict(X_test)
  72. return svr, y_pred
  73. def build_model(df, args):
  74. np.random.seed(42)
  75. # 参数
  76. numerical_features, categorical_features, label, model_name, model_params, col_time = str_to_list(
  77. args['numerical_features']), str_to_list(args['categorical_features']), args['label'], args['model_name'], eval(
  78. args['model_params']), args['col_time']
  79. features = numerical_features + categorical_features
  80. print("features:************", features)
  81. if 'is_limit' in df.columns:
  82. df = df[df['is_limit'] == False]
  83. # 清洗特征平均缺失率大于20%的天
  84. df = missing_features(df, features, col_time)
  85. df = df[~np.isnan(df[label])]
  86. # 拆分数据为训练集和测试集
  87. df_train, df_test = train_test_split(df, test_size=0.2, random_state=42,
  88. shuffle=False)
  89. X_train, y_train = df_train[features].values, df_train[label].values
  90. X_test, y_test = df_test[features].values, df_test[label].values
  91. # 获取样本权重
  92. sample_weight = get_sample_weight(df_train, label=label, args=args)
  93. model_type = args['model_type']
  94. # 区分常规机器学习模型和lgb,这里只实例化svr,后续可扩展
  95. if model_type == "lightgbm":
  96. logger.info("lightgbm training")
  97. num_boost_round = int(args['num_boost_round'])
  98. model, y_pred = train_lgb([X_train, X_test, y_train, y_test], categorical_features, model_params,
  99. num_boost_round, sample_weight=sample_weight)
  100. elif model_type == "svr":
  101. logger.info("svr training")
  102. model, y_pred = train_svr([X_train, X_test, y_train, y_test], model_params, sample_weight=sample_weight)
  103. else:
  104. raise ValueError(f"Invalid model_type, must be one of [lightgbm, svr]")
  105. # 评估
  106. mse = mean_squared_error(y_test, y_pred)
  107. rmse = np.sqrt(mse)
  108. mae = mean_absolute_error(y_test, y_pred)
  109. logger.info(f'The test rmse is: {round(rmse, 2)},"The test mae is:"{round(mae, 2)}')
  110. return model, features
  111. @app.route('/model_training_ml', methods=['POST'])
  112. def model_training_ml():
  113. # 获取程序开始时间
  114. start_time = time.time()
  115. result = {}
  116. success = 0
  117. print("Program starts execution!")
  118. try:
  119. args = request.values.to_dict()
  120. logger.info(args)
  121. power_df = get_data_from_mongo(args)
  122. model, features = build_model(power_df, args)
  123. insert_pickle_model_into_mongo(model, args, features=features)
  124. success = 1
  125. except Exception as e:
  126. my_exception = traceback.format_exc()
  127. logger.error(my_exception)
  128. end_time = time.time()
  129. result['success'] = success
  130. result['args'] = args
  131. result['log'] = g.log_stream.getvalue().splitlines()
  132. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  133. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  134. print("Program execution ends!")
  135. return result
  136. if __name__ == "__main__":
  137. print("Program starts execution!")
  138. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  139. logger = logging.getLogger("model_training_ml log")
  140. from waitress import serve
  141. serve(app, host="0.0.0.0", port=10128)
  142. print("server start!")