diff --git a/compose_viz/parser.py b/compose_viz/parser.py index 345600b..608ef85 100644 --- a/compose_viz/parser.py +++ b/compose_viz/parser.py @@ -1,11 +1,12 @@ +import re from typing import Dict, List, Optional from ruamel.yaml import YAML from compose_viz.compose import Compose, Service from compose_viz.extends import Extends -from compose_viz.port import Port -from compose_viz.volume import Volume, VolumeType +from compose_viz.port import Port, Protocol +from compose_viz.volume import AccessMode, Volume, VolumeType class Parser: @@ -58,13 +59,87 @@ class Parser: service_ports: List[Port] = [] if service.get("ports"): - if type(service["ports"]) is list: - for port_data in service["ports"]: - if ':' not in port_data: - raise RuntimeError("Invalid ports input, aborting.") - spilt_data = port_data.split(":", 1) - service_ports.append(Port(host_port=spilt_data[0], - container_port=spilt_data[1])) + assert type(service["ports"]) is list + for port_data in service["ports"]: + if type(port_data) is dict: + # define a nested function to limit variable scope + def long_syntax(): + assert port_data["target"] + + container_port: str = port_data["target"] + host_port: str = "" + protocol: Protocol = Protocol.tcp + + if port_data.get("host_port"): + host_port = port_data["host_port"] + else: + host_port = container_port + + if port_data.get("host_ip"): + host_ip = str(port_data["host_ip"]) + host_port = f"{host_ip}:{host_port}" + + if port_data.get("protocol"): + protocol = Protocol(port_data["protocol"]) + + assert host_port, "Error parsing port, aborting." + + service_ports.append( + Port( + host_port=host_port, + container_port=container_port, + protocol=protocol, + ) + ) + + long_syntax() + elif type(port_data) is str: + # ports that needs to parse using regex: + # - "3000" + # - "3000-3005" + # - "8000:8000" + # - "9090-9091:8080-8081" + # - "49100:22" + # - "127.0.0.1:8001:8001" + # - "127.0.0.1:5000-5010:5000-5010" + # - "6060:6060/udp" + + def short_syntax(): + regex = r"(?P\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:)?((?P\d+(\-\d+)?):)?((?P\d+(\-\d+)?))?(/(?P\w+))?" # noqa: E501 + match = re.match(regex, port_data) + if match: + host_ip: Optional[str] = match.group("host_ip") + host_port: Optional[str] = match.group("host_port") + container_port: Optional[str] = match.group("container_port") + protocol: Optional[str] = match.group("protocol") + + assert container_port, "Invalid port format, aborting." + + if container_port and not host_port: + host_port = container_port + + if host_ip: + host_port = f"{host_ip}{host_port}" + + assert host_port, "Error while parsing port, aborting." + + if protocol: + service_ports.append( + Port( + host_port=host_port, + container_port=container_port, + protocol=Protocol[protocol], + ) + ) + else: + service_ports.append( + Port( + host_port=host_port, + container_port=container_port, + ) + ) + + short_syntax() service_depends_on: List[str] = [] if service.get("depends_on"): @@ -74,23 +149,28 @@ class Parser: if service.get("volumes"): for volume_data in service["volumes"]: if type(volume_data) is dict: - volume_source: str = None - volume_target: str = None - volume_type: VolumeType.volume = None - if volume_data.get("source"): - volume_source = volume_data["source"] - if volume_data.get("target"): - volume_target = volume_data["target"] + assert volume_data["source"] and volume_data["target"], "Invalid volume input, aborting." + + volume_source: str = volume_data["source"] + volume_target: str = volume_data["target"] + volume_type: VolumeType = VolumeType.volume + if volume_data.get("type"): volume_type = VolumeType[volume_data["type"]] - service_volumes.append(Volume(source=volume_source, - target=volume_target, - type=volume_type)) + + service_volumes.append(Volume(source=volume_source, target=volume_target, type=volume_type)) elif type(volume_data) is str: - if ':' not in volume_data: - raise RuntimeError("Invalid volume input, aborting.") - spilt_data = volume_data.split(":", 1) - service_volumes.append(Volume(source=spilt_data[0], target=spilt_data[1])) + assert ":" in volume_data, "Invalid volume input, aborting." + + spilt_data = volume_data.split(":") + if len(spilt_data) == 2: + service_volumes.append(Volume(source=spilt_data[0], target=spilt_data[1])) + elif len(spilt_data) == 3: + service_volumes.append( + Volume( + source=spilt_data[0], target=spilt_data[1], access_mode=AccessMode[spilt_data[2]] + ) + ) service_links: List[str] = [] if service.get("links"): diff --git a/compose_viz/volume.py b/compose_viz/volume.py index d90faa9..f86d203 100644 --- a/compose_viz/volume.py +++ b/compose_viz/volume.py @@ -5,13 +5,24 @@ class VolumeType(str, Enum): volume = "volume" bind = "bind" tmpfs = "tmpfs" + npipe = "npipe" + + +class AccessMode(str, Enum): + rw = "rw" + ro = "ro" + z = "z" + Z = "Z" class Volume: - def __init__(self, source: str, target: str, type: VolumeType = VolumeType.volume): + def __init__( + self, source: str, target: str, type: VolumeType = VolumeType.volume, access_mode: AccessMode = AccessMode.rw + ): self._source = source self._target = target self._type = type + self._access_mode = access_mode @property def source(self): @@ -24,3 +35,7 @@ class Volume: @property def type(self): return self._type + + @property + def access_mode(self): + return self._access_mode diff --git a/tests/test_parse_file.py b/tests/test_parse_file.py index 7dfe071..8dc3dbf 100644 --- a/tests/test_parse_file.py +++ b/tests/test_parse_file.py @@ -5,7 +5,7 @@ from compose_viz.extends import Extends from compose_viz.parser import Parser from compose_viz.port import Port from compose_viz.service import Service -from compose_viz.volume import Volume, VolumeType +from compose_viz.volume import AccessMode, Volume, VolumeType @pytest.mark.parametrize( @@ -428,12 +428,14 @@ from compose_viz.volume import Volume, VolumeType Service( name="common", image="busybox", - volumes=[Volume(source="common-volume", target="/var/lib/backup/data:rw")], + volumes=[ + Volume(source="common-volume", target="/var/lib/backup/data", access_mode=AccessMode.rw) + ], ), Service( name="cli", extends=Extends(service_name="common"), - volumes=[Volume(source="cli-volume", target="/var/lib/backup/data:ro")], + volumes=[Volume(source="cli-volume", target="/var/lib/backup/data", access_mode=AccessMode.ro)], ), ] ), diff --git a/tests/test_volume.py b/tests/test_volume.py index 8081ae3..d1d4cac 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,4 +1,4 @@ -from compose_viz.volume import Volume, VolumeType +from compose_viz.volume import AccessMode, Volume, VolumeType def test_volume_init_normal() -> None: @@ -8,6 +8,7 @@ def test_volume_init_normal() -> None: assert v.source == "./foo" assert v.target == "./bar" assert v.type == VolumeType.volume + assert v.access_mode == AccessMode.rw except Exception as e: assert False, e @@ -19,5 +20,18 @@ def test_volume_with_type() -> None: assert v.source == "./foo" assert v.target == "./bar" assert v.type == VolumeType.bind + assert v.access_mode == AccessMode.rw + except Exception as e: + assert False, e + + +def test_volume_with_access_mode() -> None: + try: + v = Volume(source="./foo", target="./bar", access_mode=AccessMode.ro) + + assert v.source == "./foo" + assert v.target == "./bar" + assert v.type == VolumeType.volume + assert v.access_mode == AccessMode.ro except Exception as e: assert False, e