model_training_ml.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import lightgbm as lgb
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.metrics import mean_squared_error, mean_absolute_error
  5. from flask import Flask, request
  6. import time
  7. import traceback
  8. import logging
  9. from common.database_dml import get_data_from_mongo, insert_pickle_model_into_mongo
  10. from common.processing_data_common import missing_features, str_to_list
  11. from sklearn.pipeline import Pipeline
  12. from sklearn.svm import SVR
  13. from sklearn.preprocessing import MinMaxScaler
  14. from data_processing.data_operation.weight import WEIGHT_REGISTER
  15. app = Flask('model_training_ml——service')
  16. def get_sample_weight(df, label, args):
  17. # 样本权重
  18. if 'sample_weight' in args.keys():
  19. if args['sample_weight'] in WEIGHT_REGISTER.keys():
  20. sample_weight = WEIGHT_REGISTER[args['sample_weight']](df[label].values.reshape(-1), **args)
  21. elif args['sample_weight'] in df.columns.tolist():
  22. sample_weight = df[args['sample_weight']].values.reshape(-1)
  23. else:
  24. sample_weight = None
  25. print('sample_weight is neither in the predefined weights nor a column of the DataFrame, not applicable')
  26. return sample_weight
  27. def train_lgb(data_split, categorical_features, model_params, num_boost_round=100, sample_weight=None):
  28. X_train, X_test, y_train, y_test = data_split
  29. # 创建LightGBM数据集
  30. lgb_train = lgb.Dataset(X_train, y_train, categorical_feature=categorical_features, weight=sample_weight)
  31. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  32. # 设置参数
  33. params = {
  34. 'objective': 'regression',
  35. 'metric': 'rmse',
  36. 'boosting_type': 'gbdt',
  37. 'verbose': 1
  38. }
  39. params.update(model_params)
  40. # 训练模型
  41. print('Starting training...')
  42. gbm = lgb.train(params,
  43. lgb_train,
  44. num_boost_round=num_boost_round,
  45. valid_sets=[lgb_train, lgb_eval],
  46. )
  47. y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
  48. return gbm, y_pred
  49. def train_svr(data_split, model_params, sample_weight=None):
  50. X_train, X_test, y_train, y_test = data_split
  51. svr = Pipeline([('scaler', MinMaxScaler()),
  52. ('model', SVR(**model_params))])
  53. # 训练模型
  54. print('Starting training...')
  55. svr.fit(X_train, y_train, model__sample_weight=sample_weight)
  56. y_pred = svr.predict(X_test)
  57. return svr, y_pred
  58. def build_model(df, args):
  59. np.random.seed(42)
  60. # 参数
  61. numerical_features, categorical_features, label, model_name, model_params, col_time = str_to_list(
  62. args['numerical_features']), str_to_list(args['categorical_features']), args['label'], args['model_name'], eval(
  63. args['model_params']), args['col_time']
  64. features = numerical_features + categorical_features
  65. print("features:************", features)
  66. if 'is_limit' in df.columns:
  67. df = df[df['is_limit'] == False]
  68. # 清洗特征平均缺失率大于20%的天
  69. df = missing_features(df, features, col_time)
  70. df = df[~np.isnan(df[label])]
  71. # 拆分数据为训练集和测试集
  72. df_train, df_test = train_test_split(df, test_size=0.2, random_state=42,
  73. shuffle=False)
  74. X_train, y_train = df_train[features].values, df_train[label].values
  75. X_test, y_test = df_test[features].values, df_test[label].values
  76. # 获取样本权重
  77. sample_weight = get_sample_weight(df_train, label=label, args=args)
  78. model_type = args['model_type']
  79. # 区分常规机器学习模型和lgb,这里只实例化svr,后续可扩展
  80. if model_type == "lightgbm":
  81. num_boost_round = int(args['num_boost_round'])
  82. model, y_pred = train_lgb([X_train, X_test, y_train, y_test], categorical_features, model_params,
  83. num_boost_round, sample_weight=sample_weight)
  84. elif model_type == "svr":
  85. model, y_pred = train_svr([X_train, X_test, y_train, y_test], model_params, sample_weight=sample_weight)
  86. else:
  87. raise ValueError(f"Invalid model_type, must be one of [lightgbm, svr]")
  88. # 评估
  89. mse = mean_squared_error(y_test, y_pred)
  90. rmse = np.sqrt(mse)
  91. mae = mean_absolute_error(y_test, y_pred)
  92. print(f'The test rmse is: {rmse},"The test mae is:"{mae}')
  93. return model, features
  94. @app.route('/model_training_ml', methods=['POST'])
  95. def model_training_ml():
  96. # 获取程序开始时间
  97. start_time = time.time()
  98. result = {}
  99. success = 0
  100. print("Program starts execution!")
  101. try:
  102. args = request.values.to_dict()
  103. print('args', args)
  104. logger.info(args)
  105. power_df = get_data_from_mongo(args)
  106. model, features = build_model(power_df, args)
  107. insert_pickle_model_into_mongo(model, args, features=features)
  108. success = 1
  109. except Exception as e:
  110. my_exception = traceback.format_exc()
  111. my_exception.replace("\n", "\t")
  112. result['msg'] = my_exception
  113. end_time = time.time()
  114. result['success'] = success
  115. result['args'] = args
  116. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  117. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  118. print("Program execution ends!")
  119. return result
  120. if __name__ == "__main__":
  121. print("Program starts execution!")
  122. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  123. logger = logging.getLogger("model_training_ml log")
  124. from waitress import serve
  125. serve(app, host="0.0.0.0", port=10125)
  126. print("server start!")