|
@@ -1,26 +1,40 @@
|
|
|
package com.xvji.web.controller;
|
|
|
|
|
|
+import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
|
|
import com.xvji.common.core.domain.AjaxResult;
|
|
|
import com.xvji.domain.Component;
|
|
|
import com.xvji.domain.TrainTask;
|
|
|
+import com.xvji.domain.vo.PageResult;
|
|
|
+import com.xvji.domain.vo.TrainTaskVO;
|
|
|
import com.xvji.mapper.TrainTaskMapper;
|
|
|
import com.xvji.service.ComponentService;
|
|
|
import com.xvji.service.TrainTaskService;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
+import org.springframework.util.StringUtils;
|
|
|
import org.springframework.web.bind.annotation.*;
|
|
|
+import org.springframework.beans.BeanUtils;
|
|
|
+import com.baomidou.mybatisplus.core.metadata.IPage;
|
|
|
+import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
|
|
+import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
/**
|
|
|
- * 训练任务控制器(对应train_task表及关联组件)
|
|
|
+ * 训练任务控制器(兼容非Boolean类型enable,修复强转漏洞)
|
|
|
*/
|
|
|
@RestController
|
|
|
@RequestMapping("/task/train")
|
|
|
public class TrainTaskController {
|
|
|
|
|
|
+ private static final Logger log = LoggerFactory.getLogger(TrainTaskController.class);
|
|
|
+
|
|
|
@Autowired
|
|
|
private TrainTaskService trainTaskService;
|
|
|
|
|
@@ -32,7 +46,6 @@ public class TrainTaskController {
|
|
|
|
|
|
/**
|
|
|
* 新增训练任务及关联组件
|
|
|
- * 逻辑说明:与预测任务一致,额外处理训练分析报告字段(传则存,不传则空)
|
|
|
*/
|
|
|
@PostMapping("/addTask")
|
|
|
@Transactional
|
|
@@ -48,13 +61,11 @@ public class TrainTaskController {
|
|
|
Object tTaskIdObj = taskInfo.get("tTaskId");
|
|
|
Long manualTaskId = null;
|
|
|
if (tTaskIdObj != null) {
|
|
|
- //校验ID格式
|
|
|
try {
|
|
|
manualTaskId = Long.parseLong(tTaskIdObj.toString());
|
|
|
} catch (NumberFormatException e) {
|
|
|
return AjaxResult.error("传入的任务ID格式不正确,请传数字类型");
|
|
|
}
|
|
|
- //校验ID是否已存在
|
|
|
TrainTask existTask = trainTaskMapper.selectById(manualTaskId);
|
|
|
if (existTask != null) {
|
|
|
return AjaxResult.error("任务ID已存在:" + manualTaskId + ",请更换其他ID");
|
|
@@ -62,6 +73,7 @@ public class TrainTaskController {
|
|
|
task.setTTaskId(manualTaskId);
|
|
|
}
|
|
|
|
|
|
+ // 保存训练任务主表
|
|
|
boolean taskAdded = trainTaskService.save(task);
|
|
|
if (!taskAdded) {
|
|
|
return AjaxResult.error("训练任务新增失败");
|
|
@@ -70,83 +82,104 @@ public class TrainTaskController {
|
|
|
Long taskId = task.getTTaskId();
|
|
|
List<Long> componentIds = new ArrayList<>();
|
|
|
|
|
|
- //数据获取组件
|
|
|
+ // 1. 数据获取组件
|
|
|
Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
|
|
|
- if (dataAcquisition != null && (Boolean) dataAcquisition.get("enable")) {
|
|
|
- Component component = createComponent(taskId, "数据获取", dataAcquisition);
|
|
|
- boolean saved = componentService.save(component);
|
|
|
- if (saved) {
|
|
|
+ if (dataAcquisition != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(dataAcquisition.get("enable"), "数据获取");
|
|
|
+ String type = (String) dataAcquisition.get("name");
|
|
|
+ Component component = createComponent(taskId, type, dataAcquisition, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
componentIds.add(component.getComponentId());
|
|
|
} else {
|
|
|
throw new RuntimeException("数据获取组件保存失败");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 数据处理组件
|
|
|
- Map<String, Object> dataPreprocess = (Map<String, Object>) taskInfo.get("dataCleaning");
|
|
|
- if (dataPreprocess != null && (Boolean) dataPreprocess.get("enable")) {
|
|
|
- Component component = createComponent(taskId, "数据处理", dataPreprocess);
|
|
|
- boolean saved = componentService.save(component);
|
|
|
- if (saved) {
|
|
|
+ // 2. 数据处理组件
|
|
|
+ Map<String, Object> dataCleaning = (Map<String, Object>) taskInfo.get("dataCleaning");
|
|
|
+ if (dataCleaning != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(dataCleaning.get("enable"), "数据处理");
|
|
|
+ String type = (String) dataCleaning.get("name");
|
|
|
+ Component component = createComponent(taskId, type, dataCleaning, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
componentIds.add(component.getComponentId());
|
|
|
} else {
|
|
|
throw new RuntimeException("数据处理组件保存失败");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- //限电清洗-光伏 组件
|
|
|
+ // 3. 限电清洗组件
|
|
|
Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
|
|
|
- if (powerRationing != null && (Boolean) powerRationing.get("enable")) {
|
|
|
- Component component = createComponent(taskId, "限电清洗-光伏", powerRationing);
|
|
|
- boolean saved = componentService.save(component);
|
|
|
- if (saved) {
|
|
|
+ if (powerRationing != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(powerRationing.get("enable"), "限电清洗");
|
|
|
+ String type = (String) powerRationing.get("name");
|
|
|
+ Component component = createComponent(taskId, type, powerRationing, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
componentIds.add(component.getComponentId());
|
|
|
} else {
|
|
|
- throw new RuntimeException("限电清洗-光伏组件保存失败");
|
|
|
+ throw new RuntimeException("限电清洗组件保存失败");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- //光伏物理模型组件
|
|
|
+ // 4. 模型组件 + 关联模型测试组件
|
|
|
Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
|
|
|
- if (model != null && (Boolean) model.get("enable")) {
|
|
|
- Component component = createComponent(taskId, "光伏物理模型", model);
|
|
|
- boolean saved = componentService.save(component);
|
|
|
- if (saved) {
|
|
|
- componentIds.add(component.getComponentId());
|
|
|
+ if (model != null) {
|
|
|
+ Boolean modelIsEnable = parseEnableValue(model.get("enable"), "模型");
|
|
|
+ String modelType = (String) model.get("name");
|
|
|
+ Component modelComponent = createComponent(taskId, modelType, model, modelIsEnable);
|
|
|
+ if (componentService.save(modelComponent)) {
|
|
|
+ componentIds.add(modelComponent.getComponentId());
|
|
|
+
|
|
|
+ // 训练类模型关联测试组件
|
|
|
+ if ("LSTM-训练".equals(modelType) || "机器学习模型-训练".equals(modelType)) {
|
|
|
+ Map<String, Object> modelTest = (Map<String, Object>) taskInfo.get("modelTest");
|
|
|
+ if (modelTest != null) {
|
|
|
+ Boolean testIsEnable = parseEnableValue(modelTest.get("enable"), "模型测试");
|
|
|
+ String testType = (String) modelTest.get("name");
|
|
|
+ Component testComponent = createComponent(taskId, testType, modelTest, testIsEnable);
|
|
|
+ if (componentService.save(testComponent)) {
|
|
|
+ componentIds.add(testComponent.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException(testType + "组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
} else {
|
|
|
- throw new RuntimeException("光伏物理模型组件保存失败");
|
|
|
+ throw new RuntimeException(modelType + "组件保存失败");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 后处理
|
|
|
+ // 5. 后处理组件
|
|
|
Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("postProcessing");
|
|
|
- if (postProcessing != null && (Boolean) postProcessing.get("enable")) {
|
|
|
- Component component = createComponent(taskId, "后处理", postProcessing);
|
|
|
- boolean saved = componentService.save(component);
|
|
|
- if (saved) {
|
|
|
+ if (postProcessing != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(postProcessing.get("enable"), "后处理");
|
|
|
+ String type = (String) postProcessing.get("name");
|
|
|
+ Component component = createComponent(taskId, type, postProcessing, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
componentIds.add(component.getComponentId());
|
|
|
} else {
|
|
|
throw new RuntimeException("后处理组件保存失败");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // 分析报告组件
|
|
|
+ // 6. 分析报告组件
|
|
|
Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("analysisReport");
|
|
|
- if (analysisReport != null && (Boolean) analysisReport.get("enable")) {
|
|
|
- Component component = createComponent(taskId, "分析报告", analysisReport);
|
|
|
- boolean saved = componentService.save(component);
|
|
|
- if (saved) {
|
|
|
+ if (analysisReport != null) {
|
|
|
+ Boolean isEnable = parseEnableValue(analysisReport.get("enable"), "分析报告");
|
|
|
+ String type = (String) analysisReport.get("name");
|
|
|
+ Component component = createComponent(taskId, type, analysisReport, isEnable);
|
|
|
+ if (componentService.save(component)) {
|
|
|
componentIds.add(component.getComponentId());
|
|
|
} else {
|
|
|
throw new RuntimeException("分析报告组件保存失败");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- //更新任务表中的关联组件ID
|
|
|
+ // 更新任务关联组件ID
|
|
|
if (!componentIds.isEmpty()) {
|
|
|
- String componentIdsStr = String.join(",", componentIds.stream()
|
|
|
+ String componentIdsStr = componentIds.stream()
|
|
|
.map(String::valueOf)
|
|
|
- .collect(java.util.stream.Collectors.toList()));
|
|
|
+ .collect(java.util.stream.Collectors.joining(","));
|
|
|
task.setTComponentIds(componentIdsStr);
|
|
|
trainTaskService.updateById(task);
|
|
|
}
|
|
@@ -159,28 +192,89 @@ public class TrainTaskController {
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 辅助方法:创建组件对象(训练任务专用,taskType固定为0)
|
|
|
- * @param taskId 训练任务ID
|
|
|
- * @param componentType 组件类型(如数据预处理、模型训练)
|
|
|
- * @param config 组件配置(包含enable、value、interfaceUrl)
|
|
|
+ * 查询所有任务以及组件参数
|
|
|
+ * @param pageNum
|
|
|
+ * @param pageSize
|
|
|
+ * @param taskName
|
|
|
+ * @return
|
|
|
*/
|
|
|
- private Component createComponent(Long taskId, String componentType, Map<String, Object> config) {
|
|
|
+ @GetMapping("/queryTasks")
|
|
|
+ public AjaxResult queryTrainTasks(
|
|
|
+ @RequestParam(defaultValue = "1") int pageNum,
|
|
|
+ @RequestParam(defaultValue = "10") int pageSize,
|
|
|
+ @RequestParam(required = false) String taskName) {
|
|
|
+
|
|
|
+ try {
|
|
|
+ Page<TrainTask> page = new Page<>(pageNum, pageSize);
|
|
|
+
|
|
|
+ QueryWrapper<TrainTask> queryWrapper = new QueryWrapper<>();
|
|
|
+ if (StringUtils.hasText(taskName)) {
|
|
|
+ queryWrapper.like("T_TASK_NAME", taskName);
|
|
|
+ }
|
|
|
+ queryWrapper.orderByDesc("T_CREATE_TIME");
|
|
|
+
|
|
|
+ Page<TrainTask> taskPage = trainTaskService.page(page, queryWrapper);
|
|
|
+ List<TrainTaskVO> taskVOList = taskPage.getRecords().stream().map(task -> {
|
|
|
+ TrainTaskVO vo = new TrainTaskVO();
|
|
|
+ BeanUtils.copyProperties(task, vo);
|
|
|
+
|
|
|
+ // 解析组件ID并查询组件
|
|
|
+ String componentIds = task.getTComponentIds();
|
|
|
+ if (StringUtils.hasText(componentIds)) {
|
|
|
+ List<Long> ids = Arrays.stream(componentIds.split(","))
|
|
|
+ .map(Long::parseLong)
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ List<Component> components = componentService.listByIds(ids);
|
|
|
+ vo.setComponents(components);
|
|
|
+ }
|
|
|
+ return vo;
|
|
|
+ }).collect(Collectors.toList());
|
|
|
+
|
|
|
+ PageResult<TrainTaskVO> result = new PageResult<>(
|
|
|
+ taskPage.getTotal(),
|
|
|
+ pageNum,
|
|
|
+ pageSize,
|
|
|
+ taskVOList
|
|
|
+ );
|
|
|
+
|
|
|
+ return AjaxResult.success(result);
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ return AjaxResult.error("查询训练任务失败:" + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ ///////////////////////////////////////////////////////////////////////////
|
|
|
+ /**
|
|
|
+ * 使用解析后的isEnable,删除错误强转
|
|
|
+ */
|
|
|
+ private Component createComponent(Long taskId, String componentType, Map<String, Object> config, Boolean isEnable) {
|
|
|
Component component = new Component();
|
|
|
component.setTaskId(taskId);
|
|
|
- component.setTaskType(0); // 训练任务固定标识:0-训练任务
|
|
|
+ component.setTaskType(0); // 训练任务固定标识
|
|
|
component.setComponentType(componentType);
|
|
|
- component.setParamsMap((Map<String, Object>) config.get("value")); // 组件参数(JSON字符串存储)
|
|
|
- component.setIsEnable((Boolean) config.get("enable")); // 组件启用状态
|
|
|
- component.setInterfaceUrl((String) config.get("interfaceUrl")); // 接口地址:传则存,不传则空
|
|
|
+ component.setParamsMap((Map<String, Object>) config.get("value")); // 组件参数
|
|
|
+ component.setIsEnable(isEnable); // 仅使用解析后的enable值,无强转
|
|
|
+ component.setInterfaceUrl((String) config.get("interfaceUrl")); // 接口地址
|
|
|
return component;
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * 查询所有训练任务(与预测任务查询逻辑一致)
|
|
|
+ * 统一解析enable值,兼容非Boolean类型
|
|
|
*/
|
|
|
- @GetMapping
|
|
|
- public AjaxResult getAllTrainTasks() {
|
|
|
- List<TrainTask> tasks = trainTaskService.getAllTrainTasks();
|
|
|
- return AjaxResult.success(tasks);
|
|
|
+ private Boolean parseEnableValue(Object enableObj, String componentName) {
|
|
|
+ if (enableObj instanceof Boolean) {
|
|
|
+ return (Boolean) enableObj;
|
|
|
+ }
|
|
|
+ // 非Boolean类型或null,默认false并记录日志
|
|
|
+ if (enableObj != null) {
|
|
|
+ log.warn("组件[{}]的enable参数类型错误非boolean类型,实际类型为{} 值{},默认设为false",
|
|
|
+ componentName, enableObj.getClass().getSimpleName(), enableObj);
|
|
|
+ } else {
|
|
|
+ log.warn("组件[{}]未传递enable参数,已默认设为false", componentName);
|
|
|
+ }
|
|
|
+ return false;
|
|
|
}
|
|
|
+
|
|
|
}
|