model_prediction_lightgbm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import lightgbm as lgb
  2. import argparse
  3. import pandas as pd
  4. import numpy as np
  5. from pymongo import MongoClient
  6. import pickle
  7. from flask import Flask,request,jsonify
  8. from waitress import serve
  9. import time
  10. import logging
  11. import traceback
  12. app = Flask('model_prediction_lightgbm——service')
  13. def get_data_from_mongo(args):
  14. mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd']
  15. client = MongoClient(mongodb_connection)
  16. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  17. db = client[mongodb_database]
  18. collection = db[mongodb_read_table] # 集合名称
  19. query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}}
  20. cursor = collection.find(query)
  21. data = list(cursor)
  22. df = pd.DataFrame(data)
  23. # 4. 删除 _id 字段(可选)
  24. if '_id' in df.columns:
  25. df = df.drop(columns=['_id'])
  26. client.close()
  27. return df
  28. def insert_data_into_mongo(res_df,args):
  29. mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table']
  30. client = MongoClient(mongodb_connection)
  31. db = client[mongodb_database]
  32. if mongodb_write_table in db.list_collection_names():
  33. db[mongodb_write_table].drop()
  34. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  35. collection = db[mongodb_write_table] # 集合名称
  36. # 将 DataFrame 转为字典格式
  37. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  38. # 插入到 MongoDB
  39. collection.insert_many(data_dict)
  40. print("data inserted successfully!")
  41. def model_prediction(df,args):
  42. mongodb_connection,mongodb_database,mongodb_model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_model_table'],args['model_name']
  43. client = MongoClient(mongodb_connection)
  44. db = client[mongodb_database]
  45. collection = db[mongodb_model_table]
  46. model_data = collection.find_one({"model_name": model_name})
  47. if model_data is not None:
  48. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  49. # 反序列化模型
  50. model = pickle.loads(model_binary)
  51. df['predict'] = model.predict(df[model.feature_name()])
  52. print("model predict result successfully!")
  53. return df
  54. @app.route('/model_prediction_lightgbm', methods=['POST'])
  55. def model_prediction_lightgbm():
  56. # 获取程序开始时间
  57. start_time = time.time()
  58. result = {}
  59. success = 0
  60. print("Program starts execution!")
  61. try:
  62. args = request.values.to_dict()
  63. print('args',args)
  64. logger.info(args)
  65. power_df = get_data_from_mongo(args)
  66. model = model_prediction(power_df,args)
  67. insert_data_into_mongo(model,args)
  68. success = 1
  69. except Exception as e:
  70. my_exception = traceback.format_exc()
  71. my_exception.replace("\n","\t")
  72. result['msg'] = my_exception
  73. end_time = time.time()
  74. result['success'] = success
  75. result['args'] = args
  76. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  77. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  78. print("Program execution ends!")
  79. return result
  80. if __name__=="__main__":
  81. print("Program starts execution!")
  82. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  83. logger = logging.getLogger("model_prediction_lightgbm log")
  84. from waitress import serve
  85. serve(app, host="0.0.0.0", port=10090)
  86. print("server start!")