model_training_lightgbm.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import lightgbm as lgb
  2. import argparse
  3. import pandas as pd
  4. import numpy as np
  5. from pymongo import MongoClient
  6. import pickle
  7. from sklearn.model_selection import train_test_split
  8. from flask import Flask,request,jsonify
  9. from waitress import serve
  10. import time
  11. import logging
  12. import traceback
  13. app = Flask('model_training_lightgbm——service')
  14. def get_data_from_mongo(args):
  15. mongodb_connection,mongodb_database,mongodb_read_table = args['mongodb_connection'],args['mongodb_database'],args['mongodb_read_table']
  16. client = MongoClient(mongodb_connection)
  17. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  18. db = client[mongodb_database]
  19. collection = db[mongodb_read_table] # 集合名称
  20. data_from_db = collection.find() # 这会返回一个游标(cursor)
  21. # 将游标转换为列表,并创建 pandas DataFrame
  22. df = pd.DataFrame(list(data_from_db))
  23. client.close()
  24. return df
  25. def insert_model_into_mongo(model_data,args):
  26. mongodb_connection,mongodb_database,mongodb_write_table = args['mongodb_connection'],args['mongodb_database'],args['mongodb_write_table']
  27. client = MongoClient(mongodb_connection)
  28. db = client[mongodb_database]
  29. if mongodb_write_table in db.list_collection_names():
  30. db[mongodb_write_table].drop()
  31. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  32. collection = db[mongodb_write_table] # 集合名称
  33. collection.insert_one(model_data)
  34. print("model inserted successfully!")
  35. def build_model(df,args):
  36. np.random.seed(42)
  37. #lightgbm预测下
  38. numerical_features,categorical_features,label,model_name,learning_rate,num_leaves,min_data_in_leaf = str_to_list(args['numerical_features']),str_to_list(args['categorical_features']),args['label'],args['model_name'],args['learning_rate'],args['num_leaves'],args['min_data_in_leaf']
  39. features = numerical_features+categorical_features
  40. print("features:************",features)
  41. # 拆分数据为训练集和测试集
  42. X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42)
  43. # 创建LightGBM数据集
  44. lgb_train = lgb.Dataset(X_train, y_train,categorical_feature=categorical_features)
  45. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  46. # 设置参数
  47. params = {
  48. 'objective': 'regression',
  49. 'metric': 'rmse',
  50. 'boosting_type': 'gbdt',
  51. 'num_leaves': num_leaves,
  52. 'learning_rate': learning_rate,
  53. 'min_data_in_leaf': min_data_in_leaf, # 叶子节点最小数据量
  54. }
  55. # 训练模型
  56. print('Starting training...')
  57. gbm = lgb.train(params,
  58. lgb_train,
  59. num_boost_round=500,
  60. valid_sets=[lgb_train, lgb_eval],
  61. )
  62. # 序列化模型
  63. model_bytes = pickle.dumps(gbm)
  64. model_data = {
  65. 'model_name': model_name,
  66. 'model': model_bytes, #将模型字节流存入数据库
  67. }
  68. print('Training completed!')
  69. return model_data
  70. def str_to_list(arg):
  71. if arg == '':
  72. return []
  73. else:
  74. return arg.split(',')
  75. @app.route('/model_training_lightgbm', methods=['POST'])
  76. def model_training_lightgbm():
  77. # 获取程序开始时间
  78. start_time = time.time()
  79. result = {}
  80. success = 0
  81. print("Program starts execution!")
  82. try:
  83. args = request.values.to_dict()
  84. print('args',args)
  85. logger.info(args)
  86. power_df = get_data_from_mongo(args)
  87. model = build_model(power_df,args)
  88. insert_model_into_mongo(model,args)
  89. success = 1
  90. except Exception as e:
  91. my_exception = traceback.format_exc()
  92. my_exception.replace("\n","\t")
  93. result['msg'] = my_exception
  94. end_time = time.time()
  95. result['success'] = success
  96. result['args'] = args
  97. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  98. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  99. print("Program execution ends!")
  100. return result
  101. if __name__=="__main__":
  102. print("Program starts execution!")
  103. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  104. logger = logging.getLogger("model_training_lightgbm log")
  105. from waitress import serve
  106. serve(app, host="0.0.0.0", port=10090)
  107. print("server start!")