analysis.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import pandas as pd
  2. import matplotlib.pyplot as plt
  3. from pymongo import MongoClient
  4. import pickle
  5. import numpy as np
  6. import plotly.express as px
  7. from plotly.subplots import make_subplots
  8. import plotly.graph_objects as go
  9. from flask import Flask,request,jsonify
  10. from waitress import serve
  11. import time
  12. import random
  13. import argparse
  14. import logging
  15. import traceback
  16. import os
  17. app = Flask('analysis_report——service')
  18. def get_data_from_mongo(args):
  19. # 1.读数据
  20. mongodb_connection,mongodb_database,all_table,accuracy_table,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['train_table'],args['accuracy_table'],args['model_table'],args['model_name']
  21. client = MongoClient(mongodb_connection)
  22. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  23. db = client[mongodb_database]
  24. # 将游标转换为列表,并创建 pandas DataFrame
  25. df_all = pd.DataFrame(db[all_table].find({}, {'_id': 0}))
  26. df_accuracy = pd.DataFrame(db[accuracy_table].find({}, {'_id': 0}))
  27. model_data = db[model_table].find_one({"model_name": model_name})
  28. if model_data is not None:
  29. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  30. # 反序列化模型
  31. model = pickle.loads(model_binary)
  32. client.close()
  33. return df_all,df_accuracy,model
  34. def draw_info(df_all,df_accuracy,model,features,args):
  35. #1.数据描述 数据描述:
  36. col_time = args['col_time']
  37. label = args['label']
  38. df_accuracy_beginTime = df_accuracy[col_time].min()
  39. df_accuracy_endTime = df_accuracy[col_time].max()
  40. df_train = df_all[df_all[col_time]<df_accuracy_beginTime][features+[col_time,label]]
  41. df_train_beginTime = df_train[col_time].min()
  42. df_train_endTime = df_train[col_time].max()
  43. text_content = f"训练数据时间范围:{df_train_beginTime} 至 {df_train_endTime},共{df_train.shape[0]}条记录,测试集数据时间范围:{df_accuracy_beginTime} 至 {df_accuracy_endTime}。<br>lightgbm模型参数:{model.params}"
  44. return text_content
  45. def draw_global_scatter(df,args):
  46. # --- 1. 实际功率和辐照度的散点图 ---
  47. col_x = args['scatter_col_x']
  48. col_y = args['label']
  49. scatter_fig = px.scatter(
  50. df,
  51. x=col_x,
  52. y=col_y,
  53. title=f"{col_x}和{col_y}的散点图",
  54. labels={"辐照度": "辐照度 (W/m²)", "实际功率": "实际功率 (kW)"}
  55. )
  56. return scatter_fig
  57. def draw_corr(df,features,args):
  58. # --- 2. 相关性热力图 ---
  59. # 计算相关性矩阵
  60. label = args['label']
  61. features_coor = features+[label]
  62. corr_matrix = df[features_coor].corr()
  63. # 使用 Plotly Express 绘制热力图
  64. heatmap_fig = px.imshow(corr_matrix,
  65. text_auto=True, # 显示数值
  66. color_continuous_scale='RdBu', # 配色方案
  67. title="Correlation Heatmap")
  68. heatmap_fig.update_coloraxes(showscale=False)
  69. return heatmap_fig
  70. def draw_feature_importance(model,features):
  71. # --- 3. 特征重要性排名 ---
  72. # 获取特征重要性
  73. importance = model.feature_importance() # 'split' 或 'gain',根据需求选择
  74. # 转换为 DataFrame 方便绘图
  75. feature_importance_df = pd.DataFrame({
  76. 'Feature': features,
  77. 'Importance': importance
  78. })
  79. feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)
  80. # 使用 Plotly Express 绘制条形图
  81. importance_fig = px.bar(feature_importance_df, x='Feature', y='Importance',
  82. title="特征重要性排名",
  83. labels={'Feature': '特征', 'Importance': '重要性'},
  84. color='Importance',
  85. color_continuous_scale='Viridis')
  86. # 更新每个 trace,确保没有图例
  87. importance_fig.update_layout(title="模型特征重要性排名",
  88. showlegend=False # 移除图例
  89. )
  90. importance_fig.update_coloraxes(showscale=False)
  91. return importance_fig
  92. def draw_data_info_table(content):
  93. # --- 4. 创建数据说明的表格 ---
  94. # 转换为表格格式:1行1列,且填充文字说明
  95. # 转换为表格格式
  96. # 创建一个空的图
  97. table_fig = go.Figure()
  98. # 第一部分: 显示文字说明
  99. table_fig.add_trace(go.Table(
  100. header=dict(
  101. values=["说明"], # 表格只有一列:说明
  102. fill_color="paleturquoise",
  103. align="center"
  104. ),
  105. cells=dict(
  106. values=[[content]] , # 第一行填入文本说明
  107. fill_color="lavender",
  108. align="center"
  109. )
  110. ))
  111. return table_fig
  112. def draw_accuracy_table(df,content):
  113. # --- 4. 每日的准确率表格 ---
  114. # 转换为表格格式
  115. table_fig = go.Figure(
  116. data=[
  117. go.Table(
  118. header=dict(
  119. values=list(df.columns),
  120. fill_color="paleturquoise",
  121. align="center"
  122. ),
  123. cells=dict(
  124. values=[df[col] for col in df.columns],
  125. fill_color="lavender",
  126. align="center"
  127. )
  128. )
  129. ]
  130. )
  131. table_fig.update_layout(title="准确率表", showlegend=False)
  132. return table_fig
  133. @app.route('/analysis_report', methods=['POST'])
  134. def analysis_report():
  135. start_time = time.time()
  136. result = {}
  137. success = 0
  138. path = ""
  139. print("Program starts execution!")
  140. try:
  141. args = request.values.to_dict()
  142. print('args',args)
  143. logger.info(args)
  144. #获取数据
  145. df_all, df_accuracy, model = get_data_from_mongo(args)
  146. features = model.feature_name()
  147. text_content = draw_info(df_all,df_accuracy,model,features,args)
  148. text_fig,scatter_fig,heatmap_fig,importance_fig,table_fig=draw_data_info_table(text_content),draw_global_scatter(df_all,args),draw_corr(df_all,features,args),draw_feature_importance(model,features),\
  149. draw_accuracy_table(df_accuracy,text_content)
  150. # --- 合并图表并保存到一个 HTML 文件 ---
  151. # 创建子图布局
  152. combined_fig = make_subplots(
  153. rows=5, cols=1,
  154. subplot_titles=["数据-模型概览","辐照度和实际功率的散点图", "相关性","特征重要性排名", "准确率表"],
  155. row_heights=[0.3, 0.6, 0.6, 0.6, 0.4],
  156. specs=[[{"type": "table"}], [{"type": "xy"}], [{"type": "heatmap"}], [{"type": "xy"}],[{"type": "table"}]] # 指定每个子图类型
  157. )
  158. # 添加文本信息到子图(第一行)
  159. # 添加文字说明
  160. for trace in text_fig.data:
  161. combined_fig.add_trace(trace, row=1, col=1)
  162. # 添加散点图
  163. for trace in scatter_fig.data:
  164. combined_fig.add_trace(trace, row=2, col=1)
  165. # 添加相关性热力图
  166. for trace in heatmap_fig.data:
  167. combined_fig.add_trace(trace, row=3, col=1)
  168. # 添加特征重要性排名图
  169. for trace in importance_fig.data:
  170. combined_fig.add_trace(trace, row=4, col=1)
  171. # 添加表格
  172. for trace in table_fig.data:
  173. combined_fig.add_trace(trace, row=5, col=1)
  174. # 更新布局
  175. combined_fig.update_layout(
  176. height=1500,
  177. title_text="分析结果汇总", # 添加换行符以适应文本内容
  178. title_x=0.5, # 中心对齐标题
  179. showlegend=False,
  180. )
  181. combined_fig.update_coloraxes(showscale=False)
  182. filename = f"{int(time.time() * 1000)}_{random.randint(1000, 9999)}.html"
  183. # 保存为 HTML
  184. directory = '/usr/share/nginx/html'
  185. if not os.path.exists(directory):
  186. os.makedirs(directory)
  187. file_path = os.path.join(directory, filename)
  188. # combined_fig.write_html(f"D://usr//{filename}")
  189. combined_fig.write_html(file_path)
  190. path = f"交互式 HTML 文件已生成!路径/data/html/{filename}"
  191. success = 1
  192. except Exception as e:
  193. my_exception = traceback.format_exc()
  194. my_exception.replace("\n","\t")
  195. result['msg'] = my_exception
  196. end_time = time.time()
  197. result['success'] = success
  198. result['args'] = args
  199. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  200. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  201. result['file_path'] = path
  202. print("Program execution ends!")
  203. return result
  204. if __name__=="__main__":
  205. print("Program starts execution!")
  206. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  207. logger = logging.getLogger("analysis_report log")
  208. from waitress import serve
  209. serve(app, host="0.0.0.0", port=10092)
  210. print("server start!")