You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.2 KiB

  1. class BSTNode:
  2. def __init__(self, value):
  3. self.value = value
  4. self.left = None
  5. self.right = None
  6. def height(self):
  7. lh = 0
  8. rh = 0
  9. if self.left is not None:
  10. lh = self.left.height()
  11. if self.right is not None:
  12. rh = self.right.height()
  13. return 1 + max(lh, rh)
  14. class BST:
  15. def __init__(self):
  16. self.root = None
  17. self.size = 0
  18. def clear(self):
  19. self.root = None
  20. self.size = 0
  21. def add(self, value):
  22. to_add = BSTNode(value)
  23. if self.root is None:
  24. self.root = to_add
  25. self.size += 1
  26. else:
  27. current = self.root
  28. parent = None
  29. while current is not None:
  30. if value > current.value:
  31. parent = current
  32. current = current.right
  33. elif value < current.value:
  34. parent = current
  35. current = current.left
  36. else:
  37. return
  38. if value > parent.value:
  39. parent.right = to_add
  40. else:
  41. parent.left = to_add
  42. self.size += 1
  43. def inOrder(self):
  44. ret = []
  45. inOrderRec(self.root, ret)
  46. return ret
  47. @classmethod
  48. def balance(cls, list_param):
  49. tree = cls()
  50. add_rec(0, len(list_param) - 1, list_param, tree)
  51. return tree
  52. def height(self):
  53. return self.root.height()
  54. def inOrderRec(top, ret):
  55. if top is None:
  56. return
  57. if top.left:
  58. inOrderRec(top.left, ret)
  59. ret.append(top)
  60. if top.right:
  61. inOrderRec(top.right, ret)
  62. def add_rec(start, end, list_param, tree: BST):
  63. if end - start >= 0:
  64. mid = (start + end) // 2
  65. tree.add(list_param[mid])
  66. add_rec(start, mid - 1, list_param, tree)
  67. add_rec(mid + 1, end, list_param, tree)
  68. if __name__ == "__main__":
  69. test = 17
  70. list_param = list(range(test))
  71. tree = BST.balance(list_param)
  72. tree2 = BST()
  73. for i in range(test):
  74. tree2.add(i)
  75. r = list(node.value for node in tree.inOrder())
  76. r2 = list(node.value for node in tree2.inOrder())
  77. assert r == r2