Pārlūkot izejas kodu

新增训练与预测任务的 编辑接口

刘桐 2 nedēļas atpakaļ
vecāks
revīzija
9c939b5a42

+ 152 - 0
xvji-admin/src/main/java/com/xvji/web/controller/PredictTaskController.java

@@ -259,6 +259,158 @@ public class PredictTaskController {
         }
     }
 
+    /**
+     * 编辑预测任务
+     * @param taskInfo
+     * @return
+     */
+    @Transactional
+    @PostMapping("/updateTask")
+    public AjaxResult updateTask(@RequestBody Map<String , Object> taskInfo){
+        try {
+            PredictTask task = new PredictTask();
+            if (taskInfo == null){
+                return AjaxResult.error("预测任务不存在 , 无法编辑");
+            }
+            // id校验
+            Object pTaskIdObj = taskInfo.get("pTaskId");
+            if (pTaskIdObj == null){
+                return AjaxResult.error("任务id 不可为空");
+            }
+            Long pTaskId = null;
+            try {
+                pTaskId = Long.parseLong(pTaskIdObj.toString());
+            }catch (Exception e){
+                return AjaxResult.error("传入的任务id不正确");
+            }
+
+            //状态校验
+            Integer pTaskStatus = null;
+            try {
+                Object pTaskStatusObj = taskInfo.get("pTaskStatus");
+                if (pTaskStatusObj != null){
+                    pTaskStatus = Integer.parseInt(pTaskStatusObj.toString());
+                    if (pTaskStatus != 0 && pTaskStatus != 1){
+                        return AjaxResult.error("任务状态必须是0或1,请检查输入");
+                    }
+                }
+            }catch (Exception e){
+                return AjaxResult.error("任务状态格式不正确");
+            }
+
+            task.setPTaskId(pTaskId);
+            task.setPTaskName((String) taskInfo.get("pTaskName"));
+            task.setPCronExpression((String) taskInfo.get("pCronExpression"));
+            task.setPQuartzTask((String) taskInfo.get("pQuartzTask"));
+            task.setPRunInfo((String) taskInfo.get("pRunInfo"));
+            task.setPTaskStatus(pTaskStatus);
+
+            boolean updateTask = predictTaskService.updateById(task);
+            if (!updateTask){
+                return AjaxResult.error("编辑任务失败");
+            }
+            QueryWrapper<Component> componentQueryWrapper = new QueryWrapper<>();
+            componentQueryWrapper
+                    .eq("TASK_ID", pTaskId)
+                    .eq("TASK_TYPE", 1);//预测任务类型为1
+            componentService.remove(componentQueryWrapper);
+
+            List<Long> componentIds = new ArrayList<>();
+
+            // 数据获取组件
+            Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
+            if (dataAcquisition != null) {
+                Boolean isEnable = parseEnableValue(dataAcquisition.get("enable"), "数据获取");
+                String type = (String) dataAcquisition.get("name");
+                Component component = createComponent(pTaskId, type, dataAcquisition, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("数据获取组件保存失败");
+                }
+            }
+
+            // 数据处理组件
+            Map<String, Object> dataCleaning = (Map<String, Object>) taskInfo.get("dataCleaning");
+            if (dataCleaning != null) {
+                Boolean isEnable = parseEnableValue(dataCleaning.get("enable"), "数据处理");
+                String type = (String) dataCleaning.get("name");
+                Component component = createComponent(pTaskId, type, dataCleaning, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("数据处理组件保存失败");
+                }
+            }
+
+            // 限电清洗组件
+            Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
+            if (powerRationing != null) {
+                Boolean isEnable = parseEnableValue(powerRationing.get("enable"), "限电清洗");
+                String type = (String) powerRationing.get("name");
+                Component component = createComponent(pTaskId, type, powerRationing, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("限电清洗组件保存失败");
+                }
+            }
+
+            // 模型组件(预测任务不需要额外添加modelTest组件)
+            Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
+            if (model != null) {
+                Boolean isEnable = parseEnableValue(model.get("enable"), "模型");
+                String type = (String) model.get("name");
+                Component component = createComponent(pTaskId, type, model, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException(type + "组件保存失败");
+                }
+            }
+
+            // 后处理组件
+            Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("processing");
+            if (postProcessing != null) {
+                Boolean isEnable = parseEnableValue(postProcessing.get("enable"), "后处理");
+                String type = (String) postProcessing.get("name");
+                Component component = createComponent(pTaskId, type, postProcessing, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("后处理组件保存失败");
+                }
+            }
+
+            // 分析报告组件
+            Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("report");
+            if (analysisReport != null) {
+                Boolean isEnable = parseEnableValue(analysisReport.get("enable"), "分析报告");
+                String type = (String) analysisReport.get("name");
+                Component component = createComponent(pTaskId, type, analysisReport, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("分析报告组件保存失败");
+                }
+            }
+
+            // 更新任务表中的组件 ID 关联(逗号分隔)
+            if (!componentIds.isEmpty()) {
+                String componentIdsStr = componentIds.stream()
+                        .map(String::valueOf)
+                        .collect(java.util.stream.Collectors.joining(","));
+                task.setPComponentIds(componentIdsStr);
+                predictTaskService.updateById(task);
+            }
+
+            return AjaxResult.success("编辑预测任务成功");
+
+        }catch (Exception e){
+            return AjaxResult.error("编辑预测任务失败" + e.getMessage());
+        }
+    }
+
 
     ////////////////////////////////////////// 辅助方法 //////////////////////////////////////////////////////////////////////
 

+ 171 - 4
xvji-admin/src/main/java/com/xvji/web/controller/TrainTaskController.java

@@ -59,14 +59,14 @@ public class TrainTaskController {
             task.setTRunInfo((String) taskInfo.get("tRunInfo"));
             task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
 
-            Object taskStatusObj = taskInfo.get("ttaskStatus");
+            Object taskStatusObj = taskInfo.get("tTaskStatus");
             if (taskStatusObj != null){
                 try {
-                    Integer ttaskStatus = Integer.parseInt(taskStatusObj.toString());
-                    if (ttaskStatus != 0 && ttaskStatus != 1) {
+                    Integer tTaskStatus = Integer.parseInt(taskStatusObj.toString());
+                    if (tTaskStatus != 0 && tTaskStatus != 1) {
                         return AjaxResult.error("任务状态必须是0或1,请检查输入");
                     }
-                    task.setTTaskStatus(ttaskStatus);
+                    task.setTTaskStatus(tTaskStatus);
                 }catch (Exception e){
                     return AjaxResult.error("传入的任务状态格式不正确");
                 }
@@ -279,6 +279,173 @@ public class TrainTaskController {
     }
 
 
+    /**
+     * 编辑训练任务
+     * @param taskInfo
+     * @return
+     */
+    @Transactional
+    @PostMapping("/updateTask")
+    public AjaxResult updateTask(@RequestBody Map<String , Object> taskInfo){
+
+        try {
+            TrainTask task = new TrainTask();
+            if (taskInfo == null){
+                return AjaxResult.error("训练任务不存在 , 无法编辑");
+            }
+            Object tTaskIdObj = taskInfo.get("tTaskId");
+            if (tTaskIdObj == null){
+                return AjaxResult.error("训练任务id不能为空");
+            }
+            Long tTaskId = null;
+            try {
+                tTaskId = Long.parseLong(tTaskIdObj.toString());
+            } catch (NumberFormatException e) {
+                return AjaxResult.error("传入的任务ID格式不正确,请传数字类型");
+            }
+
+            Object taskStatusObj = taskInfo.get("tTaskStatus");
+            Integer tTaskStatus = null;
+            if (taskStatusObj != null){
+                try {
+                    tTaskStatus = Integer.parseInt(taskStatusObj.toString());
+                    if (tTaskStatus != 0 && tTaskStatus != 1) {
+                        return AjaxResult.error("任务状态必须是0或1,请检查输入");
+                    }
+                }catch (Exception e){
+                    return AjaxResult.error("传入的任务状态格式不正确");
+                }
+            }
+
+            task.setTTaskName((String) taskInfo.get("tTaskName"));
+            task.setTCronExpression((String) taskInfo.get("tCronExpression"));
+            task.setTQuartzTask((String) taskInfo.get("tQuartzTask"));
+            task.setTRunInfo((String) taskInfo.get("tRunInfo"));
+            task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
+            task.setTTaskStatus(tTaskStatus);
+            task.setTTaskId(tTaskId);
+            boolean updateTask = trainTaskService.updateById(task);
+            if (!updateTask){
+                return AjaxResult.error("更新训练任务失败");
+            }
+
+            // 先删除原关联组件
+            QueryWrapper<Component> componentWrapper = new QueryWrapper<>();
+            componentWrapper.eq("TASK_ID", tTaskId)
+                    .eq("TASK_TYPE", 0);
+            componentService.remove(componentWrapper);
+
+            List<Long> componentIds = new ArrayList<>();
+
+            // 1. 数据获取组件
+            Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
+            if (dataAcquisition != null) {
+                Boolean isEnable = parseEnableValue(dataAcquisition.get("enable"), "数据获取");
+                String type = (String) dataAcquisition.get("name");
+                Component component = createComponent(tTaskId, type, dataAcquisition, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("数据获取组件保存失败");
+                }
+            }
+
+            // 2. 数据处理组件
+            Map<String, Object> dataCleaning = (Map<String, Object>) taskInfo.get("dataCleaning");
+            if (dataCleaning != null) {
+                Boolean isEnable = parseEnableValue(dataCleaning.get("enable"), "数据处理");
+                String type = (String) dataCleaning.get("name");
+                Component component = createComponent(tTaskId, type, dataCleaning, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("数据处理组件保存失败");
+                }
+            }
+
+            // 3. 限电清洗组件
+            Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
+            if (powerRationing != null) {
+                Boolean isEnable = parseEnableValue(powerRationing.get("enable"), "限电清洗");
+                String type = (String) powerRationing.get("name");
+                Component component = createComponent(tTaskId, type, powerRationing, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("限电清洗组件保存失败");
+                }
+            }
+
+            // 4. 模型组件 + 关联模型测试组件
+            Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
+            if (model != null) {
+                Boolean modelIsEnable = parseEnableValue(model.get("enable"), "模型");
+                String modelType = (String) model.get("name");
+                Component modelComponent = createComponent(tTaskId, modelType, model, modelIsEnable);
+                if (componentService.save(modelComponent)) {
+                    componentIds.add(modelComponent.getComponentId());
+
+                    // 训练类模型关联测试组件
+                    if ("LSTM-训练".equals(modelType) || "机器学习模型-训练".equals(modelType)) {
+                        Map<String, Object> modelTest = (Map<String, Object>) taskInfo.get("modelTest");
+                        if (modelTest != null) {
+                            Boolean testIsEnable = parseEnableValue(modelTest.get("enable"), "模型测试");
+                            String testType = (String) modelTest.get("name");
+                            Component testComponent = createComponent(tTaskId, testType, modelTest, testIsEnable);
+                            if (componentService.save(testComponent)) {
+                                componentIds.add(testComponent.getComponentId());
+                            } else {
+                                throw new RuntimeException(testType + "组件保存失败");
+                            }
+                        }
+                    }
+                } else {
+                    throw new RuntimeException(modelType + "组件保存失败");
+                }
+            }
+
+            // 5. 后处理组件
+            Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("processing");
+            if (postProcessing != null) {
+                Boolean isEnable = parseEnableValue(postProcessing.get("enable"), "后处理");
+                String type = (String) postProcessing.get("name");
+                Component component = createComponent(tTaskId, type, postProcessing, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("后处理组件保存失败");
+                }
+            }
+
+            // 6. 分析报告组件
+            Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("report");
+            if (analysisReport != null) {
+                Boolean isEnable = parseEnableValue(analysisReport.get("enable"), "分析报告");
+                String type = (String) analysisReport.get("name");
+                Component component = createComponent(tTaskId, type, analysisReport, isEnable);
+                if (componentService.save(component)) {
+                    componentIds.add(component.getComponentId());
+                } else {
+                    throw new RuntimeException("分析报告组件保存失败");
+                }
+            }
+
+            // 更新任务关联组件ID
+            if (!componentIds.isEmpty()) {
+                String componentIdsStr = componentIds.stream()
+                        .map(String::valueOf)
+                        .collect(java.util.stream.Collectors.joining(","));
+                task.setTComponentIds(componentIdsStr);
+                trainTaskService.updateById(task);
+            }
+
+
+            return AjaxResult.success("更新训练任务成功");
+        }catch (Exception e){
+            e.printStackTrace();
+            return AjaxResult.error("更新训练任务失败" + e.getMessage());
+        }
+    }
 
     ///////////////////////////////////////////////////////////////////////////
     /**