class BSTNode: def __init__(self, value): self.value = value self.left = None self.right = None def size(self): ls = 0 rs = 0 if self.left is not None: ls = self.left.size() if self.right is not None: rs = self.right.size() return rs + ls + 1 def height(self): lh = 0 rh = 0 if self.left is not None: lh = self.left.height() if self.right is not None: rh = self.right.height() return 1 + max(lh, rh) class BST: def __init__(self): self.root = None self.size = 0 def clear(self): self.root = None self.size = 0 def add(self, value): to_add = BSTNode(value) if self.root is None: self.root = to_add self.size += 1 else: current = self.root parent = None while current is not None: if value > current.value: parent = current current = current.right elif value < current.value: parent = current current = current.left else: return if value > parent.value: parent.right = to_add else: parent.left = to_add self.size += 1 def lookup(self, value): current = self.root while current.value != value and current is not None: if value < current.value: current = current.left elif value > current.value: current = current.right return current def in_order(self): ret = [] in_order_rec(self.root, ret) return ret def height(self): return self.root.height() class BalancedBST(BST): def by_index(self, i): instructions = [] while i > 0: instructions.append(i & 1 == 0) i = (i - 1) // 2 current = self.root while len(instructions) > 0: go_right = instructions.pop() if go_right: current = current.right else: current = current.left return current @classmethod def balanced(cls, list_param): tree = cls() add_rec(0, len(list_param) - 1, list_param, tree) return tree def in_order_rec(top, ret): if top is None: return if top.left: in_order_rec(top.left, ret) ret.append(top) if top.right: in_order_rec(top.right, ret) def add_rec(start, end, list_param, tree: BST): if end - start >= 0: mid = (start + end) // 2 tree.add(list_param[mid]) add_rec(start, mid - 1, list_param, tree) add_rec(mid + 1, end, list_param, tree) if __name__ == "__main__": test = 17 list_param = list(range(test)) tree = BalancedBST.balanced(list_param) tree2 = BST() for i in range(test): tree2.add(i) r = list(node.value for node in tree.in_order()) r2 = list(node.value for node in tree2.in_order()) assert r == r2 tree.by_index(6)