diff --git a/ssh_config_utils/config_file.py b/ssh_config_utils/config_file.py index d3ab062..883d54f 100644 --- a/ssh_config_utils/config_file.py +++ b/ssh_config_utils/config_file.py @@ -1,6 +1,6 @@ from pathlib import Path -from ssh_config_utils.host import GlobalHost, Host +from ssh_config_utils.host import GlobalHost, Host, HostAlias from ssh_config_utils.parser import parse_config_text @@ -40,7 +40,11 @@ class ConfigFile: self.global_host = global_host for host in hosts: if isinstance(host.name, list): - name = host.name[0] + for name in host.name[1:]: + alias = HostAlias(name, host) + self.hosts.setdefault(name, alias) + host.name = host.name[0] + name = host.name else: name = host.name self.hosts.setdefault(name, host) @@ -52,11 +56,9 @@ class ConfigFile: text = file.read() else: text = fp.read() - data = parse_config_text(text) global_host = None hosts = [] - for host_data in data: for key, values in host_data.items(): if len(values) == 1: @@ -75,7 +77,7 @@ class ConfigFile: def format(self, **format_data): hosts = sorted( - self.hosts.values(), + filter(lambda host: isinstance(host, Host), self.hosts.values()), key=lambda host: host.name if isinstance(host.name, str) else " ".join(host.name), @@ -90,7 +92,6 @@ class ConfigFile: default_data = {"indent": " " * 2, "host_seperator": "\n" * 2} default_data.update(format_data) text = self.format(**default_data) - if isinstance(fp, (str, bytes, Path)): with open(fp, "w") as file: file.write(text) diff --git a/ssh_config_utils/host.py b/ssh_config_utils/host.py index ce3db20..450733f 100644 --- a/ssh_config_utils/host.py +++ b/ssh_config_utils/host.py @@ -46,17 +46,14 @@ class Host: self.parent = None def __getattr__(self, name): - if name in Host.__ignore__: return super().__getattribute__(name) return self.options[name] def __setattr__(self, name, value): - if name in Host.__ignore__: super().__setattr__(name, value) else: - proper_name = get_proper_name(name) self.options[CamelCase_to_snake(proper_name)] = value @@ -70,7 +67,13 @@ class Host: return self.parent.fullpath() def format(self, **format_data): - return serialize_host(self.fullname(), self.options, format_data) + name = self.fullname() + if len(self.aliases) > 0: + name = [name] + for alias in self.aliases: + name.append(aliases.name) + + return serialize_host(name, self.options, format_data) class GlobalHost(Host):