Compare commits

...

13 Commits
master ... neo

  1. 4
      .gitignore
  2. 1
      requirements.txt
  3. 72
      ssh_config_utils/config_file.py
  4. 63
      ssh_config_utils/host.py
  5. 15
      ssh_config_utils/parser.py
  6. 35
      tests/tests.py

4
.gitignore

@ -1,2 +1,6 @@
__pycache__ __pycache__
/tests/write_test /tests/write_test
.dir-locals.el
.#*
*.egg-info

1
requirements.txt

@ -0,0 +1 @@
wrapt

72
ssh_config_utils/config_file.py

@ -1,16 +1,52 @@
from pathlib import Path from pathlib import Path
from ssh_config_utils.host import GlobalHost, Host
from ssh_config_utils.host import GlobalHost, Host, HostAlias
from ssh_config_utils.parser import parse_config_text from ssh_config_utils.parser import parse_config_text
class HostGroup:
SEPARATOR = "/"
def __init__(self, name):
self.parent = None
self.name = name
self.children = {}
def add_child_group(self, name):
child = HostGroup(name)
child.parent = self
self.children[name] = child
return child
def add_host(self, host):
self.children[host.name] = host
host.parent = self
@property
def fullpath(self):
if self.parent is None:
return ""
return HostGroup.SEPARATOR.join(self.parent.fullpath(), self.name)
def hosts(self):
for child in self.children.values():
if isinstance(child, HostGroup):
yield from child.hosts()
elif isinstance(child, Host):
yield Host
class ConfigFile: class ConfigFile:
def __init__(self, global_host, hosts): def __init__(self, global_host, hosts):
self.hosts = {} self.hosts = {}
self.global_host = global_host self.global_host = global_host
for host in hosts: for host in hosts:
if isinstance(host.name, list): if isinstance(host.name, list):
name = host.name[0]
for name in host.name[1:]:
alias = HostAlias(name, host)
self.hosts.setdefault(name, alias)
host.name = host.name[0]
name = host.name
else: else:
name = host.name name = host.name
self.hosts.setdefault(name, host) self.hosts.setdefault(name, host)
@ -22,11 +58,9 @@ class ConfigFile:
text = file.read() text = file.read()
else: else:
text = fp.read() text = fp.read()
data = parse_config_text(text) data = parse_config_text(text)
global_host = None global_host = None
hosts = [] hosts = []
for host_data in data: for host_data in data:
for key, values in host_data.items(): for key, values in host_data.items():
if len(values) == 1: if len(values) == 1:
@ -45,7 +79,7 @@ class ConfigFile:
def format(self, **format_data): def format(self, **format_data):
hosts = sorted( hosts = sorted(
self.hosts.values(),
filter(lambda host: isinstance(host, Host), self.hosts.values()),
key=lambda host: host.name key=lambda host: host.name
if isinstance(host.name, str) if isinstance(host.name, str)
else " ".join(host.name), else " ".join(host.name),
@ -60,9 +94,35 @@ class ConfigFile:
default_data = {"indent": " " * 2, "host_seperator": "\n" * 2} default_data = {"indent": " " * 2, "host_seperator": "\n" * 2}
default_data.update(format_data) default_data.update(format_data)
text = self.format(**default_data) text = self.format(**default_data)
if isinstance(fp, (str, bytes, Path)): if isinstance(fp, (str, bytes, Path)):
with open(fp, "w") as file: with open(fp, "w") as file:
file.write(text) file.write(text)
else: else:
text = fp.write(text) text = fp.write(text)
class GroupedConfigFile(ConfigFile):
def __init__(self, global_host, hosts):
self.global_host = global_host
self.root_group = HostGroup("")
host: Host
for host in hosts:
current = self.root_group
parts = host.name.split(HostGroup.SEPARATOR)
host.name = parts[-1]
for path in parts[:-1]:
current = current.add_child_group(path)
current.add_host(host)
def format(self, **format_data):
hosts = sorted(
self.root_group.hosts(),
key=lambda host: host.name
if isinstance(host.name, str)
else " ".join(host.name),
)
if self.global_host is not None:
hosts.insert(0, self.global_host)
return format_data["host_seperator"].join(
map(lambda host: host.format(**format_data), hosts)
)

63
ssh_config_utils/host.py

@ -1,32 +1,79 @@
import wrapt
from ssh_config_utils.serializer import serialize_host, get_proper_name from ssh_config_utils.serializer import serialize_host, get_proper_name
from ssh_config_utils.parser import CamelCase_to_snake, KEYWORDS_LOWER_TRANSLATE from ssh_config_utils.parser import CamelCase_to_snake, KEYWORDS_LOWER_TRANSLATE
class HostAlias(wrapt.ObjectProxy):
def __init__(self, name, host):
super().__init__(host)
self._self_name = name
host.aliases.append(self)
@property
def name(self):
return self._self_name
@name.setter
def name(self, name):
self._self_name = name
@name.deleter
def name(self):
del self._self_name
@property
def target(self):
return self.__wrapped__
@target.setter
def target(self, target):
try:
self.aliases.remove(self)
except ValueError:
pass
self.__wrapped__ = target
self.aliases.append(self)
class Host: class Host:
def __init__(self, name, options):
self.options = options
__ignore__ = {"aliases", "name", "options", "parent"}
def __init__(self, name: str, options):
self.aliases = []
self.name = name self.name = name
self.options = options
self.parent = None
def __getattr__(self, name): def __getattr__(self, name):
if name in {"name", "options"}:
if name in Host.__ignore__:
return super().__getattribute__(name) return super().__getattribute__(name)
return self.options[name] return self.options[name]
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in {"options", "name"}:
if name in Host.__ignore__:
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
proper_name = get_proper_name(name) proper_name = get_proper_name(name)
self.options[CamelCase_to_snake(proper_name)] = value self.options[CamelCase_to_snake(proper_name)] = value
def __str__(self): def __str__(self):
return serialize_host(self.name, self.options, dict(indent=" " * 2)) return serialize_host(self.name, self.options, dict(indent=" " * 2))
def fullname(self):
if self.parent is None:
return self.name
else:
return self.parent.fullpath()
def format(self, **format_data): def format(self, **format_data):
return serialize_host(self.name, self.options, format_data)
name = self.fullname()
if len(self.aliases) > 0:
name = [name]
for alias in self.aliases:
name.append(aliases.name)
return serialize_host(name, self.options, format_data)
class GlobalHost(Host): class GlobalHost(Host):

15
ssh_config_utils/parser.py

@ -13,7 +13,17 @@ KEYWORDS_LOWER_TRANSLATE = dict(zip(KEYWORDS_LOWER, KEYWORDS))
HOST_SPLITTER = re.compile( HOST_SPLITTER = re.compile(
r"host(?:.(?!\bhost\b))+", flags=re.MULTILINE | re.DOTALL | re.IGNORECASE r"host(?:.(?!\bhost\b))+", flags=re.MULTILINE | re.DOTALL | re.IGNORECASE
) )
KEY_VALUE = re.compile(r"(?P<key>\S+)(?:[ \t]+|[ \t]*=[ \t]*)(?P<values>.*)")
KEY_VALUE = re.compile(
r"""
(?P<key>\S+) #anything that isn't a space is part of the key
(?:
[ \t]+ #seperated by whitespace
|
[ \t]*=[ \t]* #seperated by equals
)
(?P<values>.*)""",
re.VERBOSE,
)
def CamelCase_to_snake(text): # noqa def CamelCase_to_snake(text): # noqa
@ -21,13 +31,16 @@ def CamelCase_to_snake(text): # noqa
def parse_config_text(text): def parse_config_text(text):
# Step 0: strip leaing whitespaces and comments
text_no_leading_whitespace = re.sub(r"^[ \t]+", "", text, flags=re.MULTILINE) text_no_leading_whitespace = re.sub(r"^[ \t]+", "", text, flags=re.MULTILINE)
text_no_comments = re.sub( text_no_comments = re.sub(
r"^\#.*", "", text_no_leading_whitespace, flags=re.MULTILINE r"^\#.*", "", text_no_leading_whitespace, flags=re.MULTILINE
) )
# Step 1: split file by Host blocks
hosts = HOST_SPLITTER.finditer(text_no_comments) hosts = HOST_SPLITTER.finditer(text_no_comments)
for host in hosts: for host in hosts:
host_dict = {} host_dict = {}
# Step 2: go line by line looking for key=value
for line in host.group(0).split("\n"): for line in host.group(0).split("\n"):
match = KEY_VALUE.match(line) match = KEY_VALUE.match(line)
if match: if match:

35
tests/tests.py

@ -0,0 +1,35 @@
import os
import sys
import unittest
import tempfile
from ssh_config_utils.parser import CamelCase_to_snake
samples_dir = os.path.join(sys.path[0], "sample_configs")
CONFIG_FILES = dict(
(file, os.path.join(samples_dir, file)) for file in os.listdir(samples_dir)
)
class ReadWriteTest(unittest.TestCase):
def setUp(self):
self.write_file = tempfile.TemporaryFile(mode="r+")
def tearDown(self):
self.write_file.close()
class TestTempFile(ReadWriteTest):
def test_me(self):
self.write_file.write("hello")
self.write_file.seek(0)
self.assertEqual(self.write_file.read(), "hello")
class TestParser(ReadWriteTest):
def test_camel_case_to_snake_case(self):
initial_vals = "ThisIsCamelCase", "This"
expected_vals = "this_is_camel_case", "this"
for initial, expected in zip(initial_vals, expected_vals):
with self.subTest(initial=initial, expected=expected):
self.assertEqual(CamelCase_to_snake(initial), expected)
Loading…
Cancel
Save