analysis.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # -*- coding: utf-8 -*-
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. from pymongo import MongoClient
  5. import pickle
  6. import numpy as np
  7. import plotly.express as px
  8. from plotly.subplots import make_subplots
  9. import plotly.graph_objects as go
  10. from flask import Flask,request,jsonify
  11. from waitress import serve
  12. import time
  13. import random
  14. import argparse
  15. import logging
  16. import traceback
  17. import os
  18. import lightgbm as lgb
  19. app = Flask('analysis_report——service')
  20. def get_data_from_mongo(args):
  21. # 1.读数据
  22. mongodb_connection,mongodb_database,all_table,accuracy_table,model_table,model_name = "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/",args['mongodb_database'],args['train_table'],args['accuracy_table'],args['model_table'],args['model_name']
  23. client = MongoClient(mongodb_connection)
  24. # 选择数据库(如果数据库不存在,MongoDB 会自动创建)
  25. db = client[mongodb_database]
  26. # 将游标转换为列表,并创建 pandas DataFrame
  27. df_all = pd.DataFrame(db[all_table].find({}, {'_id': 0}))
  28. df_accuracy = pd.DataFrame(db[accuracy_table].find({}, {'_id': 0}))
  29. model_data = db[model_table].find_one({"model_name": model_name})
  30. if model_data is not None:
  31. model_binary = model_data['model'] # 确保这个字段是存储模型的二进制数据
  32. # 反序列化模型
  33. model = pickle.loads(model_binary)
  34. client.close()
  35. return df_all,df_accuracy,model
  36. def draw_info(df_all,df_accuracy,model,features,args):
  37. #1.数据描述 数据描述:
  38. col_time = args['col_time']
  39. label = args['label']
  40. df_accuracy_beginTime = df_accuracy[col_time].min()
  41. df_accuracy_endTime = df_accuracy[col_time].max()
  42. df_train = df_all[df_all[col_time]<df_accuracy_beginTime][features+[col_time,label]]
  43. df_train_beginTime = df_train[col_time].min()
  44. df_train_endTime = df_train[col_time].max()
  45. text_content = f"训练数据时间范围:{df_train_beginTime} 至 {df_train_endTime},共{df_train.shape[0]}条记录,测试集数据时间范围:{df_accuracy_beginTime} 至 {df_accuracy_endTime}。<br>lightgbm模型参数:{model.params}"
  46. return text_content
  47. def draw_global_scatter(df,args):
  48. # --- 1. 实际功率和辐照度的散点图 ---
  49. col_x = args['scatter_col_x']
  50. col_y = args['label']
  51. scatter_fig = px.scatter(
  52. df,
  53. x=col_x,
  54. y=col_y,
  55. title=f"{col_x}和{col_y}的散点图",
  56. labels={"辐照度": "辐照度 (W/m²)", "实际功率": "实际功率 (kW)"}
  57. )
  58. return scatter_fig
  59. def draw_corr(df,features,args):
  60. # --- 2. 相关性热力图 ---
  61. # 计算相关性矩阵
  62. label = args['label']
  63. features_coor = features+[label]
  64. corr_matrix = df[features_coor].corr()
  65. # 使用 Plotly Express 绘制热力图
  66. heatmap_fig = px.imshow(corr_matrix,
  67. text_auto=True, # 显示数值
  68. color_continuous_scale='RdBu', # 配色方案
  69. title="Correlation Heatmap")
  70. heatmap_fig.update_coloraxes(showscale=False)
  71. return heatmap_fig
  72. def draw_feature_importance(model,features):
  73. # --- 3. 特征重要性排名 ---
  74. # 获取特征重要性
  75. importance = model.feature_importance() # 'split' 或 'gain',根据需求选择
  76. # 转换为 DataFrame 方便绘图
  77. feature_importance_df = pd.DataFrame({
  78. 'Feature': features,
  79. 'Importance': importance
  80. })
  81. feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)
  82. # 使用 Plotly Express 绘制条形图
  83. importance_fig = px.bar(feature_importance_df, x='Feature', y='Importance',
  84. title="特征重要性排名",
  85. labels={'Feature': '特征', 'Importance': '重要性'},
  86. color='Importance',
  87. color_continuous_scale='Viridis')
  88. # 更新每个 trace,确保没有图例
  89. importance_fig.update_layout(title="模型特征重要性排名",
  90. showlegend=False # 移除图例
  91. )
  92. importance_fig.update_coloraxes(showscale=False)
  93. return importance_fig
  94. def draw_data_info_table(content):
  95. # --- 4. 创建数据说明的表格 ---
  96. # 转换为表格格式:1行1列,且填充文字说明
  97. # 转换为表格格式
  98. # 创建一个空的图
  99. table_fig = go.Figure()
  100. # 第一部分: 显示文字说明
  101. table_fig.add_trace(go.Table(
  102. header=dict(
  103. values=["说明"], # 表格只有一列:说明
  104. fill_color="paleturquoise",
  105. align="center"
  106. ),
  107. cells=dict(
  108. values=[[content]] , # 第一行填入文本说明
  109. fill_color="lavender",
  110. align="center"
  111. )
  112. ))
  113. return table_fig
  114. def draw_accuracy_table(df,content):
  115. # --- 4. 每日的准确率表格 ---
  116. # 转换为表格格式
  117. table_fig = go.Figure(
  118. data=[
  119. go.Table(
  120. header=dict(
  121. values=list(df.columns),
  122. fill_color="paleturquoise",
  123. align="center"
  124. ),
  125. cells=dict(
  126. values=[df[col] for col in df.columns],
  127. fill_color="lavender",
  128. align="center"
  129. )
  130. )
  131. ]
  132. )
  133. table_fig.update_layout(title="准确率表", showlegend=False)
  134. return table_fig
  135. @app.route('/analysis_report', methods=['POST'])
  136. def analysis_report():
  137. start_time = time.time()
  138. result = {}
  139. success = 0
  140. path = ""
  141. print("Program starts execution!")
  142. try:
  143. args = request.values.to_dict()
  144. print('args',args)
  145. logger.info(args)
  146. #获取数据
  147. df_all, df_accuracy, model = get_data_from_mongo(args)
  148. features = model.feature_name()
  149. text_content = draw_info(df_all,df_accuracy,model,features,args)
  150. text_fig,scatter_fig,heatmap_fig,importance_fig,table_fig=draw_data_info_table(text_content),draw_global_scatter(df_all,args),draw_corr(df_all,features,args),draw_feature_importance(model,features),\
  151. draw_accuracy_table(df_accuracy,text_content)
  152. # --- 合并图表并保存到一个 HTML 文件 ---
  153. # 创建子图布局
  154. combined_fig = make_subplots(
  155. rows=5, cols=1,
  156. subplot_titles=["数据-模型概览","辐照度和实际功率的散点图", "相关性","特征重要性排名", "准确率表"],
  157. row_heights=[0.3, 0.6, 0.6, 0.6, 0.4],
  158. specs=[[{"type": "table"}], [{"type": "xy"}], [{"type": "heatmap"}], [{"type": "xy"}],[{"type": "table"}]] # 指定每个子图类型
  159. )
  160. # 添加文本信息到子图(第一行)
  161. # 添加文字说明
  162. for trace in text_fig.data:
  163. combined_fig.add_trace(trace, row=1, col=1)
  164. # 添加散点图
  165. for trace in scatter_fig.data:
  166. combined_fig.add_trace(trace, row=2, col=1)
  167. # 添加相关性热力图
  168. for trace in heatmap_fig.data:
  169. combined_fig.add_trace(trace, row=3, col=1)
  170. # 添加特征重要性排名图
  171. for trace in importance_fig.data:
  172. combined_fig.add_trace(trace, row=4, col=1)
  173. # 添加表格
  174. for trace in table_fig.data:
  175. combined_fig.add_trace(trace, row=5, col=1)
  176. # 更新布局
  177. combined_fig.update_layout(
  178. height=1500,
  179. title_text="分析结果汇总", # 添加换行符以适应文本内容
  180. title_x=0.5, # 中心对齐标题
  181. showlegend=False,
  182. )
  183. combined_fig.update_coloraxes(showscale=False)
  184. filename = f"{int(time.time() * 1000)}_{random.randint(1000, 9999)}.html"
  185. # 保存为 HTML
  186. directory = '/usr/share/nginx/html'
  187. if not os.path.exists(directory):
  188. os.makedirs(directory)
  189. file_path = os.path.join(directory, filename)
  190. # combined_fig.write_html(f"D://usr//{filename}")
  191. combined_fig.write_html(file_path)
  192. path = f"http://ds2:10093/data/html/{filename}"
  193. success = 1
  194. except Exception as e:
  195. my_exception = traceback.format_exc()
  196. my_exception.replace("\n","\t")
  197. result['msg'] = my_exception
  198. end_time = time.time()
  199. result['success'] = success
  200. result['args'] = args
  201. result['start_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
  202. result['end_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))
  203. result['file_path'] = path
  204. print("Program execution ends!")
  205. return result
  206. if __name__=="__main__":
  207. print("Program starts execution!")
  208. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  209. logger = logging.getLogger("analysis_report log")
  210. from waitress import serve
  211. serve(app, host="0.0.0.0", port=10092)
  212. print("server start!")