database_dml.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import pymongo
  2. from pymongo import MongoClient, UpdateOne, DESCENDING, ASCENDING
  3. from pymongo.errors import PyMongoError
  4. import pandas as pd
  5. from sqlalchemy import create_engine
  6. import pickle
  7. from io import BytesIO
  8. import joblib
  9. import h5py
  10. import tensorflow as tf
  11. import os
  12. import tempfile
  13. def get_data_from_mongo(args):
  14. mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
  15. mongodb_database = args['mongodb_database']
  16. mongodb_read_table = args['mongodb_read_table']
  17. query_dict = {}
  18. if 'timeBegin' in args.keys():
  19. timeBegin = args['timeBegin']
  20. query_dict.update({"$gte": timeBegin})
  21. if 'timeEnd' in args.keys():
  22. timeEnd = args['timeEnd']
  23. query_dict.update({"$lte": timeEnd})
  24. client = MongoClient(mongodb_connection)
  25. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  26. db = client[mongodb_database]
  27. collection = db[mongodb_read_table] # 集合名称
  28. if len(query_dict) != 0:
  29. query = {"dateTime": query_dict}
  30. cursor = collection.find(query)
  31. else:
  32. cursor = collection.find()
  33. data = list(cursor)
  34. df = pd.DataFrame(data)
  35. # 4. 删除 _id 字段(可选)
  36. if '_id' in df.columns:
  37. df = df.drop(columns=['_id'])
  38. client.close()
  39. return df
  40. def get_df_list_from_mongo(args):
  41. mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
  42. df_list = []
  43. client = MongoClient(mongodb_connection)
  44. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  45. db = client[mongodb_database]
  46. for table in mongodb_read_table:
  47. collection = db[table] # 集合名称
  48. data_from_db = collection.find() # 这会返回一个游标(cursor)
  49. # 将游标转换为列表,并创建 pandas DataFrame
  50. df = pd.DataFrame(list(data_from_db))
  51. if '_id' in df.columns:
  52. df = df.drop(columns=['_id'])
  53. df_list.append(df)
  54. client.close()
  55. return df_list
  56. def insert_data_into_mongo(res_df, args):
  57. """
  58. 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
  59. 参数:
  60. - res_df: 要插入的 DataFrame 数据
  61. - args: 包含 MongoDB 数据库和集合名称的字典
  62. - overwrite: 布尔值,True 表示覆盖,False 表示追加
  63. - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
  64. """
  65. mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
  66. mongodb_database = args['mongodb_database']
  67. mongodb_write_table = args['mongodb_write_table']
  68. overwrite = 1
  69. update_keys = None
  70. if 'overwrite' in args.keys():
  71. overwrite = int(args['overwrite'])
  72. if 'update_keys' in args.keys():
  73. update_keys = args['update_keys'].split(',')
  74. client = MongoClient(mongodb_connection)
  75. db = client[mongodb_database]
  76. collection = db[mongodb_write_table]
  77. # 覆盖模式:删除现有集合
  78. if overwrite:
  79. if mongodb_write_table in db.list_collection_names():
  80. collection.drop()
  81. print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
  82. # 将 DataFrame 转为字典格式
  83. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  84. # 如果没有数据,直接返回
  85. if not data_dict:
  86. print("No data to insert.")
  87. return
  88. # 如果指定了 update_keys,则执行 upsert(更新或插入)
  89. if update_keys and not overwrite:
  90. operations = []
  91. for record in data_dict:
  92. # 构建查询条件,用于匹配要更新的文档
  93. query = {key: record[key] for key in update_keys}
  94. operations.append(UpdateOne(query, {'$set': record}, upsert=True))
  95. # 批量执行更新/插入操作
  96. if operations:
  97. result = collection.bulk_write(operations)
  98. client.close()
  99. print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
  100. else:
  101. # 追加模式:直接插入新数据
  102. collection.insert_many(data_dict)
  103. client.close()
  104. print("Data inserted successfully!")
  105. def get_data_fromMysql(params):
  106. mysql_conn = params['mysql_conn']
  107. query_sql = params['query_sql']
  108. #数据库读取实测气象
  109. engine = create_engine(f"mysql+pymysql://{mysql_conn}")
  110. # 定义SQL查询
  111. with engine.connect() as conn:
  112. df = pd.read_sql_query(query_sql, conn)
  113. return df
  114. def insert_pickle_model_into_mongo(model, args, features=None):
  115. mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
  116. args['mongodb_database'], args['mongodb_write_table'], args['model_name']
  117. client = MongoClient(mongodb_connection)
  118. db = client[mongodb_database]
  119. # 序列化模型
  120. model_bytes = pickle.dumps(model)
  121. model_data = {
  122. 'model_name': model_name,
  123. 'model': model_bytes, # 将模型字节流存入数据库
  124. }
  125. # 保存模型特征
  126. if features is not None:
  127. model_data['features'] = features
  128. print('Training completed!')
  129. if mongodb_write_table in db.list_collection_names():
  130. db[mongodb_write_table].drop()
  131. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  132. collection = db[mongodb_write_table] # 集合名称
  133. collection.insert_one(model_data)
  134. client.close()
  135. print("model inserted successfully!")
  136. def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
  137. mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  138. args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
  139. client = MongoClient(mongodb_connection)
  140. db = client[mongodb_database]
  141. if scaler_table in db.list_collection_names():
  142. db[scaler_table].drop()
  143. print(f"Collection '{scaler_table} already exist, deleted successfully!")
  144. collection = db[scaler_table] # 集合名称
  145. # Save the scalers in MongoDB as binary data
  146. collection.insert_one({
  147. "feature_scaler": feature_scaler_bytes.read(),
  148. "target_scaler": target_scaler_bytes.read()
  149. })
  150. print("scaler_model inserted successfully!")
  151. if model_table in db.list_collection_names():
  152. db[model_table].drop()
  153. print(f"Collection '{model_table} already exist, deleted successfully!")
  154. model_table = db[model_table]
  155. fd, temp_path = None, None
  156. client = None
  157. try:
  158. # ------------------------- 临时文件处理 -------------------------
  159. fd, temp_path = tempfile.mkstemp(suffix='.keras')
  160. os.close(fd) # 立即释放文件锁
  161. # ------------------------- 模型保存 -------------------------
  162. try:
  163. model.save(temp_path) # 不指定save_format,默认使用keras新格式
  164. except Exception as e:
  165. raise RuntimeError(f"模型保存失败: {str(e)}") from e
  166. # ------------------------- 数据插入 -------------------------
  167. with open(temp_path, 'rb') as f:
  168. result = model_table.insert_one({
  169. "model_name": args['model_name'],
  170. "model_data": f.read(),
  171. })
  172. print(f"✅ 模型 {args['model_name']} 保存成功 | 文档ID: {result.inserted_id}")
  173. return str(result.inserted_id)
  174. except Exception as e:
  175. # ------------------------- 异常分类处理 -------------------------
  176. error_type = "数据库操作" if isinstance(e, (pymongo.errors.PyMongoError, RuntimeError)) else "系统错误"
  177. print(f"❌ {error_type} - 详细错误: {str(e)}")
  178. raise # 根据业务需求决定是否重新抛出
  179. finally:
  180. # ------------------------- 资源清理 -------------------------
  181. if client:
  182. client.close()
  183. if temp_path and os.path.exists(temp_path):
  184. try:
  185. os.remove(temp_path)
  186. except PermissionError:
  187. print(f"⚠️ 临时文件清理失败: {temp_path}")
  188. # def insert_trained_model_into_mongo(model, args):
  189. # mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  190. # args['mongodb_database'],args['model_table'],args['model_name'])
  191. #
  192. # gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
  193. # client = MongoClient(mongodb_connection)
  194. # db = client[mongodb_database]
  195. # if model_table in db.list_collection_names():
  196. # db[model_table].drop()
  197. # print(f"Collection '{model_table} already exist, deleted successfully!")
  198. # model_table = db[model_table]
  199. #
  200. # # 创建 BytesIO 缓冲区
  201. # model_buffer = BytesIO()
  202. # # 将模型保存为 HDF5 格式到内存 (BytesIO)
  203. # model.save(model_buffer, save_format='h5')
  204. # # 将指针移到缓冲区的起始位置
  205. # model_buffer.seek(0)
  206. # # 获取模型的二进制数据
  207. # model_data = model_buffer.read()
  208. # # 将模型保存到 MongoDB
  209. # model_table.insert_one({
  210. # "model_name": model_name,
  211. # "model_data": model_data,
  212. # "gen_time": gen_time,
  213. # "params": params_json,
  214. # "descr": descr
  215. # })
  216. # print("模型成功保存到 MongoDB!")
  217. def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
  218. mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  219. args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
  220. client = MongoClient(mongodb_connection)
  221. db = client[mongodb_database]
  222. if scaler_table in db.list_collection_names():
  223. db[scaler_table].drop()
  224. print(f"Collection '{scaler_table} already exist, deleted successfully!")
  225. collection = db[scaler_table] # 集合名称
  226. # Save the scalers in MongoDB as binary data
  227. collection.insert_one({
  228. "feature_scaler": feature_scaler_bytes.read(),
  229. "target_scaler": scaled_target_bytes.read()
  230. })
  231. client.close()
  232. print("scaler_model inserted successfully!")
  233. def get_h5_model_from_mongo(args, custom=None):
  234. mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
  235. client = MongoClient(mongodb_connection)
  236. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  237. db = client[mongodb_database]
  238. collection = db[model_table] # 集合名称
  239. # 查询 MongoDB 获取模型数据
  240. model_doc = collection.find_one({"model_name": model_name})
  241. if model_doc:
  242. model_data = model_doc['model_data'] # 获取模型的二进制数据
  243. # 创建临时文件(自动删除)
  244. with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
  245. tmp_file.write(model_data)
  246. tmp_file_path = tmp_file.name # 记录文件路径
  247. # 从临时文件加载模型
  248. model = tf.keras.models.load_model(
  249. tmp_file_path,
  250. custom_objects=custom
  251. )
  252. print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
  253. client.close()
  254. # 确保删除临时文件
  255. if tmp_file_path and os.path.exists(tmp_file_path):
  256. try:
  257. os.remove(tmp_file_path)
  258. print(f"🧹 已清理临时文件: {tmp_file_path}")
  259. except Exception as cleanup_err:
  260. print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
  261. return model
  262. else:
  263. print(f"未找到model_name为 {model_name} 的模型。")
  264. client.close()
  265. return None
  266. def get_scaler_model_from_mongo(args, only_feature_scaler=False):
  267. mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'])
  268. client = MongoClient(mongodb_connection)
  269. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  270. db = client[mongodb_database]
  271. collection = db[scaler_table] # 集合名称
  272. # Retrieve the scalers from MongoDB
  273. scaler_doc = collection.find_one()
  274. # Deserialize the scalers
  275. feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
  276. feature_scaler = joblib.load(feature_scaler_bytes)
  277. if only_feature_scaler:
  278. return feature_scaler
  279. target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
  280. target_scaler = joblib.load(target_scaler_bytes)
  281. client.close()
  282. return feature_scaler,target_scaler