|
@@ -87,8 +87,9 @@ def create_sequences(data_features,data_target,time_steps):
|
|
|
return np.array(X), np.array(y)
|
|
|
|
|
|
def model_prediction(df,args):
|
|
|
- mongodb_connection, mongodb_database, scaler_table, features, time_steps = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
- args['mongodb_database'], args['scaler_table'],args['features'],args['time_steps'])
|
|
|
+
|
|
|
+ mongodb_connection, mongodb_database, scaler_table, features, time_steps, col_time = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
|
|
|
+ args['mongodb_database'], args['scaler_table'], str_to_list(args['features']),int(args['time_steps']),args['col_time'])
|
|
|
client = MongoClient(mongodb_connection)
|
|
|
# 选择数据库(如果数据库不存在,MongoDB 会自动创建)
|
|
|
db = client[mongodb_database]
|
|
@@ -96,10 +97,12 @@ def model_prediction(df,args):
|
|
|
# Retrieve the scalers from MongoDB
|
|
|
scaler_doc = collection.find_one()
|
|
|
# Deserialize the scalers
|
|
|
+
|
|
|
feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
|
|
|
feature_scaler = joblib.load(feature_scaler_bytes)
|
|
|
target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
|
|
|
target_scaler = joblib.load(target_scaler_bytes)
|
|
|
+ df = df.fillna(method='ffill').fillna(method='bfill').sort_values(by=col_time)
|
|
|
scaled_features = feature_scaler.transform(df[features])
|
|
|
X_predict, _ = create_sequences(scaled_features, [], time_steps)
|
|
|
# 加载模型时传入自定义损失函数
|
|
@@ -110,6 +113,11 @@ def model_prediction(df,args):
|
|
|
result['predict'] = y_predict
|
|
|
return result
|
|
|
|
|
|
+def str_to_list(arg):
|
|
|
+ if arg == '':
|
|
|
+ return []
|
|
|
+ else:
|
|
|
+ return arg.split(',')
|
|
|
|
|
|
@app.route('/model_prediction_lstm', methods=['POST'])
|
|
|
def model_prediction_lstm():
|