소스 검색

Merge branch 'dev_david' of anweiguo/algorithm_platform into dev_awg

liudawei 3 달 전
부모
커밋
838bc40d03

+ 130 - 20
common/database_dml_koi.py

@@ -174,12 +174,12 @@ def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any])
 
     try:
         # ------------------------- 临时文件处理 -------------------------
-        fd, temp_path = tempfile.mkstemp(suffix='.h5')
+        fd, temp_path = tempfile.mkstemp(suffix='.keras')
         os.close(fd)  # 立即释放文件锁
 
         # ------------------------- 模型保存 -------------------------
         try:
-            model.save(temp_path, save_format='h5')
+            model.save(temp_path) # 不指定save_format,默认使用keras新格式
         except Exception as e:
             raise RuntimeError(f"模型保存失败: {str(e)}") from e
 
@@ -189,9 +189,17 @@ def insert_trained_model_into_mongo(model: tf.keras.Model, args: Dict[str, Any])
         collection = db[args['model_table']]
 
         # ------------------------- 索引检查 -------------------------
-        if "gen_time_1" not in collection.index_information():
-            collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
-            print("已创建时间索引")
+        # index_info = collection.index_information()
+        # if "gen_time_1" not in index_info:
+        #     print("开始创建索引...")
+        #     collection.create_index(
+        #         [("gen_time", ASCENDING)],
+        #         name="gen_time_1",
+        #         background=True
+        #     )
+        #     print("索引创建成功")
+        # else:
+        #     print("索引已存在,跳过创建")
 
         # ------------------------- 容量控制 -------------------------
         # 使用更高效的计数方式
@@ -271,9 +279,9 @@ def insert_scaler_model_into_mongo(feature_scaler_bytes: BytesIO, target_scaler_
         collection = db[args['scaler_table']]
 
         # ------------------------- 索引维护 -------------------------
-        if "gen_time_1" not in collection.index_information():
-            collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
-            print("⏱️ 已创建时间排序索引")
+        # if "gen_time_1" not in collection.index_information():
+        #     collection.create_index([("gen_time", ASCENDING)], name="gen_time_1")
+        #     print("⏱️ 已创建时间排序索引")
 
         # ------------------------- 容量控制 -------------------------
         # 使用近似计数提升性能(误差在几十条内可接受)
@@ -337,6 +345,7 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
     # ------------------------- 环境配置 -------------------------
     mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
     client = None
+    tmp_file_path = None  # 用于跟踪临时文件路径
     try:
         # ------------------------- 数据库连接 -------------------------
         client = MongoClient(
@@ -379,7 +388,7 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
             # with h5py.File(model_buffer, 'r', driver='fileobj') as f:
             #     model = tf.keras.models.load_model(f, custom_objects=custom_objects)
             # 创建临时文件
-            with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_file:
+            with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
                 tmp_file.write(model_data)
                 tmp_file_path = tmp_file.name  # 获取临时文件路径
 
@@ -400,7 +409,109 @@ def get_h5_model_from_mongo( args: Dict[str, Any], custom_objects: Optional[Dict
         # ------------------------- 资源清理 -------------------------
         if client:
             client.close()
+        # 确保删除临时文件
+        if tmp_file_path and os.path.exists(tmp_file_path):
+            try:
+                os.remove(tmp_file_path)
+                print(f"🧹 已清理临时文件: {tmp_file_path}")
+            except Exception as cleanup_err:
+                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
+
+
+def get_keras_model_from_mongo(
+        args: Dict[str, Any],
+        custom_objects: Optional[Dict[str, Any]] = None
+) -> Optional[tf.keras.Model]:
+    """
+    从MongoDB获取指定模型的最新版本(支持Keras格式)
+
+    参数:
+    args : dict - 包含以下键的字典:
+        mongodb_database: 数据库名称
+        model_table: 集合名称
+        model_name: 要获取的模型名称
+    custom_objects: dict - 自定义Keras对象字典
+
+    返回:
+    tf.keras.Model - 加载成功的Keras模型
+    """
+    # ------------------------- 参数校验 -------------------------
+    required_keys = {'mongodb_database', 'model_table', 'model_name'}
+    if missing := required_keys - args.keys():
+        raise ValueError(f"❌ 缺失必要参数: {missing}")
+
+    # ------------------------- 环境配置 -------------------------
+    mongo_uri = os.getenv("MONGO_URI", "mongodb://root:sdhjfREWFWEF23e@192.168.1.43:30000/")
+    client = None
+    tmp_file_path = None  # 用于跟踪临时文件路径
+
+    try:
+        # ------------------------- 数据库连接 -------------------------
+        client = MongoClient(
+            mongo_uri,
+            maxPoolSize=10,
+            socketTimeoutMS=5000
+        )
+        db = client[args['mongodb_database']]
+        collection = db[args['model_table']]
 
+        # ------------------------- 索引维护 -------------------------
+        # index_name = "model_gen_time_idx"
+        # if index_name not in collection.index_information():
+        #     collection.create_index(
+        #         [("model_name", 1), ("gen_time", DESCENDING)],
+        #         name=index_name
+        #     )
+        #     print("⏱️ 已创建复合索引")
+
+        # ------------------------- 高效查询 -------------------------
+        model_doc = collection.find_one(
+            {"model_name": args['model_name']},
+            sort=[('gen_time', DESCENDING)],
+            projection={"model_data": 1, "gen_time": 1}
+        )
+
+        if not model_doc:
+            print(f"⚠️ 未找到模型 '{args['model_name']}' 的有效记录")
+            return None
+
+        # ------------------------- 内存优化加载 -------------------------
+        model_data = model_doc['model_data']
+
+        # 创建临时文件(自动删除)
+        with tempfile.NamedTemporaryFile(suffix=".keras", delete=False) as tmp_file:
+            tmp_file.write(model_data)
+            tmp_file_path = tmp_file.name  # 记录文件路径
+
+        # 从临时文件加载模型
+        model = tf.keras.models.load_model(
+            tmp_file_path,
+            custom_objects=custom_objects
+        )
+
+        print(f"{args['model_name']} 模型成功从 MongoDB 加载!")
+        return model
+
+    except tf.errors.NotFoundError as e:
+        print(f"❌ 模型结构缺失关键组件: {str(e)}")
+        raise RuntimeError("模型架构不完整") from e
+
+    except Exception as e:
+        print(f"❌ 系统异常: {str(e)}")
+        raise
+
+    finally:
+        # ------------------------- 资源清理 -------------------------
+        if client:
+            client.close()
+
+        # 确保删除临时文件
+        if tmp_file_path and os.path.exists(tmp_file_path):
+            try:
+                os.remove(tmp_file_path)
+                print(f"🧹 已清理临时文件: {tmp_file_path}")
+            except Exception as cleanup_err:
+                print(f"⚠️ 临时文件清理失败: {str(cleanup_err)}")
 
 def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool = False) -> Union[object, Tuple[object, object]]:
     """
@@ -441,14 +552,14 @@ def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool
         collection = db[args['scaler_table']]
 
         # ------------------------- 索引维护 -------------------------
-        index_name = "model_gen_time_idx"
-        if index_name not in collection.index_information():
-            collection.create_index(
-                [("model_name", 1), ("gen_time", DESCENDING)],
-                name=index_name,
-                background=True  # 后台构建避免阻塞
-            )
-            print("⏱️ 已创建特征缩放器复合索引")
+        # index_name = "model_gen_time_idx"
+        # if index_name not in collection.index_information():
+        #     collection.create_index(
+        #         [("model_name", 1), ("gen_time", DESCENDING)],
+        #         name=index_name,
+        #         background=True  # 后台构建避免阻塞
+        #     )
+        #     print("⏱️ 已创建特征缩放器复合索引")
 
         # ------------------------- 高效查询 -------------------------
         scaler_doc = collection.find_one(
@@ -496,7 +607,6 @@ def get_scaler_model_from_mongo(args: Dict[str, Any], only_feature_scaler: bool
         raise RuntimeError(f"❌ 未知系统异常: {str(e)}") from e
     finally:
         # ------------------------- 资源清理 -------------------------
-        # if client:
-        #     client.close()
-        pass
+        if client:
+            client.close()
 

+ 62 - 3
data_processing/data_operation/pre_prod_ftp.py

@@ -9,8 +9,7 @@
 要实现的功能:
 1. 获取场站-配置接口,根据:
 类型(超短期 短期 中期)
-    算法工程师 / 模型 版本
-        场站编码
+    算法工程师 / 模型 版本 / 场站编码
 获取所有当前配置的场站模型
 2. 根据场站模型配置和时刻,形成:
 类型(超短期 短期 中期)
@@ -21,6 +20,66 @@
 (3)合并所有算法工程是的场站编码集合,形成类型下的场站编码集合
 3. 压缩成类型(超短期 短期 中期)三个zip文件,上传生产FTP
 """
+from collections import defaultdict
+import requests
+import json
+import os
+import paramiko
+import zipfile
+from io import BytesIO
+from datetime import datetime
+api_url = 'http://itil.jiayuepowertech.com:9958/itil/api/stationModelConfig'
+nick_name = {
+    '0': 'seer',
+    '1': 'koi',
+    '2': 'lucky'
+}
+def fetch_station_records(model_type, is_current=1):
+    """
+    调用接口获取场站记录
+    :paramModelType: 模型类型 0 超短期 1 短期 2 中期
+    :paramIsCurrent: 模型启动状态(如 1 或 0)
+    :return: 场站记录列表或错误信息
+    """
+    params = {
+        "paramModelType": str(model_type),
+        "paramIsCurrent": str(is_current)  # 适配接口参数格式
+    }
+
+    try:
+        response = requests.get(api_url, params=params, timeout=10)
+        response.raise_for_status()  # 检查HTTP错误
+        return response.json()  # 假设接口返回JSON
+    except requests.exceptions.RequestException as e:
+        return {"error": f"请求失败: {str(e)}"}
+    except json.JSONDecodeError:
+        return {"error": "接口返回非JSON数据"}
+
+
+def process_station_data(api_data):
+    """
+    处理接口数据,生成三级映射关系
+    :param api_data: 接口返回的原始数据(假设为字典列表)
+    :return: 可用于生成表格的结构化数据
+    """
+    # 创建映射字典
+    mapping = {"lucky":{}, "seer":{}, "koi":{}}
+    # 遍历每条场站记录
+    for record in api_data:
+        # 提取关键字段(根据实际接口字段名称修改)
+        engineer = nick_name.get(record.get("engineerName"), "unknown")
+        model_name = record.get("modelName")
+        model_version = record.get("modelVersion")
+        station_code = record.get("stationCode")
+        assert engineer in mapping
+        if all([engineer, model_name, model_version, station_code]):
+            mapping[engineer].setdefault((model_name, model_version), set()).add(station_code)
+    return mapping
+
+
+
 
 if __name__ == "__main__":
-    run_code = 0
+    models = fetch_station_records(1)
+    mapping = process_station_data(models['data'])
+    print(mapping)

+ 239 - 0
data_processing/data_operation/test.py

@@ -0,0 +1,239 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# @FileName  :test.py
+# @Time      :2025/3/13 14:19
+# @Author    :David
+# @Company: shenyang JY
+
+import paramiko
+from datetime import datetime, timedelta
+import os
+import zipfile
+import shutil
+import tempfile
+
+# 配置信息
+SFTP_HOST = '192.168.1.33'
+SFTP_PORT = 2022
+SFTP_USERNAME = 'liudawei'
+SFTP_PASSWORD = 'liudawei@123'
+# 在原配置部分添加以下配置
+DEST_SFTP_HOST = 'dest_sftp.example.com'
+DEST_SFTP_PORT = 22
+DEST_SFTP_USERNAME = 'dest_username'
+DEST_SFTP_PASSWORD = 'dest_password'
+DEFAULT_TARGET_DIR = 'cdq'  # 默认上传目录
+
+# 更新后的三级映射
+MAPPINGS = {
+    'koi': {('Zone', '1.0'): {'J00645'}},
+    'lucky': {}, 'seer': {('lgb', '1.0'): {'J00001'}}
+}
+
+
+def get_next_target_time(current_time=None):
+    """获取下一个目标时刻"""
+    if current_time is None:
+        current_time = datetime.now()
+
+    target_hours = [0, 6, 12, 18]
+    current_hour = current_time.hour
+
+    for hour in sorted(target_hours):
+        if current_hour < hour:
+            return current_time.replace(hour=hour, minute=0, second=0, microsecond=0)
+
+    # 如果当前时间超过所有目标小时,使用次日0点
+    return (current_time + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
+
+
+def download_sftp_files(sftp, mappings, datetime_str, local_temp_dir):
+    """下载所有需要的SFTP文件"""
+    for engineer in mappings:
+        remote_base = f"/{engineer}/"  # SFTP根目录下的工程师目录
+        try:
+            sftp.chdir(remote_base)
+        except FileNotFoundError:
+            print(f"工程师目录不存在: {remote_base}")
+            continue
+
+        for model_version in mappings[engineer]:
+            # 构造目标文件名(模型版本已经合并)
+
+            target_file = f"jy_{engineer}.{'.'.join(model_version)}_{datetime_str}_dq.zip"
+            remote_path = os.path.join(remote_base, target_file).replace("\\", "/")
+            local_path = os.path.join(local_temp_dir, target_file).replace("\\", "/")
+
+            try:
+                sftp.get(remote_path, local_path)
+                print(f"下载成功: {remote_path}")
+            except Exception as e:
+                print(f"文件下载失败 {remote_path}: {str(e)}")
+
+
+def download_files_via_sftp(mappings, datetime_str, local_temp_dir):
+    """
+    封装SFTP连接和文件下载的完整流程
+    :param mappings: 文件映射配置
+    :param datetime_str: 日期时间字符串,用于文件名
+    :param local_temp_dir: 本地临时目录路径
+    """
+    transport = None
+    sftp = None
+    try:
+        # 创建SSH传输通道
+        transport = paramiko.Transport((SFTP_HOST, SFTP_PORT))
+        transport.connect(username=SFTP_USERNAME, password=SFTP_PASSWORD)
+
+        # 创建SFTP客户端
+        sftp = paramiko.SFTPClient.from_transport(transport)
+
+        # 执行文件下载
+        for engineer in mappings:
+            remote_base = f"/{engineer}/"
+            try:
+                sftp.chdir(remote_base)
+            except FileNotFoundError:
+                print(f"工程师目录不存在: {remote_base}")
+                continue
+
+            for model_version in mappings[engineer]:
+                target_file = f"jy_{engineer}.{'.'.join(model_version)}_{datetime_str}_dq.zip"
+                remote_path = os.path.join(remote_base, target_file).replace("\\", "/")
+                local_path = os.path.join(local_temp_dir, target_file).replace("\\", "/")
+
+                try:
+                    sftp.get(remote_path, local_path)
+                    print(f"下载成功: {remote_path} -> {local_path}")
+                except Exception as e:
+                    print(f"文件下载失败 {remote_path}: {str(e)}")
+
+    except paramiko.AuthenticationException:
+        print("认证失败,请检查用户名和密码")
+    except paramiko.SSHException as e:
+        print(f"SSH连接异常: {str(e)}")
+    except Exception as e:
+        print(f"未知错误: {str(e)}")
+    finally:
+        # 确保连接关闭
+        if sftp:
+            sftp.close()
+        if transport and transport.is_active():
+            transport.close()
+
+def upload_to_sftp(local_path, target_dir):
+    """上传文件到目标SFTP服务器"""
+    transport = None
+    sftp = None
+    try:
+        # 创建新的传输连接
+        transport = paramiko.Transport((DEST_SFTP_HOST, DEST_SFTP_PORT))
+        transport.connect(username=DEST_SFTP_USERNAME, password=DEST_SFTP_PASSWORD)
+        sftp = paramiko.SFTPClient.from_transport(transport)
+
+        # 确保目标目录存在
+        try:
+            sftp.chdir(target_dir)
+        except FileNotFoundError:
+            sftp.mkdir(target_dir)
+            print(f"已创建远程目录: {target_dir}")
+
+        # 构造远程路径
+        filename = os.path.basename(local_path)
+        remote_path = f"{target_dir}/{filename}"
+
+        # 执行上传
+        sftp.put(local_path, remote_path)
+        print(f"成功上传到: {remote_path}")
+
+    except Exception as e:
+        print(f"上传失败: {str(e)}")
+        raise
+    finally:
+        # 确保连接关闭
+        if sftp:
+            sftp.close()
+        if transport and transport.is_active():
+            transport.close()
+
+def process_zips(mappings, local_temp_dir, datetime_str, final_collect_dir):
+    """处理所有下载的ZIP文件并收集场站目录"""
+    for engineer in mappings:
+        for model_version in mappings[engineer]:
+            target_file = f"jy_{engineer}.{'.'.join(model_version)}_{datetime_str}_dq.zip"
+            zip_path = os.path.join(local_temp_dir, target_file).replace("\\", "/")
+            station_codes = mappings[engineer][model_version]
+
+            if not os.path.exists(zip_path):
+                continue
+
+            # 创建临时解压目录
+            with tempfile.TemporaryDirectory() as temp_extract:
+                # 解压ZIP文件
+                try:
+                    with zipfile.ZipFile(zip_path, 'r') as zf:
+                        zf.extractall(temp_extract)
+                except zipfile.BadZipFile:
+                    print(f"无效的ZIP文件: {zip_path}")
+                    continue
+
+                # 收集场站目录
+                for root, dirs, files in os.walk(temp_extract):
+                    for dir_name in dirs:
+                        if dir_name in station_codes:
+                            src = os.path.join(root, dir_name)
+                            dest = os.path.join(final_collect_dir, dir_name)
+
+                            if not os.path.exists(dest):
+                                shutil.copytree(src, dest)
+                                print(f"已收集场站: {dir_name}")
+
+
+def create_final_zip(final_collect_dir, datetime_str, output_path="result.zip"):
+    """创建最终打包的ZIP文件"""
+    with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zf:
+        for root, dirs, files in os.walk(final_collect_dir):
+            for file in files:
+                file_path = os.path.join(root, file)
+                arcname = os.path.relpath(file_path, final_collect_dir)
+                zf.write(file_path, arcname)
+    print(f"最终打包完成: {output_path}")
+
+
+
+
+def main():
+    # 创建临时工作目录
+    with tempfile.TemporaryDirectory() as local_temp_dir:
+        final_collect_dir = os.path.join(local_temp_dir, 'collected_stations')
+        os.makedirs(final_collect_dir, exist_ok=True)
+
+        # 计算目标时间
+        target_time = get_next_target_time()
+        datetime_str = target_time.strftime("%Y%m%d%H")
+        datetime_str = '2025012118'
+        print(f"目标时间: {datetime_str}")
+
+        # 连接SFTP
+        # transport = paramiko.Transport((SFTP_HOST, SFTP_PORT))
+        # transport.connect(username=SFTP_USERNAME, password=SFTP_PASSWORD)
+        # sftp = paramiko.SFTPClient.from_transport(transport)
+
+        # 下载文件
+        download_files_via_sftp(MAPPINGS, datetime_str, local_temp_dir)
+
+        # 关闭SFTP连接
+        # sftp.close()
+        # transport.close()
+
+        # 处理下载的文件
+        process_zips(MAPPINGS, local_temp_dir, datetime_str, final_collect_dir)
+
+        # 创建最终ZIP
+        create_final_zip(final_collect_dir, datetime_str)
+
+        # 上传打包ZIP文件
+        upload_to_sftp()
+
+if __name__ == "__main__":
+    main()

+ 2 - 2
models_processing/model_tf/tf_bp.py

@@ -32,7 +32,7 @@ class BPHandler(object):
         try:
             with model_lock:
                 # loss = region_loss(self.opt)
-                self.model = get_h5_model_from_mongo(args)
+                self.model = get_keras_model_from_mongo(args)
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -89,7 +89,7 @@ class BPHandler(object):
             if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
                 loss = region_loss(self.opt)
-                base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else:

+ 2 - 2
models_processing/model_tf/tf_cnn.py

@@ -32,7 +32,7 @@ class CNNHandler(object):
         try:
             with model_lock:
                 loss = region_loss(self.opt)
-                self.model = get_h5_model_from_mongo(args, {type(loss).__name__: loss})
+                self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -61,7 +61,7 @@ class CNNHandler(object):
             if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
                 loss = region_loss(self.opt)
-                base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else:

+ 2 - 2
models_processing/model_tf/tf_lstm.py

@@ -31,7 +31,7 @@ class TSHandler(object):
         try:
             with model_lock:
                 loss = region_loss(self.opt)
-                self.model = get_h5_model_from_mongo(args, {type(loss).__name__: loss})
+                self.model = get_keras_model_from_mongo(args, {type(loss).__name__: loss})
         except Exception as e:
             self.logger.info("加载模型权重失败:{}".format(e.args))
 
@@ -58,7 +58,7 @@ class TSHandler(object):
             if self.opt.Model['add_train']:
                 # 进行加强训练,支持修模
                 loss = region_loss(self.opt)
-                base_train_model = get_h5_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
+                base_train_model = get_keras_model_from_mongo(vars(self.opt), {type(loss).__name__: loss})
                 base_train_model.summary()
                 self.logger.info("已加载加强训练基础模型")
             else: