Skip to content
83 changes: 72 additions & 11 deletions podman/domain/containers_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

logger = logging.getLogger("podman.containers")

NAMED_VOLUME_PATTERN = re.compile(r'[a-zA-Z0-9][a-zA-Z0-9_.-]*')
NAMED_VOLUME_PATTERN = re.compile(r"[a-zA-Z0-9][a-zA-Z0-9_.-]*")


class CreateMixin: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -375,14 +375,58 @@ def create(
payload = api.prepare_body(payload)

response = self.client.post(
"/containers/create", headers={"content-type": "application/json"}, data=payload
"/containers/create",
headers={"content-type": "application/json"},
data=payload,
)
response.raise_for_status(not_found=ImageNotFound)

container_id = response.json()["Id"]

return self.get(container_id)

@staticmethod
def _convert_env_list_to_dict(env_list):
"""Convert a list of environment variables to a dictionary.

Args:
env_list (List[str]): List of environment variables in the format ["KEY=value"]

Returns:
Dict[str, str]: Dictionary of environment variables

Raises:
ValueError: If any environment variable is not in the correct format
"""
if not isinstance(env_list, list):
raise TypeError(f"Expected list, got {type(env_list).__name__}")

env_dict = {}

for env_var in env_list:
if not isinstance(env_var, str):
raise TypeError(
f"Environment variable must be a string, "
f"got {type(env_var).__name__}: {repr(env_var)}"
)

# Handle empty strings
if not env_var.strip():
raise ValueError(f"Environment variable cannot be empty")
if "=" not in env_var:
raise ValueError(
f"Environment variable '{env_var}' is not in the correct format. "
"Expected format: 'KEY=value'"
)
key, value = env_var.split("=", 1) # Split on first '=' only

# Validate key is not empty
if not key.strip():
raise ValueError(f"Environment variable has empty key: '{env_var}'")

env_dict[key] = value
return env_dict

# pylint: disable=too-many-locals,too-many-statements,too-many-branches
@staticmethod
def _render_payload(kwargs: MutableMapping[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -410,6 +454,23 @@ def _render_payload(kwargs: MutableMapping[str, Any]) -> dict[str, Any]:
with suppress(KeyError):
del args[key]

# Handle environment variables
environment = args.pop("environment", None)
if environment is not None:
if isinstance(environment, list):
try:
environment = CreateMixin._convert_env_list_to_dict(environment)
except ValueError as e:
raise ValueError(
"Failed to convert environment variables list to dictionary. "
f"Error: {str(e)}"
) from e
elif not isinstance(environment, dict):
raise TypeError(
"Environment variables must be provided as either a dictionary "
"or a list of strings in the format ['KEY=value']"
)

# These keywords are not supported for various reasons.
unsupported_keys = set(args.keys()).intersection(
(
Expand Down Expand Up @@ -466,9 +527,9 @@ def to_bytes(size: Union[int, str, None]) -> Union[int, None]:
try:
return int(size)
except ValueError as bad_size:
mapping = {'b': 0, 'k': 1, 'm': 2, 'g': 3}
mapping_regex = ''.join(mapping.keys())
search = re.search(rf'^(\d+)([{mapping_regex}])$', size.lower())
mapping = {"b": 0, "k": 1, "m": 2, "g": 3}
mapping_regex = "".join(mapping.keys())
search = re.search(rf"^(\d+)([{mapping_regex}])$", size.lower())
if search:
return int(search.group(1)) * (1024 ** mapping[search.group(2)])
raise TypeError(
Expand Down Expand Up @@ -497,7 +558,7 @@ def to_bytes(size: Union[int, str, None]) -> Union[int, None]:
"dns_search": pop("dns_search"),
"dns_server": pop("dns"),
"entrypoint": pop("entrypoint"),
"env": pop("environment"),
"env": environment,
"env_host": pop("env_host"), # TODO document, podman only
"expose": {},
"groups": pop("group_add"),
Expand Down Expand Up @@ -607,7 +668,7 @@ def to_bytes(size: Union[int, str, None]) -> Union[int, None]:
if _k in bool_options and v is True:
options.append(option_name)
elif _k in regular_options:
options.append(f'{option_name}={v}')
options.append(f"{option_name}={v}")
elif _k in simple_options:
options.append(v)

Expand Down Expand Up @@ -709,12 +770,12 @@ def parse_host_port(_container_port, _protocol, _host):

for item in args.pop("volumes", {}).items():
key, value = item
extended_mode = value.get('extended_mode', [])
extended_mode = value.get("extended_mode", [])
if not isinstance(extended_mode, list):
raise ValueError("'extended_mode' value should be a list")

options = extended_mode
mode = value.get('mode')
mode = value.get("mode")
if mode is not None:
if not isinstance(mode, str):
raise ValueError("'mode' value should be a str")
Expand All @@ -729,10 +790,10 @@ def parse_host_port(_container_port, _protocol, _host):
params["volumes"].append(volume)
else:
mount_point = {
"destination": value['bind'],
"destination": value["bind"],
"options": options,
"source": key,
"type": 'bind',
"type": "bind",
}
params["mounts"].append(mount_point)

Expand Down
38 changes: 38 additions & 0 deletions podman/tests/integration/test_container_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,44 @@ def test_container_extra_hosts(self):
for hosts_entry in formatted_hosts:
self.assertIn(hosts_entry, logs)

def test_container_environment_variables(self):
"""Test environment variables passed to the container."""
with self.subTest("Check environment variables as dictionary"):
env_dict = {"MY_VAR": "123", "ANOTHER_VAR": "456"}
container = self.client.containers.create(
self.alpine_image, command=["env"], environment=env_dict
)
self.containers.append(container)

container_env = container.attrs.get('Config', {}).get('Env', [])
for key, value in env_dict.items():
self.assertIn(f"{key}={value}", container_env)

container.start()
container.wait()
logs = b"\n".join(container.logs()).decode()

for key, value in env_dict.items():
self.assertIn(f"{key}={value}", logs)

with self.subTest("Check environment variables as list"):
env_list = ["MY_VAR=123", "ANOTHER_VAR=456"]
container = self.client.containers.create(
self.alpine_image, command=["env"], environment=env_list
)
self.containers.append(container)

container_env = container.attrs.get('Config', {}).get('Env', [])
for env in env_list:
self.assertIn(env, container_env)

container.start()
container.wait()
logs = b"\n".join(container.logs()).decode()

for env in env_list:
self.assertIn(env, logs)

def _test_memory_limit(self, parameter_name, host_config_name, set_mem_limit=False):
"""Base for tests which checks memory limits"""
memory_limit_tests = [
Expand Down
Loading