import lightgbm as lgb import argparse import pandas as pd import numpy as np from pymongo import MongoClient import pickle from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score from flask import Flask,request,jsonify from waitress import serve import time import logging import traceback app = Flask('model_training_lightgbm——service') def get_data_from_mongo(args): mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd'] client = MongoClient(mongodb_connection) # 选择数据库(如果数据库不存在,MongoDB 会自动创建) db = client[mongodb_database] collection = db[mongodb_read_table] # 集合名称 query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}} cursor = collection.find(query) data = list(cursor) df = pd.DataFrame(data) # 4. 删除 _id 字段(可选) if '_id' in df.columns: df = df.drop(columns=['_id']) client.close() return df def insert_model_into_mongo(model_data,args): mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table'] client = MongoClient(mongodb_connection) db = client[mongodb_database] if mongodb_write_table in db.list_collection_names(): db[mongodb_write_table].drop() print(f"Collection '{mongodb_write_table} already exist, deleted successfully!") collection = db[mongodb_write_table] # 集合名称 collection.insert_one(model_data) print("model inserted successfully!") def build_model(df,args): np.random.seed(42) #lightgbm预测下 numerical_features,categorical_features,label,model_name,num_boost_round,model_params = str_to_list(args['numerical_features']),str_to_list(args['categorical_features']),args['label'],args['model_name'],int(args['num_boost_round']),eval(args['model_params']) features = numerical_features+categorical_features print("features:************",features) # 拆分数据为训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42) # 创建LightGBM数据集 lgb_train = lgb.Dataset(X_train, y_train,categorical_feature=categorical_features) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) # 设置参数 params = { 'objective': 'regression', 'metric': 'rmse', 'boosting_type': 'gbdt', 'verbose':1 } merged_param = params | model_params # 训练模型 print('Starting training...') gbm = lgb.train(merged_param, lgb_train, num_boost_round=num_boost_round, valid_sets=[lgb_train, lgb_eval], ) y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) # 评估 mse = mean_squared_error(y_test, y_pred) rmse = np.sqrt(mse) mae = mean_absolute_error(y_test, y_pred) print(f'The test rmse is: {rmse},"The test mae is:"{mae}') # 序列化模型 model_bytes = pickle.dumps(gbm) model_data = { 'model_name': model_name, 'model': model_bytes, #将模型字节流存入数据库 } print('Training completed!') return model_data def str_to_list(arg): if arg == '': return [] else: return arg.split(',') @app.route('/model_training_lightgbm', methods=['POST']) def model_training_lightgbm(): # 获取程序开始时间 start_time = time.time() result = {} success = 0 print("Program starts execution!") try: args = request.values.to_dict() print('args',args) logger.info(args) power_df = get_data_from_mongo(args) model = build_model(power_df,args) insert_model_into_mongo(model,args) success = 1 except Exception as e: my_exception = traceback.format_exc() my_exception.replace("\n","\t") result['msg'] = my_exception end_time = time.time() result['success'] = success result['args'] = args result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time)) result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time)) print("Program execution ends!") return result if __name__=="__main__": print("Program starts execution!") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("model_training_lightgbm log") from waitress import serve serve(app, host="0.0.0.0", port=10089) print("server start!")