|
@@ -59,14 +59,14 @@ public class TrainTaskController {
|
|
|
task.setTRunInfo((String) taskInfo.get("tRunInfo"));
|
|
|
task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
|
|
|
|
|
|
- Object taskStatusObj = taskInfo.get("ttaskStatus");
|
|
|
+ Object taskStatusObj = taskInfo.get("tTaskStatus");
|
|
|
if (taskStatusObj != null){
|
|
|
try {
|
|
|
- Integer ttaskStatus = Integer.parseInt(taskStatusObj.toString());
|
|
|
- if (ttaskStatus != 0 && ttaskStatus != 1) {
|
|
|
+ Integer tTaskStatus = Integer.parseInt(taskStatusObj.toString());
|
|
|
+ if (tTaskStatus != 0 && tTaskStatus != 1) {
|
|
|
return AjaxResult.error("任务状态必须是0或1,请检查输入");
|
|
|
}
|
|
|
- task.setTTaskStatus(ttaskStatus);
|
|
|
+ task.setTTaskStatus(tTaskStatus);
|
|
|
}catch (Exception e){
|
|
|
return AjaxResult.error("传入的任务状态格式不正确");
|
|
|
}
|
|
@@ -279,6 +279,173 @@ public class TrainTaskController {
|
|
|
}
|
|
|
|
|
|
|
|
|
+ /**
|
|
|
+ * 编辑训练任务
|
|
|
+ * @param taskInfo
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ @Transactional
|
|
|
+ @PostMapping("/updateTask")
|
|
|
+ public AjaxResult updateTask(@RequestBody Map<String , Object> taskInfo){
|
|
|
+
|
|
|
+ try {
|
|
|
+ TrainTask task = new TrainTask();
|
|
|
+ if (taskInfo == null){
|
|
|
+ return AjaxResult.error("训练任务不存在 , 无法编辑");
|
|
|
+ }
|
|
|
+ Object tTaskIdObj = taskInfo.get("tTaskId");
|
|
|
+ if (tTaskIdObj == null){
|
|
|
+ return AjaxResult.error("训练任务id不能为空");
|
|
|
+ }
|
|
|
+ Long tTaskId = null;
|
|
|
+ try {
|
|
|
+ tTaskId = Long.parseLong(tTaskIdObj.toString());
|
|
|
+ } catch (NumberFormatException e) {
|
|
|
+ return AjaxResult.error("传入的任务ID格式不正确,请传数字类型");
|
|
|
+ }
|
|
|
+
|
|
|
+ Object taskStatusObj = taskInfo.get("tTaskStatus");
|
|
|
+ Integer tTaskStatus = null;
|
|
|
+ if (taskStatusObj != null){
|
|
|
+ try {
|
|
|
+ tTaskStatus = Integer.parseInt(taskStatusObj.toString());
|
|
|
+ if (tTaskStatus != 0 && tTaskStatus != 1) {
|
|
|
+ return AjaxResult.error("任务状态必须是0或1,请检查输入");
|
|
|
+ }
|
|
|
+ }catch (Exception e){
|
|
|
+ return AjaxResult.error("传入的任务状态格式不正确");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ task.setTTaskName((String) taskInfo.get("tTaskName"));
|
|
|
+ task.setTCronExpression((String) taskInfo.get("tCronExpression"));
|
|
|
+ task.setTQuartzTask((String) taskInfo.get("tQuartzTask"));
|
|
|
+ task.setTRunInfo((String) taskInfo.get("tRunInfo"));
|
|
|
+ task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
|
|
|
+ task.setTTaskStatus(tTaskStatus);
|
|
|
+ task.setTTaskId(tTaskId);
|
|
|
+ boolean updateTask = trainTaskService.updateById(task);
|
|
|
+ if (!updateTask){
|
|
|
+ return AjaxResult.error("更新训练任务失败");
|
|
|
+ }
|
|
|
+
|
|
|
+ // 先删除原关联组件
|
|
|
+ QueryWrapper<Component> componentWrapper = new QueryWrapper<>();
|
|
|
+ componentWrapper.eq("TASK_ID", tTaskId)
|
|
|
+ .eq("TASK_TYPE", 0);
|
|
|
+ componentService.remove(componentWrapper);
|
|
|
+
|
|
|
+ List<Long> componentIds = new ArrayList<>();
|
|
|
+
|
|
|
+ // 1. 数据获取组件
|
|
|
+ Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
|
|
|
+ if (dataAcquisition != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(dataAcquisition.get("enable"), "数据获取");
|
|
|
+ String type = (String) dataAcquisition.get("name");
|
|
|
+ Component component = createComponent(tTaskId, type, dataAcquisition, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("数据获取组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 2. 数据处理组件
|
|
|
+ Map<String, Object> dataCleaning = (Map<String, Object>) taskInfo.get("dataCleaning");
|
|
|
+ if (dataCleaning != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(dataCleaning.get("enable"), "数据处理");
|
|
|
+ String type = (String) dataCleaning.get("name");
|
|
|
+ Component component = createComponent(tTaskId, type, dataCleaning, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("数据处理组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 3. 限电清洗组件
|
|
|
+ Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
|
|
|
+ if (powerRationing != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(powerRationing.get("enable"), "限电清洗");
|
|
|
+ String type = (String) powerRationing.get("name");
|
|
|
+ Component component = createComponent(tTaskId, type, powerRationing, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("限电清洗组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 4. 模型组件 + 关联模型测试组件
|
|
|
+ Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
|
|
|
+ if (model != null) {
|
|
|
+ Boolean modelIsEnable = parseEnableValue(model.get("enable"), "模型");
|
|
|
+ String modelType = (String) model.get("name");
|
|
|
+ Component modelComponent = createComponent(tTaskId, modelType, model, modelIsEnable);
|
|
|
+ if (componentService.save(modelComponent)) {
|
|
|
+ componentIds.add(modelComponent.getComponentId());
|
|
|
+
|
|
|
+ // 训练类模型关联测试组件
|
|
|
+ if ("LSTM-训练".equals(modelType) || "机器学习模型-训练".equals(modelType)) {
|
|
|
+ Map<String, Object> modelTest = (Map<String, Object>) taskInfo.get("modelTest");
|
|
|
+ if (modelTest != null) {
|
|
|
+ Boolean testIsEnable = parseEnableValue(modelTest.get("enable"), "模型测试");
|
|
|
+ String testType = (String) modelTest.get("name");
|
|
|
+ Component testComponent = createComponent(tTaskId, testType, modelTest, testIsEnable);
|
|
|
+ if (componentService.save(testComponent)) {
|
|
|
+ componentIds.add(testComponent.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException(testType + "组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException(modelType + "组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 5. 后处理组件
|
|
|
+ Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("processing");
|
|
|
+ if (postProcessing != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(postProcessing.get("enable"), "后处理");
|
|
|
+ String type = (String) postProcessing.get("name");
|
|
|
+ Component component = createComponent(tTaskId, type, postProcessing, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("后处理组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 6. 分析报告组件
|
|
|
+ Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("report");
|
|
|
+ if (analysisReport != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(analysisReport.get("enable"), "分析报告");
|
|
|
+ String type = (String) analysisReport.get("name");
|
|
|
+ Component component = createComponent(tTaskId, type, analysisReport, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("分析报告组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 更新任务关联组件ID
|
|
|
+ if (!componentIds.isEmpty()) {
|
|
|
+ String componentIdsStr = componentIds.stream()
|
|
|
+ .map(String::valueOf)
|
|
|
+ .collect(java.util.stream.Collectors.joining(","));
|
|
|
+ task.setTComponentIds(componentIdsStr);
|
|
|
+ trainTaskService.updateById(task);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ return AjaxResult.success("更新训练任务成功");
|
|
|
+ }catch (Exception e){
|
|
|
+ e.printStackTrace();
|
|
|
+ return AjaxResult.error("更新训练任务失败" + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////
|
|
|
/**
|