test.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #!/usr/bin/env python
  2. # -*- coding:utf-8 -*-
  3. # @FileName :test.py
  4. # @Time :2025/3/13 14:19
  5. # @Author :David
  6. # @Company: shenyang JY
  7. from enum import Enum
  8. import paramiko
  9. from datetime import datetime, timedelta
  10. from typing import Optional
  11. import os
  12. import zipfile
  13. import shutil
  14. import tempfile
  15. # 配置信息
  16. SFTP_HOST = '192.168.1.33'
  17. SFTP_PORT = 2022
  18. SFTP_USERNAME = 'liudawei'
  19. SFTP_PASSWORD = 'liudawei@123'
  20. # 在原配置部分添加以下配置
  21. DEST_SFTP_HOST = 'dest_sftp.example.com'
  22. DEST_SFTP_PORT = 22
  23. DEST_SFTP_USERNAME = 'dest_username'
  24. DEST_SFTP_PASSWORD = 'dest_password'
  25. DEFAULT_TARGET_DIR = 'cdq' # 默认上传目录
  26. # 更新后的三级映射
  27. MAPPINGS = {
  28. 'koi': {('Zone', '1.0'): {'J00645'}},
  29. 'lucky': {}, 'seer': {('lgb', '1.0'): {'J00578'}}
  30. }
  31. def get_next_target_time(current_time=None):
  32. """获取下一个目标时刻"""
  33. if current_time is None:
  34. current_time = datetime.now()
  35. target_hours = [0, 6, 12, 18]
  36. current_hour = current_time.hour
  37. for hour in sorted(target_hours):
  38. if current_hour < hour:
  39. return current_time.replace(hour=hour, minute=0, second=0, microsecond=0)
  40. # 如果当前时间超过所有目标小时,使用次日0点
  41. return (current_time + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
  42. def download_files_via_sftp(mappings, datetime_str, local_temp_dir, model_type):
  43. """
  44. 封装SFTP连接和文件下载的完整流程
  45. :param mappings: 文件映射配置
  46. :param datetime_str: 日期时间字符串,用于文件名
  47. :param local_temp_dir: 本地临时目录路径
  48. """
  49. transport = None
  50. sftp = None
  51. try:
  52. # 创建SSH传输通道
  53. transport = paramiko.Transport((SFTP_HOST, SFTP_PORT))
  54. transport.connect(username=SFTP_USERNAME, password=SFTP_PASSWORD)
  55. # 创建SFTP客户端
  56. sftp = paramiko.SFTPClient.from_transport(transport)
  57. # 执行文件下载
  58. for engineer in mappings:
  59. datetime_str = datetime_str if engineer == 'koi' else 2025012000
  60. remote_base = f"/{engineer}/"
  61. try:
  62. sftp.chdir(remote_base)
  63. except FileNotFoundError:
  64. print(f"工程师目录不存在: {remote_base}")
  65. continue
  66. for model_version in mappings[engineer]:
  67. target_file = f"jy_{engineer}.{'.'.join(model_version)}_{datetime_str}.zip"
  68. remote_path = os.path.join(remote_base, target_file).replace("\\", "/")
  69. local_path = os.path.join(local_temp_dir, target_file).replace("\\", "/")
  70. try:
  71. sftp.get(remote_path, local_path)
  72. print(f"下载成功: {remote_path} -> {local_path}")
  73. except Exception as e:
  74. print(f"文件下载失败 {remote_path}: {str(e)}")
  75. except paramiko.AuthenticationException:
  76. print("认证失败,请检查用户名和密码")
  77. except paramiko.SSHException as e:
  78. print(f"SSH连接异常: {str(e)}")
  79. except Exception as e:
  80. print(f"未知错误: {str(e)}")
  81. finally:
  82. # 遍历到最后一个中短期,确保连接关闭
  83. if model_type == 'zcq':
  84. if sftp:
  85. sftp.close()
  86. if transport and transport.is_active():
  87. transport.close()
  88. def upload_to_sftp(local_path: str, target_dir: str) -> bool:
  89. """上传文件到SFTP服务器
  90. Args:
  91. local_path: 本地文件路径
  92. target_dir: 远程目标目录
  93. Returns:
  94. 上传是否成功 (True/False)
  95. """
  96. transport: Optional[paramiko.Transport] = None
  97. sftp: Optional[paramiko.SFTPClient] = None
  98. try:
  99. # 验证本地文件存在
  100. if not os.path.isfile(local_path):
  101. raise FileNotFoundError(f"本地文件不存在: {local_path}")
  102. # 创建SFTP连接
  103. transport = paramiko.Transport((DEST_SFTP_HOST, DEST_SFTP_PORT))
  104. transport.connect(username=DEST_SFTP_USERNAME, password=DEST_SFTP_PASSWORD)
  105. sftp = paramiko.SFTPClient.from_transport(transport)
  106. # 执行上传
  107. remote_filename = os.path.basename(local_path)
  108. remote_path = f"{target_dir}/{remote_filename}"
  109. sftp.put(local_path, remote_path)
  110. print(f"[SUCCESS] 上传完成: {remote_path}")
  111. return True
  112. except Exception as e:
  113. print(f"[ERROR] 上传失败: {str(e)}")
  114. return False
  115. finally:
  116. # 确保资源释放
  117. if sftp:
  118. sftp.close()
  119. if transport and transport.is_active():
  120. transport.close()
  121. def process_zips(mappings, local_temp_dir, datetime_str, final_collect_dir):
  122. """处理所有下载的ZIP文件并收集场站目录"""
  123. for engineer in mappings:
  124. datetime_str = datetime_str if engineer == 'koi' else 2025012000
  125. for model_version in mappings[engineer]:
  126. target_file = f"jy_{engineer}.{'.'.join(model_version)}_{datetime_str}_dq.zip"
  127. zip_path = os.path.join(local_temp_dir, target_file).replace("\\", "/")
  128. station_codes = mappings[engineer][model_version]
  129. if not os.path.exists(zip_path):
  130. continue
  131. # 创建临时解压目录
  132. with tempfile.TemporaryDirectory() as temp_extract:
  133. # 解压ZIP文件
  134. try:
  135. with zipfile.ZipFile(zip_path, 'r') as zf:
  136. zf.extractall(temp_extract)
  137. except zipfile.BadZipFile:
  138. print(f"无效的ZIP文件: {zip_path}")
  139. continue
  140. # 收集场站目录
  141. for root, dirs, files in os.walk(temp_extract):
  142. for dir_name in dirs:
  143. if dir_name in station_codes:
  144. src = os.path.join(root, dir_name)
  145. dest = os.path.join(final_collect_dir, dir_name)
  146. if not os.path.exists(dest):
  147. shutil.copytree(src, dest)
  148. print(f"已收集场站: {dir_name}")
  149. def create_final_zip(final_collect_dir: str, datetime_str: str, model_type: str) -> str:
  150. """创建ZIP压缩包并返回完整路径
  151. Args:
  152. final_collect_dir: 需要打包的源目录
  153. datetime_str: 时间戳字符串
  154. model_type: 模型类型标识
  155. Returns:
  156. 生成的ZIP文件完整路径
  157. """
  158. # 确保缓存目录存在
  159. os.makedirs('../cache/ftp', exist_ok=True)
  160. # 构造标准化文件名
  161. zip_filename = f"jy_algo_{datetime_str}_{model_type}.zip"
  162. output_path = os.path.join('../cache/ftp', zip_filename)
  163. try:
  164. with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zf:
  165. for root, _, files in os.walk(final_collect_dir):
  166. for file in files:
  167. file_path = os.path.join(root, file)
  168. arcname = os.path.relpath(file_path, final_collect_dir)
  169. zf.write(file_path, arcname)
  170. print(f"[SUCCESS] ZIP创建成功: {output_path}")
  171. return output_path
  172. except Exception as e:
  173. print(f"[ERROR] 创建ZIP失败: {str(e)}")
  174. raise
  175. def clean_up_file(file_path: str) -> None:
  176. """安全删除本地文件"""
  177. try:
  178. if os.path.exists(file_path):
  179. os.remove(file_path)
  180. print(f"[CLEANUP] 已删除本地文件: {file_path}")
  181. except Exception as e:
  182. print(f"[WARNING] 文件删除失败: {str(e)}")
  183. def prod_data_handler(mappings):
  184. # 创建临时工作目录
  185. for model_type in ['cdq', 'dq', 'zcq']:
  186. with tempfile.TemporaryDirectory() as local_temp_dir:
  187. final_collect_dir = os.path.join(local_temp_dir, 'collected_stations')
  188. os.makedirs(final_collect_dir, exist_ok=True)
  189. # 计算目标时间
  190. target_time = get_next_target_time()
  191. datetime_str = target_time.strftime("%Y%m%d%H")
  192. print(f"目标时间: {datetime_str}")
  193. datetime_str = '2025012412'
  194. # 下载文件
  195. download_files_via_sftp(mappings, datetime_str, local_temp_dir, model_type)
  196. # 处理下载的文件
  197. process_zips(mappings, local_temp_dir, datetime_str, final_collect_dir)
  198. # 创建最终ZIP
  199. zip_path = create_final_zip(final_collect_dir, datetime_str, model_type)
  200. # 上传打包ZIP文件
  201. if upload_to_sftp(zip_path, f"/{model_type}"):
  202. # 步骤3: 上传成功后清理
  203. clean_up_file(zip_path)
  204. else:
  205. print("[WARNING] 上传未成功,保留本地文件")
  206. if __name__ == "__main__":
  207. prod_data_handler()