From 0c1d49ddce5d9c744dac408780f4936a9d1cfb2a Mon Sep 17 00:00:00 2001 From: Raphael Roberts Date: Tue, 20 Aug 2019 02:02:27 -0500 Subject: [PATCH] Dealt with Host possibly having more than one entry and fixed CamelCase_to_snake --- ssh_config_utils/config_file.py | 19 +++++++++++++++---- ssh_config_utils/parser.py | 2 +- ssh_config_utils/serializer.py | 2 ++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ssh_config_utils/config_file.py b/ssh_config_utils/config_file.py index aa23779..2340e39 100644 --- a/ssh_config_utils/config_file.py +++ b/ssh_config_utils/config_file.py @@ -7,9 +7,13 @@ from ssh_config_utils.parser import parse_config_text class ConfigFile: def __init__(self, global_host, hosts): self.hosts = {} - global_host = global_host + self.global_host = global_host for host in hosts: - self.hosts.setdefault(host.name, host) + if isinstance(host.name, list): + name = host.name[0] + else: + name = host.name + self.hosts.setdefault(name, host) @classmethod def read_file(cls, fp): @@ -27,7 +31,7 @@ class ConfigFile: for key, values in host_data.items(): if len(values) == 1: host_data[key] = values[0] - if host_data["host"] == "*": + if "*" in host_data["host"]: if global_host is None: del host_data["host"] global_host = GlobalHost(host_data) @@ -40,7 +44,14 @@ class ConfigFile: return self.format(indent=" " * 2, host_seperator="\n" * 2) def format(self, **format_data): - hosts = sorted(self.hosts.values(), key=lambda host: host.name) + hosts = sorted( + self.hosts.values(), + key=lambda host: host.name + if isinstance(host.name, str) + else " ".join(host.name), + ) + if self.global_host is not None: + hosts.insert(0, self.global_host) return format_data["host_seperator"].join( map(lambda host: host.format(**format_data), hosts) ) diff --git a/ssh_config_utils/parser.py b/ssh_config_utils/parser.py index d9df5a2..208d858 100644 --- a/ssh_config_utils/parser.py +++ b/ssh_config_utils/parser.py @@ -17,7 +17,7 @@ KEY_VALUE = re.compile(r"(?P\S+)(?:[ \t]+|[ \t]*=[ \t]*)(?P.*)") def CamelCase_to_snake(text): # noqa - return "_".join(map(str.lower, re.findall(r"[A-Z][a-z]*", text))) + return "_".join(map(str.lower, re.findall(r"[A-Z][a-z]*|[0-9]+", text))) def parse_config_text(text): diff --git a/ssh_config_utils/serializer.py b/ssh_config_utils/serializer.py index 56f214b..247bd52 100644 --- a/ssh_config_utils/serializer.py +++ b/ssh_config_utils/serializer.py @@ -16,6 +16,8 @@ def quote(term): def serialize_host(name, data, format_data): + if isinstance(name, list): + name = " ".join(map(quote, name)) lines = ["Host {}".format(name)] for key, value in sorted(data.items(), key=lambda item: item[0]):