|
@@ -0,0 +1,186 @@
|
|
|
+package com.xvji.web.controller;
|
|
|
+
|
|
|
+import com.xvji.common.core.domain.AjaxResult;
|
|
|
+import com.xvji.domain.Component;
|
|
|
+import com.xvji.domain.TrainTask;
|
|
|
+import com.xvji.mapper.TrainTaskMapper;
|
|
|
+import com.xvji.service.ComponentService;
|
|
|
+import com.xvji.service.TrainTaskService;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.transaction.annotation.Transactional;
|
|
|
+import org.springframework.web.bind.annotation.*;
|
|
|
+
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+/**
|
|
|
+ * 训练任务控制器(对应train_task表及关联组件)
|
|
|
+ */
|
|
|
+@RestController
|
|
|
+@RequestMapping("/task/train")
|
|
|
+public class TrainTaskController {
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private TrainTaskService trainTaskService;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private ComponentService componentService;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private TrainTaskMapper trainTaskMapper;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 新增训练任务及关联组件
|
|
|
+ * 逻辑说明:与预测任务一致,额外处理训练分析报告字段(传则存,不传则空)
|
|
|
+ */
|
|
|
+ @PostMapping("/addTask")
|
|
|
+ @Transactional
|
|
|
+ public AjaxResult addTrainTask(@RequestBody Map<String, Object> taskInfo) {
|
|
|
+ try {
|
|
|
+ TrainTask task = new TrainTask();
|
|
|
+ task.setTTaskName((String) taskInfo.get("tTaskName"));
|
|
|
+ task.setTCronExpression((String) taskInfo.get("tCronExpression"));
|
|
|
+ task.setTQuartzTask((String) taskInfo.get("tQuartzTask"));
|
|
|
+ task.setTRunInfo((String) taskInfo.get("tRunInfo"));
|
|
|
+ task.setTAnalysisReport((String) taskInfo.get("tAnalysisReport"));
|
|
|
+
|
|
|
+ Object tTaskIdObj = taskInfo.get("tTaskId");
|
|
|
+ Long manualTaskId = null;
|
|
|
+ if (tTaskIdObj != null) {
|
|
|
+ //校验ID格式
|
|
|
+ try {
|
|
|
+ manualTaskId = Long.parseLong(tTaskIdObj.toString());
|
|
|
+ } catch (NumberFormatException e) {
|
|
|
+ return AjaxResult.error("传入的任务ID格式不正确,请传数字类型");
|
|
|
+ }
|
|
|
+ //校验ID是否已存在
|
|
|
+ TrainTask existTask = trainTaskMapper.selectById(manualTaskId);
|
|
|
+ if (existTask != null) {
|
|
|
+ return AjaxResult.error("任务ID已存在:" + manualTaskId + ",请更换其他ID");
|
|
|
+ }
|
|
|
+ task.setTTaskId(manualTaskId);
|
|
|
+ }
|
|
|
+
|
|
|
+ boolean taskAdded = trainTaskService.save(task);
|
|
|
+ if (!taskAdded) {
|
|
|
+ return AjaxResult.error("训练任务新增失败");
|
|
|
+ }
|
|
|
+
|
|
|
+ Long taskId = task.getTTaskId();
|
|
|
+ List<Long> componentIds = new ArrayList<>();
|
|
|
+
|
|
|
+ //数据获取组件
|
|
|
+ Map<String, Object> dataAcquisition = (Map<String, Object>) taskInfo.get("dataAcquisition");
|
|
|
+ if (dataAcquisition != null && (Boolean) dataAcquisition.get("enable")) {
|
|
|
+ Component component = createComponent(taskId, "数据获取", dataAcquisition);
|
|
|
+ boolean saved = componentService.save(component);
|
|
|
+ if (saved) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("数据获取组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 数据处理组件
|
|
|
+ Map<String, Object> dataPreprocess = (Map<String, Object>) taskInfo.get("dataCleaning");
|
|
|
+ if (dataPreprocess != null && (Boolean) dataPreprocess.get("enable")) {
|
|
|
+ Component component = createComponent(taskId, "数据处理", dataPreprocess);
|
|
|
+ boolean saved = componentService.save(component);
|
|
|
+ if (saved) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("数据处理组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ //限电清洗-光伏 组件
|
|
|
+ Map<String, Object> powerRationing = (Map<String, Object>) taskInfo.get("powerRationing");
|
|
|
+ if (powerRationing != null && (Boolean) powerRationing.get("enable")) {
|
|
|
+ Component component = createComponent(taskId, "限电清洗-光伏", powerRationing);
|
|
|
+ boolean saved = componentService.save(component);
|
|
|
+ if (saved) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("限电清洗-光伏组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ //光伏物理模型组件
|
|
|
+ Map<String, Object> model = (Map<String, Object>) taskInfo.get("model");
|
|
|
+ if (model != null && (Boolean) model.get("enable")) {
|
|
|
+ Component component = createComponent(taskId, "光伏物理模型", model);
|
|
|
+ boolean saved = componentService.save(component);
|
|
|
+ if (saved) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("光伏物理模型组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 后处理
|
|
|
+ Map<String, Object> postProcessing = (Map<String, Object>) taskInfo.get("postProcessing");
|
|
|
+ if (postProcessing != null && (Boolean) postProcessing.get("enable")) {
|
|
|
+ Component component = createComponent(taskId, "后处理", postProcessing);
|
|
|
+ boolean saved = componentService.save(component);
|
|
|
+ if (saved) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("后处理组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 分析报告组件
|
|
|
+ Map<String, Object> analysisReport = (Map<String, Object>) taskInfo.get("analysisReport");
|
|
|
+ if (analysisReport != null && (Boolean) analysisReport.get("enable")) {
|
|
|
+ Component component = createComponent(taskId, "分析报告", analysisReport);
|
|
|
+ boolean saved = componentService.save(component);
|
|
|
+ if (saved) {
|
|
|
+ componentIds.add(component.getComponentId());
|
|
|
+ } else {
|
|
|
+ throw new RuntimeException("分析报告组件保存失败");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ //更新任务表中的关联组件ID
|
|
|
+ if (!componentIds.isEmpty()) {
|
|
|
+ String componentIdsStr = String.join(",", componentIds.stream()
|
|
|
+ .map(String::valueOf)
|
|
|
+ .collect(java.util.stream.Collectors.toList()));
|
|
|
+ task.setTComponentIds(componentIdsStr);
|
|
|
+ trainTaskService.updateById(task);
|
|
|
+ }
|
|
|
+
|
|
|
+ return AjaxResult.success("训练任务新增成功");
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ return AjaxResult.error("新增失败:" + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 辅助方法:创建组件对象(训练任务专用,taskType固定为0)
|
|
|
+ * @param taskId 训练任务ID
|
|
|
+ * @param componentType 组件类型(如数据预处理、模型训练)
|
|
|
+ * @param config 组件配置(包含enable、value、interfaceUrl)
|
|
|
+ */
|
|
|
+ private Component createComponent(Long taskId, String componentType, Map<String, Object> config) {
|
|
|
+ Component component = new Component();
|
|
|
+ component.setTaskId(taskId);
|
|
|
+ component.setTaskType(0); // 训练任务固定标识:0-训练任务
|
|
|
+ component.setComponentType(componentType);
|
|
|
+ component.setParamsMap((Map<String, Object>) config.get("value")); // 组件参数(JSON字符串存储)
|
|
|
+ component.setIsEnable((Boolean) config.get("enable")); // 组件启用状态
|
|
|
+ component.setInterfaceUrl((String) config.get("interfaceUrl")); // 接口地址:传则存,不传则空
|
|
|
+ return component;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 查询所有训练任务(与预测任务查询逻辑一致)
|
|
|
+ */
|
|
|
+ @GetMapping
|
|
|
+ public AjaxResult getAllTrainTasks() {
|
|
|
+ List<TrainTask> tasks = trainTaskService.getAllTrainTasks();
|
|
|
+ return AjaxResult.success(tasks);
|
|
|
+ }
|
|
|
+}
|