123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- package com.xvji.web.controller;
- import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
- import com.xvji.common.core.domain.AjaxResult;
- import com.xvji.domain.Component;
- import com.xvji.domain.TrainTask;
- import com.xvji.domain.vo.PageResult;
- import com.xvji.domain.vo.TrainTaskVO;
- import com.xvji.mapper.TrainTaskMapper;
- import com.xvji.service.ComponentService;
- import com.xvji.service.TrainTaskService;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.transaction.annotation.Transactional;
- import org.springframework.util.StringUtils;
- import org.springframework.web.bind.annotation.*;
- import org.springframework.beans.BeanUtils;
- import com.baomidou.mybatisplus.core.metadata.IPage;
- import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
- import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.List;
- import java.util.Map;
- import java.util.stream.Collectors;
- /**
- * 训练任务控制器(兼容非Boolean类型enable,修复强转漏洞)
- */
- @RestController
- @RequestMapping("/task/train")
- public class TrainTaskController {
- private static final Logger log = LoggerFactory.getLogger(TrainTaskController.class);
- @Autowired
- private TrainTaskService trainTaskService;
- @Autowired
- private ComponentService componentService;
- @Autowired
- private TrainTaskMapper trainTaskMapper;
- /**
- * 新增训练任务及关联组件
- */
- @PostMapping("/addTask")
- @Transactional
- public AjaxResult addTrainTask(@RequestBody Map<String, Object> taskInfo) {
- try {
- TrainTask task = new TrainTask();
- task.setTTaskName((String) taskInfo.get("tTaskName"));
- task.setTCronExpression((String) taskInfo.get("tCronExpression"));
- task.setTQuartzTask((String) taskInfo.get("tQuartzTask"));
- task.setTRunInfo((String) taskInfo.get("tRunInfo"));
- task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
- Object tTaskIdObj = taskInfo.get("tTaskId");
- Long manualTaskId = null;
- if (tTaskIdObj != null) {
- try {
- manualTaskId = Long.parseLong(tTaskIdObj.toString());
- } catch (NumberFormatException e) {
- return AjaxResult.error("传入的任务ID格式不正确,请传数字类型");
- }
- TrainTask existTask = trainTaskMapper.selectById(manualTaskId);
- if (existTask != null) {
- return AjaxResult.error("任务ID已存在:" + manualTaskId + ",请更换其他ID");
- }
- task.setTTaskId(manualTaskId);
- }
- // 保存训练任务主表
- boolean taskAdded = trainTaskService.save(task);
- if (!taskAdded) {
- return AjaxResult.error("训练任务新增失败");
- }
- Long taskId = task.getTTaskId();
- List<Long> componentIds = new ArrayList<>();
- // 1. 数据获取组件
- Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
- if (dataAcquisition != null) {
- Boolean isEnable = parseEnableValue(dataAcquisition.get("enable"), "数据获取");
- String type = (String) dataAcquisition.get("name");
- Component component = createComponent(taskId, type, dataAcquisition, isEnable);
- if (componentService.save(component)) {
- componentIds.add(component.getComponentId());
- } else {
- throw new RuntimeException("数据获取组件保存失败");
- }
- }
- // 2. 数据处理组件
- Map<String, Object> dataCleaning = (Map<String, Object>) taskInfo.get("dataCleaning");
- if (dataCleaning != null) {
- Boolean isEnable = parseEnableValue(dataCleaning.get("enable"), "数据处理");
- String type = (String) dataCleaning.get("name");
- Component component = createComponent(taskId, type, dataCleaning, isEnable);
- if (componentService.save(component)) {
- componentIds.add(component.getComponentId());
- } else {
- throw new RuntimeException("数据处理组件保存失败");
- }
- }
- // 3. 限电清洗组件
- Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
- if (powerRationing != null) {
- Boolean isEnable = parseEnableValue(powerRationing.get("enable"), "限电清洗");
- String type = (String) powerRationing.get("name");
- Component component = createComponent(taskId, type, powerRationing, isEnable);
- if (componentService.save(component)) {
- componentIds.add(component.getComponentId());
- } else {
- throw new RuntimeException("限电清洗组件保存失败");
- }
- }
- // 4. 模型组件 + 关联模型测试组件
- Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
- if (model != null) {
- Boolean modelIsEnable = parseEnableValue(model.get("enable"), "模型");
- String modelType = (String) model.get("name");
- Component modelComponent = createComponent(taskId, modelType, model, modelIsEnable);
- if (componentService.save(modelComponent)) {
- componentIds.add(modelComponent.getComponentId());
- // 训练类模型关联测试组件
- if ("LSTM-训练".equals(modelType) || "机器学习模型-训练".equals(modelType)) {
- Map<String, Object> modelTest = (Map<String, Object>) taskInfo.get("modelTest");
- if (modelTest != null) {
- Boolean testIsEnable = parseEnableValue(modelTest.get("enable"), "模型测试");
- String testType = (String) modelTest.get("name");
- Component testComponent = createComponent(taskId, testType, modelTest, testIsEnable);
- if (componentService.save(testComponent)) {
- componentIds.add(testComponent.getComponentId());
- } else {
- throw new RuntimeException(testType + "组件保存失败");
- }
- }
- }
- } else {
- throw new RuntimeException(modelType + "组件保存失败");
- }
- }
- // 5. 后处理组件
- Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("postProcessing");
- if (postProcessing != null) {
- Boolean isEnable = parseEnableValue(postProcessing.get("enable"), "后处理");
- String type = (String) postProcessing.get("name");
- Component component = createComponent(taskId, type, postProcessing, isEnable);
- if (componentService.save(component)) {
- componentIds.add(component.getComponentId());
- } else {
- throw new RuntimeException("后处理组件保存失败");
- }
- }
- // 6. 分析报告组件
- Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("analysisReport");
- if (analysisReport != null) {
- Boolean isEnable = parseEnableValue(analysisReport.get("enable"), "分析报告");
- String type = (String) analysisReport.get("name");
- Component component = createComponent(taskId, type, analysisReport, isEnable);
- if (componentService.save(component)) {
- componentIds.add(component.getComponentId());
- } else {
- throw new RuntimeException("分析报告组件保存失败");
- }
- }
- // 更新任务关联组件ID
- if (!componentIds.isEmpty()) {
- String componentIdsStr = componentIds.stream()
- .map(String::valueOf)
- .collect(java.util.stream.Collectors.joining(","));
- task.setTComponentIds(componentIdsStr);
- trainTaskService.updateById(task);
- }
- return AjaxResult.success("训练任务新增成功");
- } catch (Exception e) {
- e.printStackTrace();
- return AjaxResult.error("新增失败:" + e.getMessage());
- }
- }
- /**
- * 查询所有任务以及组件参数
- * @param pageNum
- * @param pageSize
- * @param taskName
- * @return
- */
- @GetMapping("/queryTasks")
- public AjaxResult queryTrainTasks(
- @RequestParam(defaultValue = "1") int pageNum,
- @RequestParam(defaultValue = "10") int pageSize,
- @RequestParam(required = false) String taskName) {
- try {
- Page<TrainTask> page = new Page<>(pageNum, pageSize);
- QueryWrapper<TrainTask> queryWrapper = new QueryWrapper<>();
- if (StringUtils.hasText(taskName)) {
- queryWrapper.like("T_TASK_NAME", taskName);
- }
- queryWrapper.orderByDesc("T_CREATE_TIME");
- Page<TrainTask> taskPage = trainTaskService.page(page, queryWrapper);
- List<TrainTaskVO> taskVOList = taskPage.getRecords().stream().map(task -> {
- TrainTaskVO vo = new TrainTaskVO();
- BeanUtils.copyProperties(task, vo);
- // 解析组件ID并查询组件
- String componentIds = task.getTComponentIds();
- if (StringUtils.hasText(componentIds)) {
- List<Long> ids = Arrays.stream(componentIds.split(","))
- .map(Long::parseLong)
- .collect(Collectors.toList());
- List<Component> components = componentService.listByIds(ids);
- vo.setComponents(components);
- }
- return vo;
- }).collect(Collectors.toList());
- PageResult<TrainTaskVO> result = new PageResult<>(
- taskPage.getTotal(),
- pageNum,
- pageSize,
- taskVOList
- );
- return AjaxResult.success(result);
- } catch (Exception e) {
- e.printStackTrace();
- return AjaxResult.error("查询训练任务失败:" + e.getMessage());
- }
- }
- ///////////////////////////////////////////////////////////////////////////
- /**
- * 使用解析后的isEnable,删除错误强转
- */
- private Component createComponent(Long taskId, String componentType, Map<String, Object> config, Boolean isEnable) {
- Component component = new Component();
- component.setTaskId(taskId);
- component.setTaskType(0); // 训练任务固定标识
- component.setComponentType(componentType);
- component.setParamsMap((Map<String, Object>) config.get("value")); // 组件参数
- component.setIsEnable(isEnable); // 仅使用解析后的enable值,无强转
- component.setInterfaceUrl((String) config.get("interfaceUrl")); // 接口地址
- return component;
- }
- /**
- * 统一解析enable值,兼容非Boolean类型
- */
- private Boolean parseEnableValue(Object enableObj, String componentName) {
- if (enableObj instanceof Boolean) {
- return (Boolean) enableObj;
- }
- // 非Boolean类型或null,默认false并记录日志
- if (enableObj != null) {
- log.warn("组件[{}]的enable参数类型错误非boolean类型,实际类型为{} 值{},默认设为false",
- componentName, enableObj.getClass().getSimpleName(), enableObj);
- } else {
- log.warn("组件[{}]未传递enable参数,已默认设为false", componentName);
- }
- return false;
- }
- }
|