diff --git a/ctabus.py b/ctabus.py index 17154e2..6e968cd 100644 --- a/ctabus.py +++ b/ctabus.py @@ -1,5 +1,6 @@ from urllib.parse import urlencode from urllib.request import urlopen +from disk_cache import disk_cache import json from sensitive import api @@ -25,13 +26,24 @@ def get_times(stop_id, api_key=api, timeout=None): return get_data('getpredictions', api_key, stpid=stop_id, timeout=timeout) +@disk_cache def get_routes(api_key=api, timeout=None): return get_data('getroutes', api_key, timeout=timeout) +@disk_cache def get_directions(route, api_key=api, timeout=None): return get_data('getdirections', api_key, rt=route, timeout=timeout) +@disk_cache def get_stops(route, direction, api_key=api, timeout=None): - return get_data('getstops', api_key, rt=route, dir=direction, timeout=timeout) + return get_data('getstops', api_key, rt=route, dir=direction, + timeout=timeout) + + +@disk_cache +def get_name_from_direction(route, direction, api_key=api, timeout=None): + test_stop = get_stops(route, direction, api_key=api_key, + timeout=timeout)['stops'][0]['stpid'] + return get_times(test_stop, api_key=api, timeout=timeout)['prd'][0]['des'] diff --git a/disk_cache.py b/disk_cache.py index a6af36b..b5a9b3c 100644 --- a/disk_cache.py +++ b/disk_cache.py @@ -1,34 +1,54 @@ import pickle import os -import lmza +import lzma cache_path = os.path.abspath(os.path.join(__file__, "..", "__pycache__")) if not os.path.exists(cache_path): os.mkdir(cache_path) +def make_key(*args, **kwargs): + + return args, tuple(sorted( + kwargs.items(), key=lambda item: item[0])) + + class disk_cache: + """Decorator to make persistent cache""" + caches = [] + def __init__(self, func): self.fname = "{}.{}.dc".format(func.__module__, func.__name__) - fname = os.path.join(cache_path, fname) + self.fname = os.path.join(cache_path, self.fname) self.func = func - self.cache = self.load_cache() + self.load_cache() + disk_cache.caches.append(self) - def call(self, *args, **kwargs): - key = args + tuple(sorted(self.kwargs.items(), key=lambda item)) + def __call__(self, *args, **kwargs): + key = make_key(*args, **kwargs) try: - return self.cache[key] + res = self.cache[key] + return res except KeyError: + self.fresh = True res = self.func(*args, **kwargs) self.cache[key] = res + return res def load_cache(self): try: - with lmza.open(self.fname, 'rb') as file: + with lzma.open(self.fname, 'rb') as file: cache = pickle.load(file) + self.fresh = False except FileNotFoundError: cache = {} + self.fresh = True self.cache = cache def save_cache(self): - with lmza.open(self.fname, 'wb') as file: - pickle.dump(self.cache, file) + with lzma.open(self.fname, 'wb') as file: + pickle.dump(self.cache, file, pickle.HIGHEST_PROTOCOL) + + def delete_cache(self): + os.remove(self.fname) + self.cache = {} + self.fresh = True diff --git a/main.py b/main.py index 1068888..a794c48 100755 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ #!/usr/bin/python3 from dateutil.parser import parse as date_parse from dateutil import tz +from disk_cache import disk_cache, make_key import argparse import ctabus import datetime @@ -14,6 +15,7 @@ import subprocess import os.path as osp import sys CHICAGO_TZ = tz.gettz("America/Chicago") +DATETIME_FORMAT = "%A, %B %e, %Y %H:%M:%S" # https://stackoverflow.com/a/5967539 @@ -73,7 +75,8 @@ def pprint_delta(delta): def gen_list(objs, data, *displays, key=None, sort=0, num_pic=True): - from print2d import print2d + from print2d import create_table, render_table + # sort based on column number k = displays[sort] display_data = {obj[k]: obj[data] for obj in objs} srt_keys = sorted(display_data.keys(), key=key) @@ -87,15 +90,8 @@ def gen_list(objs, data, *displays, key=None, sort=0, num_pic=True): if num_pic: display = [[i] + data for i, data in enumerate(display)] - opts = { - 'spacer': ' ', - 'seperator': ' ', - 'interactive': True, - 'bottom': '=', - 'l_end': '<', - 'r_end': '>', - } - print2d(display, **opts) + table = create_table(display, DATETIME_FORMAT) + render_table(table) if num_pic: which = None while not which: @@ -105,7 +101,7 @@ def gen_list(objs, data, *displays, key=None, sort=0, num_pic=True): quit() try: which = srt_keys[int(which)] - except ValueError: + except (ValueError, IndexError): which = None return display_data[which] else: @@ -170,7 +166,11 @@ def main(args): data = ctabus.get_directions(route)['directions'] # direction if not args.direction: - direction = gen_list(data, 'dir', 'dir') + for direction_obj in data: + friendly_name = ctabus.get_name_from_direction( + route, direction_obj['dir']) + direction_obj['friendly_name'] = friendly_name + direction = gen_list(data, 'dir', 'dir', 'friendly_name') else: s = Search(args.direction) direction = sorted((obj['dir'] for obj in data), key=s)[0] @@ -185,6 +185,10 @@ def main(args): else: stop_id = args.arg data = ctabus.get_times(stop_id) + info = data['prd'][0] + key = make_key(info['rt'], info['rtdir'], ctabus.api, None) + if key not in ctabus.get_name_from_direction.cache.keys(): + ctabus.get_name_from_direction.cache[key] = info['des'] if args.periodic is not None: _done = False while not _done: @@ -215,7 +219,14 @@ if __name__ == '__main__': parser.add_argument('-r', '--route', default=None) parser.add_argument('-d', '--direction', default=None) parser.add_argument('-t', '--disable_toast', action='store_false') + parser.add_argument('-k', '--kill-cache', action="store_true") parser.add_argument('arg', nargs='+', metavar='(stop-id | cross streets)') args = parser.parse_args() sys.stderr = open(osp.join(osp.dirname(__file__), 'stderr.log'), 'w') + if args.kill_cache: + for cache_obj in disk_cache.caches: + cache_obj.kill_cache() main(args) + for cache_obj in disk_cache.caches: + if cache_obj.fresh: + cache_obj.save_cache() diff --git a/print2d.py b/print2d.py index 100b993..67cf1ce 100644 --- a/print2d.py +++ b/print2d.py @@ -1,5 +1,34 @@ +from terminaltables.terminal_io import terminal_size +from terminaltables import AsciiTable +from textwrap import fill +from pydoc import pipepager, tempfilepager, plainpager, plain import datetime -from pydoc import pager +import os +import sys + + +def getpager(): + """Decide what method to use for paging through text.""" + if not hasattr(sys.stdin, "isatty"): + return plainpager + if not hasattr(sys.stdout, "isatty"): + return plainpager + if not sys.stdin.isatty() or not sys.stdout.isatty(): + return plainpager + use_pager = os.environ.get('MANPAGER') or os.environ.get('PAGER') + if use_pager: + if sys.platform == 'win32': # pipes completely broken in Windows + return lambda text: tempfilepager(plain(text), use_pager) + elif os.environ.get('TERM') in ('dumb', 'emacs'): + return lambda text: pipepager(plain(text), use_pager) + else: + return lambda text: pipepager(text, use_pager) + if os.environ.get('TERM') in ('dumb', 'emacs'): + return plainpager + if sys.platform == 'win32': + return lambda text: tempfilepager(plain(text), 'more <') + if hasattr(os, 'system') and os.system('(less) 2>/dev/null') == 0: + return lambda text: pipepager(text, 'less -X') def str_coerce(s, **kwargs): @@ -9,48 +38,33 @@ def str_coerce(s, **kwargs): return str(s) -def print2d(list_param, - datetime_format="%A, %B %e, %Y %H:%M:%S", - seperator=' | ', - spacer=' ', - bottom='=', - l_end='|', r_end='|', - interactive=False - ): - list_param = [[str_coerce(s, datetime_format=datetime_format) - for s in row] for row in list_param] - - max_col = [] +def create_table(list_param, datetime_format): + rows = [] for row in list_param: - for i, col in enumerate(row): - try: - max_col[i] = max(max_col[i], len(col)) - except IndexError: - max_col.append(len(col)) - - fmt_row = '{content}' - if l_end: - fmt_row = '{} {}'.format(l_end, fmt_row) - if r_end: - fmt_row = '{} {}'.format(fmt_row, r_end) - - done = [] - for row in list_param: - content = seperator.join(col.ljust(max_col[i], spacer if i < len( - row)-1 or r_end else ' ') for i, col in enumerate(row)) - done.append(fmt_row.format(content=content)) + rows.append([]) + for item in row: + rows[-1].append(str_coerce(item, datetime_format=datetime_format)) + return AsciiTable(rows) - if bottom: - bottom = bottom*len(done[0]) - row_sep = ('\n'+bottom+'\n') - else: - row_sep = '\n' - final = row_sep.join(done) - if bottom: - final = '\n'.join((bottom, final, bottom)) + +def render_table(table: AsciiTable, interactive=True): + '''Do all wrapping to make the table fit in screen''' + table.inner_row_border = True + data = table.table_data + terminal_width = terminal_size()[0] + n_cols = len(data[0]) + even_distribution = terminal_width // n_cols + for row_num, row in enumerate(data): + for col_num, col_data in enumerate(row): + if len(col_data) > even_distribution: + if col_num != n_cols - 1: + data[row_num][col_num] = fill(col_data, even_distribution) + else: + data[row_num][col_num] = '' + data[row_num][col_num] = fill( + col_data, table.column_max_width(col_num)) if interactive: - if not bottom: - final += '\n' - pager(final) + pager = getpager() + pager(table.table) else: - return final + print(table.table diff --git a/requirements.txt b/requirements.txt index db25db4..0c78596 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ edlib -python-dateutil \ No newline at end of file +python-dateutil +terminaltables