model_training_lightgbm.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import lightgbm as lgb
  2. import argparse
  3. import pandas as pd
  4. import numpy as np
  5. from pymongo import MongoClient
  6. import pickle
  7. from sklearn.model_selection import train_test_split
  8. from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
  9. from flask import Flask,request,jsonify
  10. from waitress import serve
  11. import time
  12. import logging
  13. import traceback
  14. app = Flask('model_training_lightgbm——service')
  15. def get_data_from_mongo(args):
  16. mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd']
  17. client = MongoClient(mongodb_connection)
  18. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  19. db = client[mongodb_database]
  20. collection = db[mongodb_read_table] # 集合名称
  21. query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}}
  22. cursor = collection.find(query)
  23. data = list(cursor)
  24. df = pd.DataFrame(data)
  25. # 4. 删除 _id 字段(可选)
  26. if '_id' in df.columns:
  27. df = df.drop(columns=['_id'])
  28. client.close()
  29. return df
  30. def insert_model_into_mongo(model_data,args):
  31. mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table']
  32. client = MongoClient(mongodb_connection)
  33. db = client[mongodb_database]
  34. if mongodb_write_table in db.list_collection_names():
  35. db[mongodb_write_table].drop()
  36. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  37. collection = db[mongodb_write_table] # 集合名称
  38. collection.insert_one(model_data)
  39. print("model inserted successfully!")
  40. def build_model(df,args):
  41. np.random.seed(42)
  42. #lightgbm预测下
  43. numerical_features,categorical_features,label,model_name,num_boost_round,model_params = str_to_list(args['numerical_features']),str_to_list(args['categorical_features']),args['label'],args['model_name'],int(args['num_boost_round']),eval(args['model_params'])
  44. features = numerical_features+categorical_features
  45. print("features:************",features)
  46. # 拆分数据为训练集和测试集
  47. X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42)
  48. # 创建LightGBM数据集
  49. lgb_train = lgb.Dataset(X_train, y_train,categorical_feature=categorical_features)
  50. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  51. # 设置参数
  52. params = {
  53. 'objective': 'regression',
  54. 'metric': 'rmse',
  55. 'boosting_type': 'gbdt',
  56. 'verbose':1
  57. }
  58. merged_param = params | model_params
  59. # 训练模型
  60. print('Starting training...')
  61. gbm = lgb.train(merged_param,
  62. lgb_train,
  63. num_boost_round=num_boost_round,
  64. valid_sets=[lgb_train, lgb_eval],
  65. )
  66. y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
  67. # 评估
  68. mse = mean_squared_error(y_test, y_pred)
  69. rmse = np.sqrt(mse)
  70. mae = mean_absolute_error(y_test, y_pred)
  71. print(f'The test rmse is: {rmse},"The test mae is:"{mae}')
  72. # 序列化模型
  73. model_bytes = pickle.dumps(gbm)
  74. model_data = {
  75. 'model_name': model_name,
  76. 'model': model_bytes, #将模型字节流存入数据库
  77. }
  78. print('Training completed!')
  79. return model_data
  80. def str_to_list(arg):
  81. if arg == '':
  82. return []
  83. else:
  84. return arg.split(',')
  85. @app.route('/model_training_lightgbm', methods=['POST'])
  86. def model_training_lightgbm():
  87. # 获取程序开始时间
  88. start_time = time.time()
  89. result = {}
  90. success = 0
  91. print("Program starts execution!")
  92. try:
  93. args = request.values.to_dict()
  94. print('args',args)
  95. logger.info(args)
  96. power_df = get_data_from_mongo(args)
  97. model = build_model(power_df,args)
  98. insert_model_into_mongo(model,args)
  99. success = 1
  100. except Exception as e:
  101. my_exception = traceback.format_exc()
  102. my_exception.replace("\n","\t")
  103. result['msg'] = my_exception
  104. end_time = time.time()
  105. result['success'] = success
  106. result['args'] = args
  107. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  108. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  109. print("Program execution ends!")
  110. return result
  111. if __name__=="__main__":
  112. print("Program starts execution!")
  113. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  114. logger = logging.getLogger("model_training_lightgbm log")
  115. from waitress import serve
  116. serve(app, host="0.0.0.0", port=10089)
  117. print("server start!")