# -*- coding: utf-8 -*- """ @File : tree_service_impl.py @Time : 2023/1/14 16:10 @Author : geekbing @LastEditTime : 2023/7/27 21:10 @LastEditors : geekbing @Description : 树形结构操作 """ from ast import literal_eval import traceback from typing import Dict, List, Optional from loguru import logger from crud.base_crud import GenericCURD from lunarlink.dto.tree_dto import TreeOut, TreeUniqueIn, TreeUpdateIn from lunarlink.models import API, Case, Relation from lunarlink.utils.response import ( TREE_GET_SUCCESS, TREE_UPDATE_SUCCESS, StandResponse, ) from lunarlink.utils.tree import get_tree_max_id from lunarlink.utils.enums.TreeTypeEnum import TreeType class TreeService: def __init__(self): self.model = Relation self.curd = GenericCURD(self.model) def get_or_create(self, query: TreeUniqueIn) -> StandResponse[TreeOut]: default_tree = [ { "id": 1, "label": "默认目录", "children": [], } ] tree_obj, is_created = self.curd.get_or_create( filter_kwargs=query.dict(), defaults={"tree": default_tree, "project_id": query.project_id}, ) if is_created: logger.info(f"tree created {query=}") body: List[Dict] = tree_obj.tree else: logger.info(f"tree exist {query}") body: List[Dict] = literal_eval(tree_obj.tree) node_api_case_counts = {} for node in body: if query.type == TreeType.API: TreeService.add_api_count_to_node( tree_obj.project_id, node, node_api_case_counts, ) else: TreeService.add_case_count_to_node( tree_obj.project_id, node, node_api_case_counts, ) for root_node in body: TreeService.add_api_case_count_to_tree(root_node, node_api_case_counts) tree = { "tree": body, "id": tree_obj.id, "success": True, "max": get_tree_max_id(body), } return StandResponse[TreeOut](**TREE_GET_SUCCESS, data=TreeOut(**tree)) @staticmethod def add_case_count_to_node(project_id: int, node: Dict, node_case_count: Dict): """ 递归获取节点的用例数量 :param project_id: :param node: :param node_case_count: :return: """ node_id = node["id"] case_count = Case.objects.filter( project_id=project_id, relation=node_id ).count() node_case_count[node_id] = case_count for child in node.get("children", []): TreeService.add_case_count_to_node(project_id, child, node_case_count) @staticmethod def add_api_count_to_node(project_id: int, node: Dict, node_api_count: Dict): """ 递归获取节点的接口数量 :param project_id: :param node: :param node_api_count: :return: """ node_id = node["id"] api_count = API.objects.filter(project_id=project_id, relation=node_id).count() node_api_count[node_id] = api_count for child in node.get("children", []): TreeService.add_api_count_to_node(project_id, child, node_api_count) @staticmethod def add_api_case_count_to_tree(node: Dict, node_api_case_counts: Dict): """ 将节点的接口或用例数量嵌入原树形结构中 :param node: :param node_api_case_counts: :return: """ data_count = node_api_case_counts.get(node["id"], 0) if "children" in node: for child in node["children"]: data_count += TreeService.add_api_case_count_to_tree( child, node_api_case_counts ) node["data_count"] = data_count return data_count @staticmethod def check_related_api(node_id: int, project_id: int) -> bool: # 检查是否存在关联的接口 return API.objects.filter(relation=node_id, project_id=project_id).exists() @staticmethod def check_related_case(node_id: int, project_id: int) -> bool: # 检查是否存在关联的用例 return Case.objects.filter(relation=node_id, project_id=project_id).exists() def get_all_ids_from_tree(self, tree): """ 获取树的所有节点id :param tree: :return: """ ids = [node["id"] for node in tree] for node in tree: if "children" in node: ids.extend(self.get_all_ids_from_tree(node["children"])) return ids def _check_related_api_recursive(self, current_tree, payload_tree, project_id): """ 提取payload.tree所有节点id,递归检查current_tree中是否存在这些id,如果不存在,检查是否有关联的API """ payload_tree_id_list = self.get_all_ids_from_tree(payload_tree) for node in current_tree: if node["id"] not in payload_tree_id_list: # 这个节点正在被删除,检查是否有关联的API if self.check_related_api(node["id"], project_id): return True # 如果该节点有子节点,继续递归检查 if "children" in node and node["children"]: if self._check_related_api_recursive( node["children"], payload_tree, project_id ): return True return False def _check_related_case_recursive(self, current_tree, payload_tree, project_id): """ 提取payload.tree所有节点id,递归检查current_tree中是否存在这些id,如果不存在,检查是否有关联的用例 """ payload_tree_id_list = self.get_all_ids_from_tree(payload_tree) for node in current_tree: if node["id"] not in payload_tree_id_list: # 这个节点正在被删除,检查是否有关联的API if self.check_related_case(node["id"], project_id): return True # 如果该节点有子节点,继续递归检查 if "children" in node and node["children"]: if self._check_related_case_recursive( node["children"], payload_tree, project_id ): return True return False @staticmethod def get_current_tree(relation_obj) -> List: """ 获取当前的树形结构, 转成python对象 :param relation_obj: :return: """ if not relation_obj: return [] return literal_eval(relation_obj.tree) def patch(self, tree_id: int, payload: TreeUpdateIn): """ 更新树形结构 :param tree_id: :param payload: :return: """ # 获取当前的树结构 if not payload.tree: return StandResponse[Optional[TreeOut]]( code="9999", success=False, msg="删除失败, 至少保留一个根目录", data=None, ) relation_obj = self.curd.get_obj_by_pk(pk=tree_id) current_tree = self.get_current_tree(relation_obj) # 检查每个节点,如果在当前树中找到了这个节点,但在新树中找不到,就意味着这个节点被删除了 if payload.type == TreeType.API: if self._check_related_api_recursive( current_tree, payload.tree, relation_obj.project_id ): # 如果存在关联的接口,返回提示信息 return StandResponse[Optional[TreeOut]]( code="9999", success=False, msg="目录有关联接口,不能删除", data=None, ) elif payload.type == TreeType.CASE: # 如果存在关联的用例,返回提示信息 if self._check_related_case_recursive( current_tree, payload.tree, relation_obj.project_id ): # 如果存在关联的接口,返回提示信息 return StandResponse[Optional[TreeOut]]( code="9999", success=False, msg="目录有关联用例,不能删除", data=None, ) # 如果没有问题,尝试更新 try: tree_obj = self.curd.update_obj_by_pk( pk=tree_id, updater="", payload=payload.dict() ) except Exception as e: return self._handle_exception(e) tree: List[Dict] = tree_obj.tree node_api_case_counts = {} for node in tree: if payload.type == TreeType.API: TreeService.add_api_count_to_node( tree_obj.project_id, node, node_api_case_counts, ) else: TreeService.add_case_count_to_node( tree_obj.project_id, node, node_api_case_counts, ) for root_node in tree: TreeService.add_api_case_count_to_tree(root_node, node_api_case_counts) return StandResponse[TreeOut]( **TREE_UPDATE_SUCCESS, data=TreeOut(tree=tree, id=tree_obj.id, max=get_tree_max_id(tree)), ) @staticmethod def _handle_exception(e: Exception) -> StandResponse[Optional[TreeOut]]: err: str = traceback.format_exc() logger.warning(f"Exception {e} occurred with traceback: {err}") return StandResponse[Optional[TreeOut]]( code="9999", success=False, msg=err, data=None ) tree_service = TreeService()