123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- import lightgbm as lgb
- import pandas as pd
- import numpy as np
- from pymongo import MongoClient
- import pickle
- from sklearn.model_selection import train_test_split
- from sklearn.metrics import mean_squared_error,mean_absolute_error
- from flask import Flask,request
- import time
- import traceback
- import logging
- app = Flask('model_training_lightgbm——service')
- def get_data_from_mongo(args):
- mongodb_connection,mongodb_database,mongodb_read_table,timeBegin,timeEnd = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'],args['timeBegin'],args['timeEnd']
- client = MongoClient(mongodb_connection)
- # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
- db = client[mongodb_database]
- collection = db[mongodb_read_table] # 集合名称
- query = {"dateTime": {"$gte": timeBegin, "$lte": timeEnd}}
- cursor = collection.find(query)
- data = list(cursor)
- df = pd.DataFrame(data)
- # 4. 删除 _id 字段(可选)
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
- client.close()
- return df
-
- def insert_model_into_mongo(model_data,args):
- mongodb_connection,mongodb_database,mongodb_write_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_write_table']
- client = MongoClient(mongodb_connection)
- db = client[mongodb_database]
- if mongodb_write_table in db.list_collection_names():
- db[mongodb_write_table].drop()
- print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
- collection = db[mongodb_write_table] # 集合名称
- collection.insert_one(model_data)
- print("model inserted successfully!")
- def build_model(df,args):
- np.random.seed(42)
- #lightgbm预测下
- numerical_features,categorical_features,label,model_name,num_boost_round,model_params = str_to_list(args['numerical_features']),str_to_list(args['categorical_features']),args['label'],args['model_name'],int(args['num_boost_round']),eval(args['model_params'])
- features = numerical_features+categorical_features
- print("features:************",features)
- # 拆分数据为训练集和测试集
- X_train, X_test, y_train, y_test = train_test_split(df[features], df[label], test_size=0.2, random_state=42)
- # 创建LightGBM数据集
- lgb_train = lgb.Dataset(X_train, y_train,categorical_feature=categorical_features)
- lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
- # 设置参数
- params = {
- 'objective': 'regression',
- 'metric': 'rmse',
- 'boosting_type': 'gbdt',
- 'verbose':1
- }
- merged_param = params | model_params
- # 训练模型
- print('Starting training...')
- gbm = lgb.train(merged_param,
- lgb_train,
- num_boost_round=num_boost_round,
- valid_sets=[lgb_train, lgb_eval],
- )
- y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
- # 评估
- mse = mean_squared_error(y_test, y_pred)
- rmse = np.sqrt(mse)
- mae = mean_absolute_error(y_test, y_pred)
- print(f'The test rmse is: {rmse},"The test mae is:"{mae}')
-
- # 序列化模型
- model_bytes = pickle.dumps(gbm)
- model_data = {
- 'model_name': model_name,
- 'model': model_bytes, #将模型字节流存入数据库
- }
- print('Training completed!')
- return model_data
- def str_to_list(arg):
- if arg == '':
- return []
- else:
- return arg.split(',')
- @app.route('/model_training_lightgbm', methods=['POST'])
- def model_training_lightgbm():
- # 获取程序开始时间
- start_time = time.time()
- result = {}
- success = 0
- print("Program starts execution!")
- try:
- args = request.values.to_dict()
- print('args',args)
- logger.info(args)
- power_df = get_data_from_mongo(args)
- model = build_model(power_df,args)
- insert_model_into_mongo(model,args)
- success = 1
- except Exception as e:
- my_exception = traceback.format_exc()
- my_exception.replace("\n","\t")
- result['msg'] = my_exception
- end_time = time.time()
-
- result['success'] = success
- result['args'] = args
- result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
- result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
- print("Program execution ends!")
- return result
- if __name__=="__main__":
- print("Program starts execution!")
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger("model_training_lightgbm log")
- from waitress import serve
- serve(app, host="0.0.0.0", port=10089)
- print("server start!")
|