database_dml.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. from pymongo import MongoClient, UpdateOne
  2. import pandas as pd
  3. from sqlalchemy import create_engine
  4. import pickle
  5. from io import BytesIO
  6. import joblib
  7. import h5py
  8. import tensorflow as tf
  9. def get_data_from_mongo(args):
  10. mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
  11. mongodb_database = args['mongodb_database']
  12. mongodb_read_table = args['mongodb_read_table']
  13. query_dict = {}
  14. if 'timeBegin' in args.keys():
  15. timeBegin = args['timeBegin']
  16. query_dict.update({"$gte": timeBegin})
  17. if 'timeEnd' in args.keys():
  18. timeEnd = args['timeEnd']
  19. query_dict.update({"$lte": timeEnd})
  20. client = MongoClient(mongodb_connection)
  21. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  22. db = client[mongodb_database]
  23. collection = db[mongodb_read_table] # 集合名称
  24. if len(query_dict) != 0:
  25. query = {"dateTime": query_dict}
  26. cursor = collection.find(query)
  27. else:
  28. cursor = collection.find()
  29. data = list(cursor)
  30. df = pd.DataFrame(data)
  31. # 4. 删除 _id 字段(可选)
  32. if '_id' in df.columns:
  33. df = df.drop(columns=['_id'])
  34. client.close()
  35. return df
  36. def get_df_list_from_mongo(args):
  37. mongodb_connection,mongodb_database,mongodb_read_table = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['mongodb_read_table'].split(',')
  38. df_list = []
  39. client = MongoClient(mongodb_connection)
  40. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  41. db = client[mongodb_database]
  42. for table in mongodb_read_table:
  43. collection = db[table] # 集合名称
  44. data_from_db = collection.find() # 这会返回一个游标(cursor)
  45. # 将游标转换为列表,并创建 pandas DataFrame
  46. df = pd.DataFrame(list(data_from_db))
  47. if '_id' in df.columns:
  48. df = df.drop(columns=['_id'])
  49. df_list.append(df)
  50. client.close()
  51. return df_list
  52. def insert_data_into_mongo(res_df, args):
  53. """
  54. 插入数据到 MongoDB 集合中,可以选择覆盖、追加或按指定的 key 进行更新插入。
  55. 参数:
  56. - res_df: 要插入的 DataFrame 数据
  57. - args: 包含 MongoDB 数据库和集合名称的字典
  58. - overwrite: 布尔值,True 表示覆盖,False 表示追加
  59. - update_keys: 列表,指定用于匹配的 key 列,如果存在则更新,否则插入 'col1','col2'
  60. """
  61. mongodb_connection = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/"
  62. mongodb_database = args['mongodb_database']
  63. mongodb_write_table = args['mongodb_write_table']
  64. overwrite = 1
  65. update_keys = None
  66. if 'overwrite' in args.keys():
  67. overwrite = int(args['overwrite'])
  68. if 'update_keys' in args.keys():
  69. update_keys = args['update_keys'].split(',')
  70. client = MongoClient(mongodb_connection)
  71. db = client[mongodb_database]
  72. collection = db[mongodb_write_table]
  73. # 覆盖模式:删除现有集合
  74. if overwrite:
  75. if mongodb_write_table in db.list_collection_names():
  76. collection.drop()
  77. print(f"Collection '{mongodb_write_table}' already exists, deleted successfully!")
  78. # 将 DataFrame 转为字典格式
  79. data_dict = res_df.to_dict("records") # 每一行作为一个字典
  80. # 如果没有数据,直接返回
  81. if not data_dict:
  82. print("No data to insert.")
  83. return
  84. # 如果指定了 update_keys,则执行 upsert(更新或插入)
  85. if update_keys and not overwrite:
  86. operations = []
  87. for record in data_dict:
  88. # 构建查询条件,用于匹配要更新的文档
  89. query = {key: record[key] for key in update_keys}
  90. operations.append(UpdateOne(query, {'$set': record}, upsert=True))
  91. # 批量执行更新/插入操作
  92. if operations:
  93. result = collection.bulk_write(operations)
  94. print(f"Matched: {result.matched_count}, Upserts: {result.upserted_count}")
  95. else:
  96. # 追加模式:直接插入新数据
  97. collection.insert_many(data_dict)
  98. print("Data inserted successfully!")
  99. def get_data_fromMysql(params):
  100. mysql_conn = params['mysql_conn']
  101. query_sql = params['query_sql']
  102. #数据库读取实测气象
  103. engine = create_engine(f"mysql+pymysql://{mysql_conn}")
  104. # 定义SQL查询
  105. env_df = pd.read_sql_query(query_sql, engine)
  106. return env_df
  107. def insert_pickle_model_into_mongo(model, args):
  108. mongodb_connection, mongodb_database, mongodb_write_table, model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", \
  109. args['mongodb_database'], args['mongodb_write_table'], args['model_name']
  110. client = MongoClient(mongodb_connection)
  111. db = client[mongodb_database]
  112. # 序列化模型
  113. model_bytes = pickle.dumps(model)
  114. model_data = {
  115. 'model_name': model_name,
  116. 'model': model_bytes, # 将模型字节流存入数据库
  117. }
  118. print('Training completed!')
  119. if mongodb_write_table in db.list_collection_names():
  120. db[mongodb_write_table].drop()
  121. print(f"Collection '{mongodb_write_table} already exist, deleted successfully!")
  122. collection = db[mongodb_write_table] # 集合名称
  123. collection.insert_one(model_data)
  124. print("model inserted successfully!")
  125. def insert_h5_model_into_mongo(model,feature_scaler_bytes,target_scaler_bytes ,args):
  126. mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  127. args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
  128. client = MongoClient(mongodb_connection)
  129. db = client[mongodb_database]
  130. if scaler_table in db.list_collection_names():
  131. db[scaler_table].drop()
  132. print(f"Collection '{scaler_table} already exist, deleted successfully!")
  133. collection = db[scaler_table] # 集合名称
  134. # Save the scalers in MongoDB as binary data
  135. collection.insert_one({
  136. "feature_scaler": feature_scaler_bytes.read(),
  137. "target_scaler": target_scaler_bytes.read()
  138. })
  139. print("scaler_model inserted successfully!")
  140. if model_table in db.list_collection_names():
  141. db[model_table].drop()
  142. print(f"Collection '{model_table} already exist, deleted successfully!")
  143. model_table = db[model_table]
  144. # 创建 BytesIO 缓冲区
  145. model_buffer = BytesIO()
  146. # 将模型保存为 HDF5 格式到内存 (BytesIO)
  147. model.save(model_buffer, save_format='h5')
  148. # 将指针移到缓冲区的起始位置
  149. model_buffer.seek(0)
  150. # 获取模型的二进制数据
  151. model_data = model_buffer.read()
  152. # 将模型保存到 MongoDB
  153. model_table.insert_one({
  154. "model_name": model_name,
  155. "model_data": model_data
  156. })
  157. print("模型成功保存到 MongoDB!")
  158. def insert_trained_model_into_mongo(model, args):
  159. mongodb_connection,mongodb_database,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  160. args['mongodb_database'],args['model_table'],args['model_name'])
  161. gen_time, params_json, descr = args['gen_time'], args['params'], args['descr']
  162. client = MongoClient(mongodb_connection)
  163. db = client[mongodb_database]
  164. if model_table in db.list_collection_names():
  165. db[model_table].drop()
  166. print(f"Collection '{model_table} already exist, deleted successfully!")
  167. model_table = db[model_table]
  168. # 创建 BytesIO 缓冲区
  169. model_buffer = BytesIO()
  170. # 将模型保存为 HDF5 格式到内存 (BytesIO)
  171. model.save(model_buffer, save_format='h5')
  172. # 将指针移到缓冲区的起始位置
  173. model_buffer.seek(0)
  174. # 获取模型的二进制数据
  175. model_data = model_buffer.read()
  176. # 将模型保存到 MongoDB
  177. model_table.insert_one({
  178. "model_name": model_name,
  179. "model_data": model_data,
  180. "gen_time": gen_time,
  181. "params": params_json,
  182. "descr": descr
  183. })
  184. print("模型成功保存到 MongoDB!")
  185. def insert_scaler_model_into_mongo(feature_scaler_bytes, scaled_target_bytes, args):
  186. mongodb_connection,mongodb_database,scaler_table,model_table,model_name = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",
  187. args['mongodb_database'],args['scaler_table'],args['model_table'],args['model_name'])
  188. client = MongoClient(mongodb_connection)
  189. db = client[mongodb_database]
  190. if scaler_table in db.list_collection_names():
  191. db[scaler_table].drop()
  192. print(f"Collection '{scaler_table} already exist, deleted successfully!")
  193. collection = db[scaler_table] # 集合名称
  194. # Save the scalers in MongoDB as binary data
  195. collection.insert_one({
  196. "feature_scaler": feature_scaler_bytes.read(),
  197. "target_scaler": scaled_target_bytes.read()
  198. })
  199. print("scaler_model inserted successfully!")
  200. def get_h5_model_from_mongo(args, custom=None):
  201. mongodb_connection,mongodb_database,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['model_table'],args['model_name']
  202. client = MongoClient(mongodb_connection)
  203. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  204. db = client[mongodb_database]
  205. collection = db[model_table] # 集合名称
  206. # 查询 MongoDB 获取模型数据
  207. model_doc = collection.find_one({"model_name": model_name})
  208. if model_doc:
  209. model_data = model_doc['model_data'] # 获取模型的二进制数据
  210. # 将二进制数据加载到 BytesIO 缓冲区
  211. model_buffer = BytesIO(model_data)
  212. # 从缓冲区加载模型
  213. # 使用 h5py 和 BytesIO 从内存中加载模型
  214. with h5py.File(model_buffer, 'r') as f:
  215. model = tf.keras.models.load_model(f, custom_objects=custom)
  216. print(f"{model_name}模型成功从 MongoDB 加载!")
  217. client.close()
  218. return model
  219. else:
  220. print(f"未找到model_name为 {model_name} 的模型。")
  221. client.close()
  222. return None
  223. def get_scaler_model_from_mongo(args, only_feature_scaler=False):
  224. mongodb_connection, mongodb_database, scaler_table, = ("mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/", args['mongodb_database'], args['scaler_table'])
  225. client = MongoClient(mongodb_connection)
  226. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  227. db = client[mongodb_database]
  228. collection = db[scaler_table] # 集合名称
  229. # Retrieve the scalers from MongoDB
  230. scaler_doc = collection.find_one()
  231. # Deserialize the scalers
  232. feature_scaler_bytes = BytesIO(scaler_doc["feature_scaler"])
  233. feature_scaler = joblib.load(feature_scaler_bytes)
  234. if only_feature_scaler:
  235. return feature_scaler
  236. target_scaler_bytes = BytesIO(scaler_doc["target_scaler"])
  237. target_scaler = joblib.load(target_scaler_bytes)
  238. return feature_scaler,target_scaler