model_prediction_lightgbm.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import lightgbm as lgb
  2. import argparse
  3. import pandas as pd
  4. import numpy as np
  5. from pymongo import MongoClient
  6. import pickle
  7. from flask import Flask,request,jsonify
  8. from waitress import serve
  9. import time
  10. import logging
  11. import traceback
  12. app = Flask('model_prediction_lightgbm——service')
  13. def get_data_from_mongo(args):
  14. mongodb_connection,mongodb_database,mongodb_read_table = args['mongodb_connection'],args['mongodb_database'],args['mongodb_read_table']
  15. client = MongoClient(mongodb_connection)
  16. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  17. db = client[mongodb_database]
  18. collection = db[mongodb_read_table] # 集合名称
  19. data_from_db = collection.find() # 这会返回一个游标(cursor)
  20. # 将游标转换为列表,并创建 pandas DataFrame
  21. df = pd.DataFrame(list(data_from_db))
  22. client.close()
  23. return df
  24. def insert_data_into_mongo(res_df,args):
  25. mongodb_connection,mongodb_database,mongodb_write_table = args['mongodb_connection'],args['mongodb_database'],args['mongodb_write_table']
  26. client = MongoClient(mongodb_connection)
  27. db = client[mongodb_database]
  28. if mongodb_write_table in db.list_collection_names():
  29. db[mongodb_write_table].drop()
  30. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  31. collection = db[mongodb_write_table] # 集合名称
  32. # 将 DataFrame 转为字典格式
  33. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  34. # 插入到 MongoDB
  35. collection.insert_many(data_dict)
  36. print("data inserted successfully!")
  37. def model_prediction(df,args):
  38. mongodb_connection,mongodb_database,mongodb_model_table,model_name = args['mongodb_connection'],args['mongodb_database'],args['mongodb_model_table'],args['model_name']
  39. client = MongoClient(mongodb_connection)
  40. db = client[mongodb_database]
  41. collection = db[mongodb_model_table]
  42. model_data = collection.find_one({"model_name": model_name})
  43. if model_data is not None:
  44. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  45. # 反序列化模型
  46. model = pickle.loads(model_binary)
  47. df['predict'] = model.predict(df[model.feature_name()])
  48. print("model predict result successfully!")
  49. return df
  50. @app.route('/model_prediction_lightgbm', methods=['POST'])
  51. def model_prediction_lightgbm():
  52. # 获取程序开始时间
  53. start_time = time.time()
  54. result = {}
  55. success = 0
  56. print("Program starts execution!")
  57. try:
  58. args = request.values.to_dict()
  59. print('args',args)
  60. logger.info(args)
  61. power_df = get_data_from_mongo(args)
  62. model = model_prediction(power_df,args)
  63. insert_data_into_mongo(model,args)
  64. success = 1
  65. except Exception as e:
  66. my_exception = traceback.format_exc()
  67. my_exception.replace("\n","\t")
  68. result['msg'] = my_exception
  69. end_time = time.time()
  70. result['success'] = success
  71. result['args'] = args
  72. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  73. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  74. print("Program execution ends!")
  75. return result
  76. if __name__=="__main__":
  77. print("Program starts execution!")
  78. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  79. logger = logging.getLogger("model_prediction_lightgbm log")
  80. from waitress import serve
  81. serve(app, host="0.0.0.0", port=10089)
  82. print("server start!")