Skip to content
Snippets Groups Projects
AnyTreeTools.py 2.07 KiB
Newer Older
import anytree
from typing import Any, List, Optional

class TreeNode(anytree.node.Node):
    def __init__(self, name, parent=None, children=None, **kwargs) -> None:
        super().__init__(name, parent=parent, children=children, **kwargs)
        self.value : Optional[Any] = None

    def get_full_path(self):
        return self.separator.join([''] + [str(node.name) for node in self.path])

class RawStyle(anytree.render.AbstractStyle):
    def __init__(self):
        """
        Raw style.

        >>> from anytree import Node, RenderTree
        >>> root = Node("root")
        >>> s0 = Node("sub0", parent=root)
        >>> s0b = Node("sub0B", parent=s0)
        >>> s0a = Node("sub0A", parent=s0)
        >>> s1 = Node("sub1", parent=root)
        >>> print(RenderTree(root, style=RawStyle()))
        Node('/root')
        Node('/root/sub0')
        Node('/root/sub0/sub0B')
        Node('/root/sub0/sub0A')
        Node('/root/sub1')
        """
        super(RawStyle, self).__init__(u'', u'', u'')

def get_subnode(resolver : anytree.Resolver, root : TreeNode, path : List[str], default : Optional[Any] = None):
    node = root
    for path_item in path:
        try:
            node = resolver.get(node, path_item)
        except anytree.ChildResolverError:
            return default
    return node

def set_subnode_value(resolver : anytree.Resolver, root : TreeNode, path : List[str], value : Any):
    node = root
    for path_item in path:
        try:
            node = resolver.get(node, path_item)
        except anytree.ChildResolverError:
            node = TreeNode(path_item, parent=node)
    node.value = value

def dump_subtree(root : TreeNode):
    if not isinstance(root, TreeNode): raise Exception('root must be a TreeNode')
    results = []
    for row in anytree.RenderTree(root, style=RawStyle()):
        node : TreeNode = row.node
        path = node.get_full_path()[2:] # get full path except the heading root placeholder "/."
        if len(path) == 0: continue
        value = node.value
        if value is None: continue
        results.append((path, value))
    return results