|
|
@ -1,16 +1,52 @@ |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
from ssh_config_utils.host import GlobalHost, Host |
|
|
|
|
|
|
|
|
from ssh_config_utils.host import GlobalHost, Host, HostAlias |
|
|
from ssh_config_utils.parser import parse_config_text |
|
|
from ssh_config_utils.parser import parse_config_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HostGroup: |
|
|
|
|
|
SEPARATOR = "/" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, name): |
|
|
|
|
|
self.parent = None |
|
|
|
|
|
self.name = name |
|
|
|
|
|
self.children = {} |
|
|
|
|
|
|
|
|
|
|
|
def add_child_group(self, name): |
|
|
|
|
|
child = HostGroup(name) |
|
|
|
|
|
child.parent = self |
|
|
|
|
|
self.children[name] = child |
|
|
|
|
|
return child |
|
|
|
|
|
|
|
|
|
|
|
def add_host(self, host): |
|
|
|
|
|
self.children[host.name] = host |
|
|
|
|
|
host.parent = self |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def fullpath(self): |
|
|
|
|
|
if self.parent is None: |
|
|
|
|
|
return "" |
|
|
|
|
|
return HostGroup.SEPARATOR.join(self.parent.fullpath(), self.name) |
|
|
|
|
|
|
|
|
|
|
|
def hosts(self): |
|
|
|
|
|
for child in self.children.values(): |
|
|
|
|
|
if isinstance(child, HostGroup): |
|
|
|
|
|
yield from child.hosts() |
|
|
|
|
|
elif isinstance(child, Host): |
|
|
|
|
|
yield Host |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigFile: |
|
|
class ConfigFile: |
|
|
def __init__(self, global_host, hosts): |
|
|
def __init__(self, global_host, hosts): |
|
|
self.hosts = {} |
|
|
self.hosts = {} |
|
|
self.global_host = global_host |
|
|
self.global_host = global_host |
|
|
for host in hosts: |
|
|
for host in hosts: |
|
|
if isinstance(host.name, list): |
|
|
if isinstance(host.name, list): |
|
|
name = host.name[0] |
|
|
|
|
|
|
|
|
for name in host.name[1:]: |
|
|
|
|
|
alias = HostAlias(name, host) |
|
|
|
|
|
self.hosts.setdefault(name, alias) |
|
|
|
|
|
host.name = host.name[0] |
|
|
|
|
|
name = host.name |
|
|
else: |
|
|
else: |
|
|
name = host.name |
|
|
name = host.name |
|
|
self.hosts.setdefault(name, host) |
|
|
self.hosts.setdefault(name, host) |
|
|
@ -22,11 +58,9 @@ class ConfigFile: |
|
|
text = file.read() |
|
|
text = file.read() |
|
|
else: |
|
|
else: |
|
|
text = fp.read() |
|
|
text = fp.read() |
|
|
|
|
|
|
|
|
data = parse_config_text(text) |
|
|
data = parse_config_text(text) |
|
|
global_host = None |
|
|
global_host = None |
|
|
hosts = [] |
|
|
hosts = [] |
|
|
|
|
|
|
|
|
for host_data in data: |
|
|
for host_data in data: |
|
|
for key, values in host_data.items(): |
|
|
for key, values in host_data.items(): |
|
|
if len(values) == 1: |
|
|
if len(values) == 1: |
|
|
@ -45,7 +79,7 @@ class ConfigFile: |
|
|
|
|
|
|
|
|
def format(self, **format_data): |
|
|
def format(self, **format_data): |
|
|
hosts = sorted( |
|
|
hosts = sorted( |
|
|
self.hosts.values(), |
|
|
|
|
|
|
|
|
filter(lambda host: isinstance(host, Host), self.hosts.values()), |
|
|
key=lambda host: host.name |
|
|
key=lambda host: host.name |
|
|
if isinstance(host.name, str) |
|
|
if isinstance(host.name, str) |
|
|
else " ".join(host.name), |
|
|
else " ".join(host.name), |
|
|
@ -60,9 +94,35 @@ class ConfigFile: |
|
|
default_data = {"indent": " " * 2, "host_seperator": "\n" * 2} |
|
|
default_data = {"indent": " " * 2, "host_seperator": "\n" * 2} |
|
|
default_data.update(format_data) |
|
|
default_data.update(format_data) |
|
|
text = self.format(**default_data) |
|
|
text = self.format(**default_data) |
|
|
|
|
|
|
|
|
if isinstance(fp, (str, bytes, Path)): |
|
|
if isinstance(fp, (str, bytes, Path)): |
|
|
with open(fp, "w") as file: |
|
|
with open(fp, "w") as file: |
|
|
file.write(text) |
|
|
file.write(text) |
|
|
else: |
|
|
else: |
|
|
text = fp.write(text) |
|
|
text = fp.write(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroupedConfigFile(ConfigFile): |
|
|
|
|
|
def __init__(self, global_host, hosts): |
|
|
|
|
|
self.global_host = global_host |
|
|
|
|
|
self.root_group = HostGroup("") |
|
|
|
|
|
host: Host |
|
|
|
|
|
for host in hosts: |
|
|
|
|
|
current = self.root_group |
|
|
|
|
|
parts = host.name.split(HostGroup.SEPARATOR) |
|
|
|
|
|
host.name = parts[-1] |
|
|
|
|
|
for path in parts[:-1]: |
|
|
|
|
|
current = current.add_child_group(path) |
|
|
|
|
|
current.add_host(host) |
|
|
|
|
|
|
|
|
|
|
|
def format(self, **format_data): |
|
|
|
|
|
hosts = sorted( |
|
|
|
|
|
self.root_group.hosts(), |
|
|
|
|
|
key=lambda host: host.name |
|
|
|
|
|
if isinstance(host.name, str) |
|
|
|
|
|
else " ".join(host.name), |
|
|
|
|
|
) |
|
|
|
|
|
if self.global_host is not None: |
|
|
|
|
|
hosts.insert(0, self.global_host) |
|
|
|
|
|
return format_data["host_seperator"].join( |
|
|
|
|
|
map(lambda host: host.format(**format_data), hosts) |
|
|
|
|
|
) |