You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
2.1 KiB

  1. class BSTNode:
  2. def __init__(self, value):
  3. self.value = value
  4. self.left = None
  5. self.right = None
  6. class BST:
  7. def __init__(self):
  8. self.root = None
  9. self.size = 0
  10. def clear(self):
  11. self.root = None
  12. self.size = 0
  13. def add(self, value):
  14. to_add = BSTNode(value)
  15. if self.root is None:
  16. self.root = to_add
  17. self.size += 1
  18. else:
  19. current = self.root
  20. parent = None
  21. while current is not None:
  22. if value > current.value:
  23. parent = current
  24. current = current.right
  25. elif value < current.value:
  26. parent = current
  27. current = current.left
  28. else:
  29. return
  30. if value > parent.value:
  31. parent.right = to_add
  32. else:
  33. parent.left = to_add
  34. self.size += 1
  35. def inOrder(self):
  36. ret = []
  37. inOrderRec(self.root, ret)
  38. return ret
  39. @classmethod
  40. def balance(cls, list_param):
  41. tree = cls()
  42. add_rec(0, len(list_param) - 1, list_param, tree)
  43. return tree
  44. def height(self):
  45. return rec_height(self.root)
  46. def inOrderRec(top, ret):
  47. if top is None:
  48. return
  49. if top.left:
  50. inOrderRec(top.left, ret)
  51. ret.append(top)
  52. if top.right:
  53. inOrderRec(top.right, ret)
  54. def rec_height(node: BSTNode):
  55. if node is None:
  56. return 0
  57. return 1 + max(map(rec_height, (node.left, node.right)))
  58. def add_rec(start, end, list_param, tree: BST):
  59. if end - start >= 0:
  60. mid = (start + end) // 2
  61. tree.add(list_param[mid])
  62. add_rec(start, mid - 1, list_param, tree)
  63. add_rec(mid + 1, end, list_param, tree)
  64. if __name__ == "__main__":
  65. test = 17
  66. list_param = list(range(test))
  67. tree = BST.balance(list_param)
  68. tree2 = BST()
  69. for i in range(test):
  70. tree2.add(i)
  71. r = list(node.value for node in tree.inOrder())
  72. r2 = list(node.value for node in tree2.inOrder())
  73. assert r == r2