You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

50 lines
1.5 KiB

  1. import edlib
  2. import re
  3. space = re.compile('(?: (?=[^ ]))+')
  4. def ed_words(sentence1,sentence2):
  5. words1 = space.split(sentence1)
  6. words2 = space.split(sentence2)
  7. all = set(words1).union(set(words2))
  8. translation = {}
  9. for i,word in enumerate(all):
  10. translation[word] = i
  11. ed = edlib.align(
  12. bytes(translation[word] for word in words1),
  13. bytes(translation[word] for word in words2)
  14. )['editDistance']
  15. l = max(map(len,(sentence1,sentence2)))
  16. return ed/l
  17. def cluster_by_ed(sentences,threshold):
  18. '''algorithm calculates word edit distance between words, and so long as it is above a threshold adds to cluster. If above threshold, start new cluster and add new word'''
  19. ret = []
  20. sentence_list = list(sentences)
  21. cont = True
  22. index = 0
  23. while index < len(sentence_list):
  24. current = [sentence_list[index]]
  25. index += 1
  26. while index < len(sentence_list):
  27. ed = ed_words(current[0],sentence_list[index])
  28. if ed < threshold:
  29. current.append(sentence_list[index])
  30. index += 1
  31. else:
  32. break
  33. ret.append(current)
  34. return ret
  35. if __name__ == "__main__":
  36. import argparse
  37. import json
  38. import pprint
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument('threshold',type=float)
  41. args = parser.parse_args()
  42. with open('test.json') as file:
  43. data = json.load(file)
  44. l = cluster_by_ed(data,args.threshold)
  45. print(len(l))
  46. if input('pprint?: ') == 'y':
  47. pprint.pprint(l)