David 2 周之前
父节点
当前提交
cc6d80421a
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 1 1
      models_processing/model_tf/tf_lstm3_pre.py
  2. 1 1
      models_processing/model_tf/tf_lstm3_train.py

+ 1 - 1
models_processing/model_tf/tf_lstm3_pre.py

@@ -17,7 +17,7 @@ from copy import deepcopy
 model_lock = Lock()
 from itertools import chain
 from common.logs import Log
-from tf_lstm import TSHandler
+from models_processing.model_tf.tf_bilstm import TSHandler
 logger = Log('tf_ts3').logger
 np.random.seed(42)  # NumPy随机种子
 # tf.set_random_seed(42)  # TensorFlow随机种子

+ 1 - 1
models_processing/model_tf/tf_lstm3_train.py

@@ -12,7 +12,7 @@ import logging, argparse
 from data_processing.data_operation.data_handler import DataHandler
 import time, yaml, threading
 from copy import deepcopy
-from models_processing.model_tf.tf_lstm import TSHandler
+from models_processing.model_tf.tf_bilstm import TSHandler
 from common.database_dml_koi import *
 from common.logs import Log
 logger = Log('tf_ts3').logger