TrainTaskController.java 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package com.xvji.web.controller;
  2. import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
  3. import com.xvji.common.core.domain.AjaxResult;
  4. import com.xvji.domain.Component;
  5. import com.xvji.domain.TrainTask;
  6. import com.xvji.domain.vo.PageResult;
  7. import com.xvji.domain.vo.TrainTaskVO;
  8. import com.xvji.mapper.TrainTaskMapper;
  9. import com.xvji.service.ComponentService;
  10. import com.xvji.service.TrainTaskService;
  11. import org.slf4j.Logger;
  12. import org.slf4j.LoggerFactory;
  13. import org.springframework.beans.factory.annotation.Autowired;
  14. import org.springframework.transaction.annotation.Transactional;
  15. import org.springframework.util.StringUtils;
  16. import org.springframework.web.bind.annotation.*;
  17. import org.springframework.beans.BeanUtils;
  18. import com.baomidou.mybatisplus.core.metadata.IPage;
  19. import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
  20. import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
  21. import java.util.ArrayList;
  22. import java.util.Arrays;
  23. import java.util.List;
  24. import java.util.Map;
  25. import java.util.stream.Collectors;
  26. /**
  27. * 训练任务控制器(兼容非Boolean类型enable,修复强转漏洞)
  28. */
  29. @RestController
  30. @RequestMapping("/task/train")
  31. public class TrainTaskController {
  32. private static final Logger log = LoggerFactory.getLogger(TrainTaskController.class);
  33. @Autowired
  34. private TrainTaskService trainTaskService;
  35. @Autowired
  36. private ComponentService componentService;
  37. @Autowired
  38. private TrainTaskMapper trainTaskMapper;
  39. /**
  40. * 新增训练任务及关联组件
  41. */
  42. @PostMapping("/addTask")
  43. @Transactional
  44. public AjaxResult addTrainTask(@RequestBody Map<String, Object> taskInfo) {
  45. try {
  46. TrainTask task = new TrainTask();
  47. task.setTTaskName((String) taskInfo.get("tTaskName"));
  48. task.setTCronExpression((String) taskInfo.get("tCronExpression"));
  49. task.setTQuartzTask((String) taskInfo.get("tQuartzTask"));
  50. task.setTRunInfo((String) taskInfo.get("tRunInfo"));
  51. task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
  52. Object tTaskIdObj = taskInfo.get("tTaskId");
  53. Long manualTaskId = null;
  54. if (tTaskIdObj != null) {
  55. try {
  56. manualTaskId = Long.parseLong(tTaskIdObj.toString());
  57. } catch (NumberFormatException e) {
  58. return AjaxResult.error("传入的任务ID格式不正确,请传数字类型");
  59. }
  60. TrainTask existTask = trainTaskMapper.selectById(manualTaskId);
  61. if (existTask != null) {
  62. return AjaxResult.error("任务ID已存在:" + manualTaskId + ",请更换其他ID");
  63. }
  64. task.setTTaskId(manualTaskId);
  65. }
  66. // 保存训练任务主表
  67. boolean taskAdded = trainTaskService.save(task);
  68. if (!taskAdded) {
  69. return AjaxResult.error("训练任务新增失败");
  70. }
  71. Long taskId = task.getTTaskId();
  72. List<Long> componentIds = new ArrayList<>();
  73. // 1. 数据获取组件
  74. Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
  75. if (dataAcquisition != null) {
  76. Boolean isEnable = parseEnableValue(dataAcquisition.get("enable"), "数据获取");
  77. String type = (String) dataAcquisition.get("name");
  78. Component component = createComponent(taskId, type, dataAcquisition, isEnable);
  79. if (componentService.save(component)) {
  80. componentIds.add(component.getComponentId());
  81. } else {
  82. throw new RuntimeException("数据获取组件保存失败");
  83. }
  84. }
  85. // 2. 数据处理组件
  86. Map<String, Object> dataCleaning = (Map<String, Object>) taskInfo.get("dataCleaning");
  87. if (dataCleaning != null) {
  88. Boolean isEnable = parseEnableValue(dataCleaning.get("enable"), "数据处理");
  89. String type = (String) dataCleaning.get("name");
  90. Component component = createComponent(taskId, type, dataCleaning, isEnable);
  91. if (componentService.save(component)) {
  92. componentIds.add(component.getComponentId());
  93. } else {
  94. throw new RuntimeException("数据处理组件保存失败");
  95. }
  96. }
  97. // 3. 限电清洗组件
  98. Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
  99. if (powerRationing != null) {
  100. Boolean isEnable = parseEnableValue(powerRationing.get("enable"), "限电清洗");
  101. String type = (String) powerRationing.get("name");
  102. Component component = createComponent(taskId, type, powerRationing, isEnable);
  103. if (componentService.save(component)) {
  104. componentIds.add(component.getComponentId());
  105. } else {
  106. throw new RuntimeException("限电清洗组件保存失败");
  107. }
  108. }
  109. // 4. 模型组件 + 关联模型测试组件
  110. Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
  111. if (model != null) {
  112. Boolean modelIsEnable = parseEnableValue(model.get("enable"), "模型");
  113. String modelType = (String) model.get("name");
  114. Component modelComponent = createComponent(taskId, modelType, model, modelIsEnable);
  115. if (componentService.save(modelComponent)) {
  116. componentIds.add(modelComponent.getComponentId());
  117. // 训练类模型关联测试组件
  118. if ("LSTM-训练".equals(modelType) || "机器学习模型-训练".equals(modelType)) {
  119. Map<String, Object> modelTest = (Map<String, Object>) taskInfo.get("modelTest");
  120. if (modelTest != null) {
  121. Boolean testIsEnable = parseEnableValue(modelTest.get("enable"), "模型测试");
  122. String testType = (String) modelTest.get("name");
  123. Component testComponent = createComponent(taskId, testType, modelTest, testIsEnable);
  124. if (componentService.save(testComponent)) {
  125. componentIds.add(testComponent.getComponentId());
  126. } else {
  127. throw new RuntimeException(testType + "组件保存失败");
  128. }
  129. }
  130. }
  131. } else {
  132. throw new RuntimeException(modelType + "组件保存失败");
  133. }
  134. }
  135. // 5. 后处理组件
  136. Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("postProcessing");
  137. if (postProcessing != null) {
  138. Boolean isEnable = parseEnableValue(postProcessing.get("enable"), "后处理");
  139. String type = (String) postProcessing.get("name");
  140. Component component = createComponent(taskId, type, postProcessing, isEnable);
  141. if (componentService.save(component)) {
  142. componentIds.add(component.getComponentId());
  143. } else {
  144. throw new RuntimeException("后处理组件保存失败");
  145. }
  146. }
  147. // 6. 分析报告组件
  148. Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("analysisReport");
  149. if (analysisReport != null) {
  150. Boolean isEnable = parseEnableValue(analysisReport.get("enable"), "分析报告");
  151. String type = (String) analysisReport.get("name");
  152. Component component = createComponent(taskId, type, analysisReport, isEnable);
  153. if (componentService.save(component)) {
  154. componentIds.add(component.getComponentId());
  155. } else {
  156. throw new RuntimeException("分析报告组件保存失败");
  157. }
  158. }
  159. // 更新任务关联组件ID
  160. if (!componentIds.isEmpty()) {
  161. String componentIdsStr = componentIds.stream()
  162. .map(String::valueOf)
  163. .collect(java.util.stream.Collectors.joining(","));
  164. task.setTComponentIds(componentIdsStr);
  165. trainTaskService.updateById(task);
  166. }
  167. return AjaxResult.success("训练任务新增成功");
  168. } catch (Exception e) {
  169. e.printStackTrace();
  170. return AjaxResult.error("新增失败:" + e.getMessage());
  171. }
  172. }
  173. /**
  174. * 查询所有任务以及组件参数
  175. * @param pageNum
  176. * @param pageSize
  177. * @param taskName
  178. * @return
  179. */
  180. @GetMapping("/queryTasks")
  181. public AjaxResult queryTrainTasks(
  182. @RequestParam(defaultValue = "1") int pageNum,
  183. @RequestParam(defaultValue = "10") int pageSize,
  184. @RequestParam(required = false) String taskName) {
  185. try {
  186. Page<TrainTask> page = new Page<>(pageNum, pageSize);
  187. QueryWrapper<TrainTask> queryWrapper = new QueryWrapper<>();
  188. if (StringUtils.hasText(taskName)) {
  189. queryWrapper.like("T_TASK_NAME", taskName);
  190. }
  191. queryWrapper.orderByDesc("T_CREATE_TIME");
  192. Page<TrainTask> taskPage = trainTaskService.page(page, queryWrapper);
  193. List<TrainTaskVO> taskVOList = taskPage.getRecords().stream().map(task -> {
  194. TrainTaskVO vo = new TrainTaskVO();
  195. BeanUtils.copyProperties(task, vo);
  196. // 解析组件ID并查询组件
  197. String componentIds = task.getTComponentIds();
  198. if (StringUtils.hasText(componentIds)) {
  199. List<Long> ids = Arrays.stream(componentIds.split(","))
  200. .map(Long::parseLong)
  201. .collect(Collectors.toList());
  202. List<Component> components = componentService.listByIds(ids);
  203. vo.setComponents(components);
  204. }
  205. return vo;
  206. }).collect(Collectors.toList());
  207. PageResult<TrainTaskVO> result = new PageResult<>(
  208. taskPage.getTotal(),
  209. pageNum,
  210. pageSize,
  211. taskVOList
  212. );
  213. return AjaxResult.success(result);
  214. } catch (Exception e) {
  215. e.printStackTrace();
  216. return AjaxResult.error("查询训练任务失败:" + e.getMessage());
  217. }
  218. }
  219. ///////////////////////////////////////////////////////////////////////////
  220. /**
  221. * 使用解析后的isEnable,删除错误强转
  222. */
  223. private Component createComponent(Long taskId, String componentType, Map<String, Object> config, Boolean isEnable) {
  224. Component component = new Component();
  225. component.setTaskId(taskId);
  226. component.setTaskType(0); // 训练任务固定标识
  227. component.setComponentType(componentType);
  228. component.setParamsMap((Map<String, Object>) config.get("value")); // 组件参数
  229. component.setIsEnable(isEnable); // 仅使用解析后的enable值,无强转
  230. component.setInterfaceUrl((String) config.get("interfaceUrl")); // 接口地址
  231. return component;
  232. }
  233. /**
  234. * 统一解析enable值,兼容非Boolean类型
  235. */
  236. private Boolean parseEnableValue(Object enableObj, String componentName) {
  237. if (enableObj instanceof Boolean) {
  238. return (Boolean) enableObj;
  239. }
  240. // 非Boolean类型或null,默认false并记录日志
  241. if (enableObj != null) {
  242. log.warn("组件[{}]的enable参数类型错误非boolean类型,实际类型为{} 值{},默认设为false",
  243. componentName, enableObj.getClass().getSimpleName(), enableObj);
  244. } else {
  245. log.warn("组件[{}]未传递enable参数,已默认设为false", componentName);
  246. }
  247. return false;
  248. }
  249. }