diff --git a/fileglancer/app.py b/fileglancer/app.py index c655feab..9be8db90 100644 --- a/fileglancer/app.py +++ b/fileglancer/app.py @@ -865,14 +865,20 @@ async def get_profile(username: str = Depends(get_current_user)): # SSH Key Management endpoints @app.get("/api/ssh-keys", response_model=sshkeys.SSHKeyListResponse, - description="List all SSH keys in the user's ~/.ssh directory") + description="List Fileglancer-managed SSH keys") async def list_ssh_keys(username: str = Depends(get_current_user)): - """List SSH keys for the authenticated user""" + """List SSH keys with 'fileglancer' in the comment""" with _get_user_context(username): try: ssh_dir = sshkeys.get_ssh_directory() keys = sshkeys.list_ssh_keys(ssh_dir) - return sshkeys.SSHKeyListResponse(keys=keys) + exists, unmanaged, missing_pubkey = sshkeys.check_id_ed25519_status(ssh_dir) + return sshkeys.SSHKeyListResponse( + keys=keys, + unmanaged_id_ed25519_exists=unmanaged, + id_ed25519_exists=exists, + id_ed25519_missing_pubkey=missing_pubkey + ) except Exception as e: logger.error(f"Error listing SSH keys for {username}: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -951,20 +957,56 @@ async def authorize_ssh_key(username: str = Depends(get_current_user)): logger.error(f"Error authorizing SSH key for {username}: {e}") raise HTTPException(status_code=500, detail=str(e)) - @app.get("/api/ssh-keys/content", response_model=sshkeys.SSHKeyContent, + @app.post("/api/ssh-keys/regenerate-public", + response_model=sshkeys.SSHKeyInfo, + description="Regenerate public key from private key") + async def regenerate_public_key( + request: sshkeys.GenerateKeyRequest = Body(default=sshkeys.GenerateKeyRequest()), + username: str = Depends(get_current_user) + ): + """Regenerate the public key from the private key. + + If the private key is encrypted, provide the passphrase in the request body. + """ + with _get_user_context(username): + try: + ssh_dir = sshkeys.get_ssh_directory() + key_info = sshkeys.regenerate_public_key( + ssh_dir, + passphrase=request.passphrase + ) + return key_info + + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except RuntimeError as e: + # Check for passphrase errors + if "passphrase" in str(e).lower(): + raise HTTPException(status_code=401, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + logger.error(f"Error regenerating public key for {username}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @app.get("/api/ssh-keys/content", description="Get the content of the default SSH key (id_ed25519)") async def get_ssh_key_content( key_type: str = Query(..., description="Type of key to fetch: 'public' or 'private'"), username: str = Depends(get_current_user) ): - """Get the public or private key content for copying""" + """Get the public or private key content for copying. + + Returns plain text response with secure bytearray handling that wipes + the key content from memory after sending. + """ if key_type not in ("public", "private"): raise HTTPException(status_code=400, detail="key_type must be 'public' or 'private'") with _get_user_context(username): try: ssh_dir = sshkeys.get_ssh_directory() - return sshkeys.get_key_content(ssh_dir, "id_ed25519", key_type) + key_buffer = sshkeys.get_key_content(ssh_dir, "id_ed25519", key_type) + return sshkeys.SSHKeyContentResponse(key_buffer) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -972,6 +1014,32 @@ async def get_ssh_key_content( logger.error(f"Error getting SSH key content for {username}: {e}") raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/ssh-keys/generate-temp", + description="Generate a temporary SSH key and return private key for one-time copy") + async def generate_temp_ssh_key( + request: sshkeys.GenerateKeyRequest = Body(default=sshkeys.GenerateKeyRequest()), + username: str = Depends(get_current_user) + ): + """Generate a temporary SSH key, add to authorized_keys, return private key. + + The private key is streamed securely and the temporary files are deleted + after the response is sent. Key info is included in response headers: + - X-SSH-Key-Filename + - X-SSH-Key-Type + - X-SSH-Key-Fingerprint + - X-SSH-Key-Comment + """ + with _get_user_context(username): + try: + ssh_dir = sshkeys.get_ssh_directory() + return sshkeys.generate_temp_key_and_authorize(ssh_dir, request.passphrase) + + except RuntimeError as e: + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + logger.error(f"Error generating temp SSH key for {username}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + # File content endpoint @app.head("/api/content/{path_name:path}") async def head_file_content(path_name: str, diff --git a/fileglancer/sshkeys.py b/fileglancer/sshkeys.py index 888df280..381fbfa7 100644 --- a/fileglancer/sshkeys.py +++ b/fileglancer/sshkeys.py @@ -4,14 +4,16 @@ in a user's ~/.ssh directory. """ +import gc import os import shutil import subprocess import tempfile -from typing import List, Optional +from typing import Dict, List, Optional +from fastapi.responses import Response from loguru import logger -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr # Constants AUTHORIZED_KEYS_FILENAME = "authorized_keys" @@ -77,6 +79,18 @@ class SSHKeyInfo(BaseModel): class SSHKeyListResponse(BaseModel): """Response containing a list of SSH keys""" keys: List[SSHKeyInfo] = Field(description="List of SSH keys") + unmanaged_id_ed25519_exists: bool = Field( + default=False, + description="True if id_ed25519 exists but is not managed by fileglancer" + ) + id_ed25519_exists: bool = Field( + default=False, + description="True if id_ed25519 private key exists" + ) + id_ed25519_missing_pubkey: bool = Field( + default=False, + description="True if id_ed25519 exists and is managed but .pub file is missing" + ) class GenerateKeyResponse(BaseModel): @@ -85,14 +99,118 @@ class GenerateKeyResponse(BaseModel): message: str = Field(description="Status message") -class SSHKeyContent(BaseModel): - """SSH key content - only fetched on demand""" - key: str = Field(description="The requested key content") - - class GenerateKeyRequest(BaseModel): """Request body for generating an SSH key""" - passphrase: Optional[str] = Field(default=None, description="Optional passphrase to protect the private key") + passphrase: Optional[SecretStr] = Field(default=None, description="Optional passphrase to protect the private key") + + +def _wipe_bytearray(data: bytearray) -> None: + """Securely wipe a bytearray by overwriting with zeros. + + Args: + data: The bytearray to wipe + """ + data[:] = b'\x00' * len(data) + + +class SSHKeyContentResponse(Response): + """Secure streaming response for SSH key content that minimizes memory exposure. + + This response class streams the key content line-by-line via ASGI to minimize + buffering in h11/uvicorn. By sending each chunk with "more_body": True, h11 + can flush each line immediately rather than buffering the entire key in memory. + + After streaming completes, the original bytearray is wiped with zeros and + garbage collection is triggered to clean up any intermediate copies. + + Uses memoryview to avoid creating unnecessary copies when iterating over lines. + """ + + media_type = "text/plain" + charset = "utf-8" + + def __init__( + self, + key_content: bytearray, + status_code: int = 200, + headers: Optional[Dict[str, str]] = None, + ): + """Initialize the response with key content stored in a bytearray. + + Args: + key_content: The SSH key content as a mutable bytearray + status_code: HTTP status code (default 200) + headers: Optional additional headers + """ + # Store the key buffer - do NOT convert to bytes + self._key_buffer = key_content + + # Merge Content-Length header with any provided headers + # This is required because we don't pass content to the parent + all_headers = dict(headers) if headers else {} + all_headers["content-length"] = str(len(key_content)) + + # Initialize parent without content - we'll send body directly in __call__ + super().__init__( + status_code=status_code, + headers=all_headers, + media_type=self.media_type, + ) + + def _iter_lines(self): + """Yield memoryview slices of the key buffer line by line. + + Using memoryview avoids creating copies of the data. Each yielded + view is a reference to a slice of the original bytearray. + + Yields: + memoryview: A view into the buffer for each line (including newline) + """ + buffer = self._key_buffer + view = memoryview(buffer) + start = 0 + + while start < len(buffer): + # Find next newline + try: + end = buffer.index(b'\n', start) + 1 + except ValueError: + # No more newlines, yield the rest + end = len(buffer) + + yield view[start:end] + start = end + + async def __call__(self, scope, receive, send): + """Stream the response line-by-line, then wipe the sensitive buffer. + + Sends each line with "more_body": True to signal h11 that it can flush + immediately, reducing memory buffering. After all lines are sent, sends + an empty body with "more_body": False to complete the response. + + The bytearray is wiped in the finally block regardless of success or error. + """ + await send({ + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + }) + try: + for line in self._iter_lines(): + await send({ + "type": "http.response.body", + "body": line, + "more_body": True, + }) + # Final empty body to signal end of response + await send({ + "type": "http.response.body", + "body": b"", + "more_body": False, + }) + finally: + _wipe_bytearray(self._key_buffer) + gc.collect() def get_ssh_directory() -> str: @@ -200,49 +318,45 @@ def parse_public_key(pubkey_path: str, ssh_dir: str) -> SSHKeyInfo: ) -def _read_file_secure(file_path: str) -> str: - """Read a file securely using a mutable bytearray that is cleared after use. +def read_file_to_bytearray(file_path: str) -> bytearray: + """Read a file into a mutable bytearray for secure handling. - This function reads sensitive data (like private keys) into a mutable bytearray - which is explicitly overwritten with zeros after converting to a string. This - reduces the window during which the sensitive data exists in memory and prevents - it from persisting in immutable strings that could appear in core dumps. + This function reads file contents into a mutable bytearray that can be + explicitly wiped from memory when no longer needed. The caller is responsible + for wiping the bytearray when done (e.g., by passing it to SSHKeyContentResponse + which wipes it after sending). Args: file_path: Path to the file to read Returns: - The file contents as a string + The file contents as a mutable bytearray Note: - While the returned string is still immutable and will exist in memory, - this approach minimizes the exposure by clearing the mutable buffer - immediately after use. + The returned bytearray should be wiped with _wipe_bytearray() when no + longer needed. SSHKeyContentResponse handles this automatically. """ # Get file size to allocate bytearray file_size = os.path.getsize(file_path) # Read into mutable bytearray - sensitive_buffer = bytearray(file_size) + key_buffer = bytearray(file_size) - try: - with open(file_path, 'rb') as f: - bytes_read = f.readinto(sensitive_buffer) - if bytes_read != file_size: - raise IOError(f"Expected to read {file_size} bytes, but read {bytes_read}") + with open(file_path, 'rb') as f: + bytes_read = f.readinto(key_buffer) + if bytes_read != file_size: + # Wipe partial data on error + _wipe_bytearray(key_buffer) + raise IOError(f"Expected to read {file_size} bytes, but read {bytes_read}") - # Convert to string (this creates an immutable copy) - result = sensitive_buffer.decode('utf-8') + return key_buffer - return result - finally: - # Explicitly overwrite the mutable buffer with zeros - for i in range(len(sensitive_buffer)): - sensitive_buffer[i] = 0 +def get_key_content(ssh_dir: str, filename: str, key_type: str = "public") -> bytearray: + """Get the content of an SSH key as a mutable bytearray. -def get_key_content(ssh_dir: str, filename: str, key_type: str = "public") -> SSHKeyContent: - """Get the content of an SSH key (public or private). + The returned bytearray should be passed to SSHKeyContentResponse, which + will wipe it after sending the response. Do not convert to str or bytes. Args: ssh_dir: Path to the .ssh directory @@ -250,38 +364,34 @@ def get_key_content(ssh_dir: str, filename: str, key_type: str = "public") -> SS key_type: Type of key to fetch: 'public' or 'private' Returns: - SSHKeyContent with the requested key content + bytearray containing the key content (caller must wipe when done) Raises: - ValueError: If the key doesn't exist or is invalid + ValueError: If the key doesn't exist or key_type is invalid """ if key_type == "public": - pubkey_path = safe_join_path(ssh_dir, f"{filename}.pub") - if not os.path.exists(pubkey_path): + key_path = safe_join_path(ssh_dir, f"{filename}.pub") + if not os.path.exists(key_path): raise ValueError(f"Public key '{filename}' not found") - with open(pubkey_path, 'r') as f: - return SSHKeyContent(key=f.read().strip()) - elif key_type == "private": - private_key_path = safe_join_path(ssh_dir, filename) - if not os.path.exists(private_key_path): + key_path = safe_join_path(ssh_dir, filename) + if not os.path.exists(key_path): raise ValueError(f"Private key '{filename}' not found") - # Use secure reading for private keys - key_content = _read_file_secure(private_key_path) - return SSHKeyContent(key=key_content) + else: + raise ValueError(f"Invalid key_type: {key_type}") - raise ValueError(f"Invalid key_type: {key_type}") + return read_file_to_bytearray(key_path) def is_key_in_authorized_keys(ssh_dir: str, fingerprint: str) -> bool: - """Check if a key with the given fingerprint is in authorized_keys. + """Check if a key with the given fingerprint is in authorized_keys with 'fileglancer' comment. Args: ssh_dir: Path to the .ssh directory fingerprint: The SHA256 fingerprint to look for Returns: - True if the key is in authorized_keys, False otherwise + True if the key is in authorized_keys with 'fileglancer' in the comment, False otherwise """ authorized_keys_path = os.path.join(ssh_dir, AUTHORIZED_KEYS_FILENAME) @@ -290,6 +400,7 @@ def is_key_in_authorized_keys(ssh_dir: str, fingerprint: str) -> bool: try: # Get fingerprints of all keys in authorized_keys + # Output format: "256 SHA256:xxxxx comment (ED25519)" result = subprocess.run( ['ssh-keygen', '-lf', authorized_keys_path], capture_output=True, @@ -301,9 +412,9 @@ def is_key_in_authorized_keys(ssh_dir: str, fingerprint: str) -> bool: logger.warning(f"Could not check authorized_keys: {result.stderr}") return False - # Check each line for the fingerprint + # Check each line for the fingerprint AND 'fileglancer' in comment for line in result.stdout.strip().split('\n'): - if fingerprint in line: + if fingerprint in line and 'fileglancer' in line: return True return False @@ -312,27 +423,109 @@ def is_key_in_authorized_keys(ssh_dir: str, fingerprint: str) -> bool: return False +def _parse_authorized_keys_fileglancer(ssh_dir: str) -> Dict[str, SSHKeyInfo]: + """Parse authorized_keys and return keys with 'fileglancer' in the comment. + + Args: + ssh_dir: Path to the .ssh directory + + Returns: + Dict mapping fingerprint to SSHKeyInfo for keys with 'fileglancer' comment + """ + keys_by_fingerprint: Dict[str, SSHKeyInfo] = {} + authorized_keys_path = os.path.join(ssh_dir, AUTHORIZED_KEYS_FILENAME) + + if not os.path.exists(authorized_keys_path): + return keys_by_fingerprint + + try: + # Use ssh-keygen to get fingerprints and comments from authorized_keys + # Output format: "256 SHA256:xxxxx comment (ED25519)" + result = subprocess.run( + ['ssh-keygen', '-lf', authorized_keys_path], + capture_output=True, + text=True, + timeout=30 + ) + + if result.returncode != 0: + logger.warning(f"Could not read authorized_keys: {result.stderr}") + return keys_by_fingerprint + + for line in result.stdout.strip().split('\n'): + if not line or 'fileglancer' not in line: + continue + + # Parse: "256 SHA256:xxxxx comment text (ED25519)" + parts = line.split() + if len(parts) < 4: + continue + + fingerprint = parts[1] + key_type_raw = parts[-1] # e.g., "(ED25519)" + + # Extract key type, converting to ssh- format + key_type = key_type_raw.strip('()') + if key_type == "ED25519": + key_type = "ssh-ed25519" + elif key_type == "RSA": + key_type = "ssh-rsa" + elif key_type == "ECDSA": + key_type = "ecdsa-sha2-nistp256" + elif key_type == "DSA": + key_type = "ssh-dss" + else: + key_type = f"ssh-{key_type.lower()}" + + # Comment is everything between fingerprint and key type + comment = " ".join(parts[2:-1]) + + keys_by_fingerprint[fingerprint] = SSHKeyInfo( + filename=fingerprint, # Use fingerprint as filename placeholder + key_type=key_type, + fingerprint=fingerprint, + comment=comment, + has_private_key=False, # Don't search for private keys + is_authorized=True + ) + + except Exception as e: + logger.warning(f"Error parsing authorized_keys: {e}") + + return keys_by_fingerprint + + def list_ssh_keys(ssh_dir: str) -> List[SSHKeyInfo]: - """List all SSH keys in the given directory. + """List SSH keys with 'fileglancer' in the comment. + + Collects keys from both .pub files and authorized_keys, filtering to only + those with 'fileglancer' in the comment. Keys are deduplicated by fingerprint, + preferring .pub file info when available. Args: ssh_dir: Path to the .ssh directory Returns: - List of SSHKeyInfo objects + List of SSHKeyInfo objects for fileglancer-managed keys """ - keys = [] + keys_by_fingerprint: Dict[str, SSHKeyInfo] = {} if not os.path.exists(ssh_dir): - return keys + return [] - # Find all .pub files + # First, get keys from authorized_keys with 'fileglancer' comment + keys_by_fingerprint = _parse_authorized_keys_fileglancer(ssh_dir) + + # Then, scan .pub files and override/add entries with 'fileglancer' comment + # .pub files take precedence since they have better filename info for filename in os.listdir(ssh_dir): if filename.endswith('.pub'): try: pubkey_path = safe_join_path(ssh_dir, filename) key_info = parse_public_key(pubkey_path, ssh_dir) - keys.append(key_info) + # Only include keys with 'fileglancer' in the comment + if 'fileglancer' in key_info.comment: + keys_by_fingerprint[key_info.fingerprint] = key_info except ValueError as e: logger.warning(f"Skipping suspicious filename {filename}: {e}") continue @@ -340,15 +533,17 @@ def list_ssh_keys(ssh_dir: str) -> List[SSHKeyInfo]: logger.warning(f"Could not parse key {filename}: {e}") continue - # Sort by filename - keys.sort(key=lambda k: k.filename) + keys = list(keys_by_fingerprint.values()) + + # Sort by filename, but put id_ed25519 first + keys.sort(key=lambda k: (0 if k.filename == "id_ed25519" else 1, k.filename)) - logger.info(f"Listed {len(keys)} SSH keys in {ssh_dir}") + logger.info(f"Listed {len(keys)} SSH keys with 'fileglancer' comment in {ssh_dir}") return keys -def generate_ssh_key(ssh_dir: str, passphrase: Optional[str] = None) -> SSHKeyInfo: +def generate_ssh_key(ssh_dir: str, passphrase: Optional[SecretStr] = None) -> SSHKeyInfo: """Generate the default ed25519 SSH key (id_ed25519). Args: @@ -376,15 +571,19 @@ def generate_ssh_key(ssh_dir: str, passphrase: Optional[str] = None) -> SSHKeyIn raise ValueError(f"SSH key '{key_name}' already exists") # Build ssh-keygen command + passphrase_str = passphrase.get_secret_value() if passphrase else "" cmd = [ 'ssh-keygen', '-t', 'ed25519', - '-N', passphrase or '', # Empty string if no passphrase + '-N', passphrase_str, '-f', key_path, + '-C', 'fileglancer', ] logger.info(f"Generating SSH key: {key_name}") + # Set restrictive umask to ensure private key is created with secure permissions + old_umask = os.umask(0o077) try: result = subprocess.run( cmd, @@ -396,7 +595,7 @@ def generate_ssh_key(ssh_dir: str, passphrase: Optional[str] = None) -> SSHKeyIn if result.returncode != 0: raise RuntimeError(f"ssh-keygen failed: {result.stderr}") - # Set correct permissions + # Set correct permissions explicitly as well os.chmod(key_path, 0o600) os.chmod(pubkey_path, 0o644) @@ -409,11 +608,86 @@ def generate_ssh_key(ssh_dir: str, passphrase: Optional[str] = None) -> SSHKeyIn raise RuntimeError("Key generation timed out") except FileNotFoundError: raise RuntimeError("ssh-keygen not found on system") + finally: + os.umask(old_umask) + + +def regenerate_public_key( + ssh_dir: str, + key_name: str = "id_ed25519", + passphrase: Optional[SecretStr] = None +) -> SSHKeyInfo: + """Regenerate a public key from an existing private key. + + Uses ssh-keygen -y to extract the public key from the private key. + If the private key is encrypted, a passphrase must be provided. + + Args: + ssh_dir: Path to the .ssh directory + key_name: Name of the key (without extension), defaults to "id_ed25519" + passphrase: Passphrase for the private key (if encrypted) + + Returns: + SSHKeyInfo for the regenerated key + + Raises: + ValueError: If the private key doesn't exist + RuntimeError: If public key regeneration fails (e.g., wrong passphrase) + """ + privkey_path = os.path.join(ssh_dir, key_name) + pubkey_path = os.path.join(ssh_dir, f"{key_name}.pub") + + if not os.path.exists(privkey_path): + raise ValueError(f"Private key '{key_name}' not found") + + # Build ssh-keygen command to extract public key + # ssh-keygen -y -f private_key [-P passphrase] + passphrase_str = passphrase.get_secret_value() if passphrase else "" + + try: + # Use -y to read private key and output public key + result = subprocess.run( + ['ssh-keygen', '-y', '-P', passphrase_str, '-f', privkey_path], + capture_output=True, + text=True, + timeout=30 + ) + + if result.returncode != 0: + if "incorrect passphrase" in result.stderr.lower() or \ + "bad passphrase" in result.stderr.lower(): + raise RuntimeError("Incorrect passphrase for private key") + raise RuntimeError(f"Failed to extract public key: {result.stderr}") + + # The output is just "type base64key" without comment + # We need to add the comment + public_key_content = result.stdout.strip() + if not public_key_content: + raise RuntimeError("No public key output from ssh-keygen") + + # Write the public key file + with open(pubkey_path, 'w') as f: + f.write(public_key_content + '\n') + + # Set correct permissions + os.chmod(pubkey_path, 0o644) + + logger.info(f"Regenerated public key: {pubkey_path}") + + # Parse and return the key info + return parse_public_key(pubkey_path, ssh_dir) + + except subprocess.TimeoutExpired: + raise RuntimeError("Public key regeneration timed out") + except FileNotFoundError: + raise RuntimeError("ssh-keygen not found on system") def add_to_authorized_keys(ssh_dir: str, public_key: str) -> bool: """Add a public key to the authorized_keys file. + Enforces 'restrict' option and ensures 'fileglancer' is in the comment. + Args: ssh_dir: Path to the .ssh directory public_key: The public key content to add @@ -426,8 +700,43 @@ def add_to_authorized_keys(ssh_dir: str, public_key: str) -> bool: RuntimeError: If adding the key fails """ # Validate public key format (basic check) - if not public_key or not public_key.startswith(SSH_KEY_PREFIX): - raise ValueError("Invalid public key format") + if not public_key: + raise ValueError("Invalid public key format: empty") + + parts = public_key.strip().split() + + # Find the key type index (ssh-...) to handle existing options + try: + type_idx = next(i for i, part in enumerate(parts) if part.startswith(SSH_KEY_PREFIX)) + except StopIteration: + raise ValueError("Invalid public key format: key type not found") + + # Handle options + if type_idx > 0: + options = parts[0].split(',') + if "restrict" not in options: + options.insert(0, "restrict,pty") + new_options = ",".join(options) + else: + new_options = "restrict,pty" + + # Handle comment + key_parts = parts[type_idx:] + # key_parts is [type, blob, comment...] + if len(key_parts) < 2: + raise ValueError("Invalid public key format: incomplete") + + key_type = key_parts[0] + key_blob = key_parts[1] + comment_parts = key_parts[2:] + + if not any("fileglancer" in p for p in comment_parts): + comment_parts.append("fileglancer") + + new_comment = " ".join(comment_parts) + + # Reconstruct key line + final_key_line = f"{new_options} {key_type} {key_blob} {new_comment}" # Ensure .ssh directory exists ensure_ssh_directory_exists(ssh_dir) @@ -435,7 +744,7 @@ def add_to_authorized_keys(ssh_dir: str, public_key: str) -> bool: authorized_keys_path = os.path.join(ssh_dir, AUTHORIZED_KEYS_FILENAME) # Get fingerprint of the key we're adding to check if already present - # Write key to temp file to get its fingerprint + # We use the ORIGINAL key content for fingerprinting as options don't affect it try: with tempfile.NamedTemporaryFile(mode='w', suffix='.pub', delete=False) as tmp: tmp.write(public_key) @@ -454,7 +763,7 @@ def add_to_authorized_keys(ssh_dir: str, public_key: str) -> bool: except Exception as e: logger.warning(f"Could not check fingerprint, proceeding with add: {e}") - # Backup and append the key + # Backup and append the modified key try: # Backup existing file before modifying if os.path.exists(authorized_keys_path): @@ -474,13 +783,306 @@ def add_to_authorized_keys(ssh_dir: str, public_key: str) -> bool: # Append the key with open(authorized_keys_path, 'a') as f: - f.write(public_key + '\n') + f.write(final_key_line + '\n') # Ensure correct permissions os.chmod(authorized_keys_path, 0o600) - logger.info(f"Added key to {authorized_keys_path}") + logger.info(f"Added restricted key to {authorized_keys_path}") return True except Exception as e: raise RuntimeError(f"Failed to add key to authorized_keys: {e}") + + +def _get_key_info_line(key_path: str) -> Optional[str]: + """Get the ssh-keygen info line for a key (public or private). + + Args: + key_path: Path to the key file + + Returns: + The output line from ssh-keygen -lf, or None if it fails. + Format: "256 SHA256:xxxxx comment (ED25519)" + """ + try: + result = subprocess.run( + ['ssh-keygen', '-lf', key_path], + capture_output=True, + text=True, + timeout=10 + ) + if result.returncode == 0: + return result.stdout.strip() + return None + except Exception: + return None + + +def check_unmanaged_id_ed25519(ssh_dir: str) -> bool: + """Check if id_ed25519 exists but is not managed by fileglancer. + + Args: + ssh_dir: Path to the .ssh directory + + Returns: + True if id_ed25519 or id_ed25519.pub exists but isn't fileglancer-managed + """ + privkey_path = os.path.join(ssh_dir, "id_ed25519") + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + + privkey_exists = os.path.exists(privkey_path) + pubkey_exists = os.path.exists(pubkey_path) + + # If neither file exists, not unmanaged + if not privkey_exists and not pubkey_exists: + return False + + # If public key exists, check if it has 'fileglancer' in the comment + if pubkey_exists: + try: + with open(pubkey_path, 'r') as f: + content = f.read().strip() + + # If 'fileglancer' is in the comment, it's managed + if 'fileglancer' in content: + return False + # Otherwise it's unmanaged + return True + except Exception as e: + logger.warning(f"Error checking id_ed25519.pub: {e}") + # If we can't read it, assume it's unmanaged to be safe + return True + + # Public key doesn't exist but private key does + # Get the key info line which includes the comment + # Output format: "256 SHA256:xxxxx comment (ED25519)" + info_line = _get_key_info_line(privkey_path) + if info_line and 'fileglancer' in info_line: + return False # It's managed (has fileglancer in comment) + + # Could not get info or 'fileglancer' not in comment + return True + + +def check_id_ed25519_status(ssh_dir: str) -> tuple: + """Check the status of id_ed25519 key. + + Args: + ssh_dir: Path to the .ssh directory + + Returns: + Tuple of (exists, unmanaged, missing_pubkey): + - exists: True if id_ed25519 private key exists + - unmanaged: True if id_ed25519 exists but is not managed by fileglancer + - missing_pubkey: True if id_ed25519 is managed but .pub file is missing + """ + privkey_path = os.path.join(ssh_dir, "id_ed25519") + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + + privkey_exists = os.path.exists(privkey_path) + pubkey_exists = os.path.exists(pubkey_path) + + # If private key doesn't exist + if not privkey_exists: + # Check if only pubkey exists (unusual but possible) + if pubkey_exists: + try: + with open(pubkey_path, 'r') as f: + content = f.read().strip() + if 'fileglancer' in content: + return (True, False, False) # Managed via pubkey only + return (True, True, False) # Unmanaged + except Exception: + return (True, True, False) # Assume unmanaged + return (False, False, False) # Nothing exists + + # Private key exists + exists = True + + # Check if public key exists and has fileglancer + if pubkey_exists: + try: + with open(pubkey_path, 'r') as f: + content = f.read().strip() + if 'fileglancer' in content: + return (exists, False, False) # Managed, has pubkey + return (exists, True, False) # Unmanaged + except Exception: + return (exists, True, False) # Assume unmanaged + + # Private key exists but public key doesn't + # Check if private key has fileglancer in comment + info_line = _get_key_info_line(privkey_path) + if info_line and 'fileglancer' in info_line: + return (exists, False, True) # Managed but missing pubkey + + return (exists, True, False) # Unmanaged + + +class TempKeyResponse(SSHKeyContentResponse): + """Secure streaming response for temporary SSH key that deletes files after sending. + + Extends SSHKeyContentResponse to also delete the temporary key files + after the private key content has been streamed and wiped. + """ + + def __init__( + self, + key_content: bytearray, + temp_key_path: str, + temp_pubkey_path: str, + key_info: SSHKeyInfo, + status_code: int = 200, + ): + """Initialize the response with key content and paths to delete. + + Args: + key_content: The private key content as a mutable bytearray + temp_key_path: Path to temporary private key file to delete + temp_pubkey_path: Path to temporary public key file to delete + key_info: SSHKeyInfo to include in response headers + status_code: HTTP status code (default 200) + """ + self._temp_key_path = temp_key_path + self._temp_pubkey_path = temp_pubkey_path + + # Include key info in headers + headers = { + "X-SSH-Key-Filename": key_info.filename, + "X-SSH-Key-Type": key_info.key_type, + "X-SSH-Key-Fingerprint": key_info.fingerprint, + "X-SSH-Key-Comment": key_info.comment, + } + + super().__init__(key_content, status_code, headers) + + async def __call__(self, scope, receive, send): + """Stream the response, then wipe buffer and delete temp files.""" + try: + await super().__call__(scope, receive, send) + finally: + # Delete temporary files + self._cleanup_temp_files() + + def _cleanup_temp_files(self): + """Securely delete temporary key files.""" + for path in (self._temp_key_path, self._temp_pubkey_path): + try: + if os.path.exists(path): + # Overwrite with zeros before deleting + file_size = os.path.getsize(path) + with open(path, 'wb') as f: + f.write(b'\x00' * file_size) + os.unlink(path) + logger.info(f"Deleted temporary key file: {path}") + except Exception as e: + logger.warning(f"Failed to delete temp file {path}: {e}") + + +def generate_temp_key_and_authorize( + ssh_dir: str, + passphrase: Optional[SecretStr] = None +) -> TempKeyResponse: + """Generate a temporary SSH key, add to authorized_keys, and return private key. + + The key is generated to a temporary location, the public key is added to + authorized_keys with 'fileglancer' comment, and the private key is returned + via TempKeyResponse which will delete the temp files after streaming. + + Args: + ssh_dir: Path to the .ssh directory + passphrase: Optional passphrase to protect the private key + + Returns: + TempKeyResponse containing the private key (streams and deletes files) + + Raises: + RuntimeError: If key generation fails + """ + # Create temporary directory for key generation + temp_dir = tempfile.mkdtemp(prefix="fileglancer_ssh_") + temp_key_path = os.path.join(temp_dir, "temp_key") + temp_pubkey_path = os.path.join(temp_dir, "temp_key.pub") + + # Set restrictive umask to ensure private key is created with secure permissions + old_umask = os.umask(0o077) + try: + # Generate key to temp location + passphrase_str = passphrase.get_secret_value() if passphrase else "" + cmd = [ + 'ssh-keygen', + '-t', 'ed25519', + '-N', passphrase_str, + '-f', temp_key_path, + '-C', 'fileglancer', + ] + + logger.info("Generating temporary SSH key") + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=30 + ) + + if result.returncode != 0: + raise RuntimeError(f"ssh-keygen failed: {result.stderr}") + + # Set correct permissions explicitly as well + os.chmod(temp_key_path, 0o600) + os.chmod(temp_pubkey_path, 0o644) + + # Read public key and add to authorized_keys + with open(temp_pubkey_path, 'r') as f: + public_key = f.read().strip() + + ensure_ssh_directory_exists(ssh_dir) + add_to_authorized_keys(ssh_dir, public_key) + + # Get key info for headers + fingerprint = get_key_fingerprint(temp_pubkey_path) + + # Parse public key for type and comment + parts = public_key.split(None, 2) + key_type = parts[0] if len(parts) >= 1 else "ssh-ed25519" + comment = parts[2] if len(parts) > 2 else "fileglancer" + + key_info = SSHKeyInfo( + filename="temporary", + key_type=key_type, + fingerprint=fingerprint, + comment=comment, + has_private_key=False, # Not persisted + is_authorized=True + ) + + # Read private key into bytearray + private_key_buffer = read_file_to_bytearray(temp_key_path) + + logger.info("Temporary SSH key generated and added to authorized_keys") + + # Return response that will stream private key and delete temp files + # Note: temp_dir cleanup happens in TempKeyResponse._cleanup_temp_files + return TempKeyResponse( + private_key_buffer, + temp_key_path, + temp_pubkey_path, + key_info + ) + + except Exception as e: + # Clean up on error + try: + if os.path.exists(temp_key_path): + os.unlink(temp_key_path) + if os.path.exists(temp_pubkey_path): + os.unlink(temp_pubkey_path) + if os.path.exists(temp_dir): + os.rmdir(temp_dir) + except Exception: + pass + raise RuntimeError(f"Failed to generate temporary key: {e}") + finally: + os.umask(old_umask) diff --git a/frontend/src/components/SSHKeys.tsx b/frontend/src/components/SSHKeys.tsx index 2d84c22e..ad18ff0b 100644 --- a/frontend/src/components/SSHKeys.tsx +++ b/frontend/src/components/SSHKeys.tsx @@ -1,28 +1,62 @@ import { useState } from 'react'; -import { Button, Card, Typography } from '@material-tailwind/react'; +import { Button, Card, Input, Typography } from '@material-tailwind/react'; import { HiOutlinePlus, HiOutlineKey, - HiOutlineInformationCircle + HiOutlineInformationCircle, + HiOutlineExclamation, + HiOutlineRefresh } from 'react-icons/hi'; +import toast from 'react-hot-toast'; -import { useSSHKeysQuery } from '@/queries/sshKeyQueries'; +import { + useSSHKeysQuery, + useRegeneratePublicKeyMutation +} from '@/queries/sshKeyQueries'; +import type { TempKeyResult } from '@/queries/sshKeyQueries'; import SSHKeyCard from '@/components/ui/SSHKeys/SSHKeyCard'; import GenerateKeyDialog from '@/components/ui/SSHKeys/GenerateKeyDialog'; +import GenerateTempKeyDialog from '@/components/ui/SSHKeys/GenerateTempKeyDialog'; +import TempKeyDialog from '@/components/ui/SSHKeys/TempKeyDialog'; import { Spinner } from '@/components/ui/widgets/Loaders'; export default function SSHKeys() { const [showGenerateDialog, setShowGenerateDialog] = useState(false); - const { data: keys, isLoading, error, refetch } = useSSHKeysQuery(); + const [showGenerateTempDialog, setShowGenerateTempDialog] = useState(false); + const [tempKeyResult, setTempKeyResult] = useState( + null + ); + const [regeneratePassphrase, setRegeneratePassphrase] = useState(''); + const { data, isLoading, error, refetch } = useSSHKeysQuery(); + const regenerateMutation = useRegeneratePublicKeyMutation(); + + const handleRegenerate = async () => { + try { + await regenerateMutation.mutateAsync( + regeneratePassphrase ? { passphrase: regeneratePassphrase } : undefined + ); + setRegeneratePassphrase(''); + toast.success('Public key regenerated successfully'); + } catch (err) { + toast.error( + `Failed to regenerate public key: ${err instanceof Error ? err.message : 'Unknown error'}` + ); + } + }; - // Only show the id_ed25519 key - const defaultKey = keys?.find(key => key.filename === 'id_ed25519'); + const keys = data?.keys ?? []; + const hasKeys = keys.length > 0; + const unmanagedExists = data?.unmanaged_id_ed25519_exists ?? false; + const id_ed25519_exists = data?.id_ed25519_exists ?? false; + const id_ed25519_missing_pubkey = data?.id_ed25519_missing_pubkey ?? false; + // Can generate permanent key if id_ed25519 doesn't exist at all + const canGeneratePermanentKey = !unmanagedExists && !id_ed25519_exists; return ( <>
- SSH Key + Fileglancer-managed SSH Keys
@@ -30,24 +64,25 @@ export default function SSHKeys() {
- SSH keys allow you to securely connect to cluster nodes without - entering a password. Specifically, you need an ed25519 SSH key to - use Seqera Platform to run pipelines on the cluster. This page lets - you view your existing ed25519 SSH key or generate a new one. + Fileglancer-managed SSH keys allow you to securely connect to + cluster nodes without entering a password. Specifically, you need an + ed25519 SSH key to use Seqera Platform to run pipelines on the + cluster. This page shows SSH keys with "fileglancer" in the comment + and lets you generate a new one.
{isLoading ? (
- +
) : null} {error ? ( - Failed to load SSH key: {error.message} + Failed to load Fileglancer-managed SSH keys: {error.message} + + + + + ) : null} + + {!isLoading && !error && !hasKeys ? ( - No ed25519 SSH key found + No Fileglancer-managed SSH key found Generate an ed25519 SSH key to enable passwordless access to cluster nodes and integration with Seqera Platform. - ) : null} - {!isLoading && !error && defaultKey ? ( - + {!isLoading && !error && hasKeys ? ( +
+ {keys.map(key => ( + + ))} +
+ ) : null} + + {!isLoading && !error && canGeneratePermanentKey ? ( + +
+
+ + Generate Permanent Key + + + Creates the default id_ed25519 key pair in your ~/.ssh directory + and adds it to authorized_keys. The private key is stored on the + server. + +
+ +
+
+ ) : null} + + {!isLoading && !error ? ( + +
+
+ + Generate Temporary Key + + + Creates a key that is added to authorized_keys. The private key + is shown once for you to copy - it is not stored on the server. + +
+ +
+
) : null} + + + + setTempKeyResult(null)} + tempKeyResult={tempKeyResult} + /> ); } diff --git a/frontend/src/components/ui/SSHKeys/GenerateTempKeyDialog.tsx b/frontend/src/components/ui/SSHKeys/GenerateTempKeyDialog.tsx new file mode 100644 index 00000000..dd53fb05 --- /dev/null +++ b/frontend/src/components/ui/SSHKeys/GenerateTempKeyDialog.tsx @@ -0,0 +1,91 @@ +import { useState } from 'react'; +import type { Dispatch, SetStateAction } from 'react'; +import { Button, Input, Typography } from '@material-tailwind/react'; +import toast from 'react-hot-toast'; + +import FgDialog from '@/components/ui/Dialogs/FgDialog'; +import { Spinner } from '@/components/ui/widgets/Loaders'; +import { useGenerateTempKeyMutation } from '@/queries/sshKeyQueries'; +import type { TempKeyResult } from '@/queries/sshKeyQueries'; + +type GenerateTempKeyDialogProps = { + readonly showDialog: boolean; + readonly setShowDialog: Dispatch>; + readonly onKeyGenerated: (result: TempKeyResult) => void; +}; + +export default function GenerateTempKeyDialog({ + showDialog, + setShowDialog, + onKeyGenerated +}: GenerateTempKeyDialogProps) { + const generateMutation = useGenerateTempKeyMutation(); + const [passphrase, setPassphrase] = useState(''); + + const handleClose = () => { + setShowDialog(false); + setPassphrase(''); + }; + + const handleGenerate = async () => { + try { + const result = await generateMutation.mutateAsync( + passphrase ? { passphrase } : undefined + ); + handleClose(); + onKeyGenerated(result); + } catch (error) { + toast.error( + `Failed to generate key: ${error instanceof Error ? error.message : 'Unknown error'}` + ); + } + }; + + return ( + + + Generate Temporary SSH Key + + + + This will create a temporary ed25519 SSH key pair and add the public key + to your authorized_keys file. The private key will be shown once for you + to copy - it will not be stored on the server. + + +
+ + Passphrase (optional) + + setPassphrase(e.target.value)} + placeholder="Leave empty for no passphrase" + type="password" + value={passphrase} + /> + + A passphrase adds extra security but must be entered each time you use + the key. + +
+ +
+ + +
+
+ ); +} diff --git a/frontend/src/components/ui/SSHKeys/TempKeyDialog.tsx b/frontend/src/components/ui/SSHKeys/TempKeyDialog.tsx new file mode 100644 index 00000000..35b6845c --- /dev/null +++ b/frontend/src/components/ui/SSHKeys/TempKeyDialog.tsx @@ -0,0 +1,113 @@ +import { useState } from 'react'; +import { Button, Typography } from '@material-tailwind/react'; +import { + HiOutlineClipboardCopy, + HiOutlineExclamation, + HiOutlineCheck +} from 'react-icons/hi'; +import toast from 'react-hot-toast'; + +import FgDialog from '@/components/ui/Dialogs/FgDialog'; +import type { TempKeyResult } from '@/queries/sshKeyQueries'; + +type TempKeyDialogProps = { + readonly tempKeyResult: TempKeyResult | null; + readonly onClose: () => void; +}; + +export default function TempKeyDialog({ + tempKeyResult, + onClose +}: TempKeyDialogProps) { + const [copied, setCopied] = useState(false); + + if (!tempKeyResult) { + return null; + } + + const handleCopy = async () => { + try { + await navigator.clipboard.writeText(tempKeyResult.privateKey); + setCopied(true); + toast.success('Private key copied to clipboard'); + } catch { + toast.error('Failed to copy to clipboard'); + } + }; + + const handleClose = () => { + setCopied(false); + onClose(); + }; + + // Truncate fingerprint for display + const shortFingerprint = + tempKeyResult.keyInfo.fingerprint.replace('SHA256:', '').slice(0, 16) + + '...'; + + return ( + +
+ + + Temporary SSH Key Generated + +
+ +
+
+ + Copy this private key now - it will not be available again! + + + The private key is not stored on the server. You must copy it now + and save it securely on your local machine or intended application. + +
+ +
+ + Key Information + +
+
+ Type:{' '} + {tempKeyResult.keyInfo.key_type} +
+
+ Fingerprint:{' '} + {shortFingerprint} +
+
+ Comment:{' '} + {tempKeyResult.keyInfo.comment} +
+
+
+ +
+ + +
+
+
+ ); +} diff --git a/frontend/src/queries/sshKeyQueries.ts b/frontend/src/queries/sshKeyQueries.ts index cbee850c..d63951a7 100644 --- a/frontend/src/queries/sshKeyQueries.ts +++ b/frontend/src/queries/sshKeyQueries.ts @@ -31,6 +31,9 @@ export type SSHKeyContent = { */ type SSHKeyListResponse = { keys: SSHKeyInfo[]; + unmanaged_id_ed25519_exists: boolean; + id_ed25519_exists: boolean; + id_ed25519_missing_pubkey: boolean; }; /** @@ -47,10 +50,20 @@ export const sshKeyQueryKeys = { list: () => ['sshKeys', 'list'] as const }; +/** + * Result from fetching SSH keys including metadata + */ +export type SSHKeysResult = { + keys: SSHKeyInfo[]; + unmanaged_id_ed25519_exists: boolean; + id_ed25519_exists: boolean; + id_ed25519_missing_pubkey: boolean; +}; + /** * Fetches all SSH keys from the backend */ -const fetchSSHKeys = async (signal?: AbortSignal): Promise => { +const fetchSSHKeys = async (signal?: AbortSignal): Promise => { const response = await sendFetchRequest('/api/ssh-keys', 'GET', undefined, { signal }); @@ -62,16 +75,21 @@ const fetchSSHKeys = async (signal?: AbortSignal): Promise => { } const data = body as SSHKeyListResponse; - return data.keys ?? []; + return { + keys: data.keys ?? [], + unmanaged_id_ed25519_exists: data.unmanaged_id_ed25519_exists ?? false, + id_ed25519_exists: data.id_ed25519_exists ?? false, + id_ed25519_missing_pubkey: data.id_ed25519_missing_pubkey ?? false + }; }; /** * Query hook for fetching all SSH keys * - * @returns Query result with all SSH keys + * @returns Query result with SSH keys and metadata */ -export function useSSHKeysQuery(): UseQueryResult { - return useQuery({ +export function useSSHKeysQuery(): UseQueryResult { + return useQuery({ queryKey: sshKeyQueryKeys.list(), queryFn: ({ signal }) => fetchSSHKeys(signal) }); @@ -174,6 +192,9 @@ export function useAuthorizeSSHKeyMutation(): UseMutationResult< * Fetch SSH key content (public or private key) on demand. * This is not a hook - call it imperatively when user clicks copy. * + * Backend uses secure bytearray handling that wipes key content from + * memory after sending. Response is plain text for both key types. + * * @param keyType - Type of key to fetch: 'public' or 'private' * @returns Promise with the key content */ @@ -185,11 +206,122 @@ export async function fetchSSHKeyContent( 'GET' ); - const body = await getResponseJsonOrError(response); - if (!response.ok) { + const body = await getResponseJsonOrError(response); throwResponseNotOkError(response, body); } - return body as SSHKeyContent; + const keyText = await response.text(); + return { key: keyText }; +} + +/** + * Result from generating a temporary SSH key + */ +export type TempKeyResult = { + privateKey: string; + keyInfo: SSHKeyInfo; +}; + +/** + * Parameters for generating a temporary SSH key + */ +type GenerateTempKeyParams = { + passphrase?: string; +}; + +/** + * Mutation hook for generating a temporary SSH key. + * + * Generates a key, adds public key to authorized_keys, and returns + * the private key for one-time display. The temporary files are + * deleted on the server after the response is sent. + */ +export function useGenerateTempKeyMutation(): UseMutationResult< + TempKeyResult, + Error, + GenerateTempKeyParams | void +> { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async (params?: GenerateTempKeyParams) => { + const response = await sendFetchRequest( + '/api/ssh-keys/generate-temp', + 'POST', + params?.passphrase ? { passphrase: params.passphrase } : undefined + ); + + if (!response.ok) { + const body = await getResponseJsonOrError(response); + throwResponseNotOkError(response, body); + } + + // Private key is in response body + const privateKey = await response.text(); + + // Key info is in response headers + const keyInfo: SSHKeyInfo = { + filename: response.headers.get('X-SSH-Key-Filename') ?? 'temporary', + key_type: response.headers.get('X-SSH-Key-Type') ?? 'ssh-ed25519', + fingerprint: response.headers.get('X-SSH-Key-Fingerprint') ?? '', + comment: response.headers.get('X-SSH-Key-Comment') ?? 'fileglancer', + has_private_key: false, + is_authorized: true + }; + + return { privateKey, keyInfo }; + }, + onSuccess: () => { + // Invalidate and refetch the list to show the new key + queryClient.invalidateQueries({ + queryKey: sshKeyQueryKeys.all + }); + } + }); +} + +/** + * Parameters for regenerating a public key + */ +type RegeneratePublicKeyParams = { + passphrase?: string; +}; + +/** + * Mutation hook for regenerating the public key from the private key. + * + * Use this when the .pub file is missing but the private key exists. + * If the private key is encrypted, provide the passphrase. + */ +export function useRegeneratePublicKeyMutation(): UseMutationResult< + SSHKeyInfo, + Error, + RegeneratePublicKeyParams | void +> { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: async (params?: RegeneratePublicKeyParams) => { + const response = await sendFetchRequest( + '/api/ssh-keys/regenerate-public', + 'POST', + params?.passphrase ? { passphrase: params.passphrase } : undefined + ); + + const body = await getResponseJsonOrError(response); + + if (!response.ok) { + throwResponseNotOkError(response, body); + } + + return body as SSHKeyInfo; + }, + onSuccess: () => { + // Invalidate and refetch the list + queryClient.invalidateQueries({ + queryKey: sshKeyQueryKeys.all + }); + } + }); } diff --git a/tests/test_sshkeys.py b/tests/test_sshkeys.py new file mode 100644 index 00000000..ca3708da --- /dev/null +++ b/tests/test_sshkeys.py @@ -0,0 +1,958 @@ +"Tests for SSH key management utilities with secure bytearray handling." + +import os +import stat +import subprocess +import tempfile +import pytest +from pydantic import SecretStr + +from fileglancer.sshkeys import ( + _wipe_bytearray, + read_file_to_bytearray, + get_key_content, + SSHKeyContentResponse, + TempKeyResponse, + generate_ssh_key, + generate_temp_key_and_authorize, + regenerate_public_key, + check_id_ed25519_status, + list_ssh_keys, + add_to_authorized_keys, + is_key_in_authorized_keys, + _parse_authorized_keys_fileglancer, +) + + +class TestWipeBytearray: + """Tests for the _wipe_bytearray helper function.""" + + def test_wipe_bytearray_zeros_all_bytes(self): + """Verify that _wipe_bytearray overwrites all bytes with zeros.""" + data = bytearray(b"sensitive data here") + original_length = len(data) + + _wipe_bytearray(data) + + assert len(data) == original_length + assert all(b == 0 for b in data) + + def test_wipe_bytearray_empty(self): + """Verify that wiping an empty bytearray doesn't raise.""" + data = bytearray() + _wipe_bytearray(data) + assert len(data) == 0 + + def test_wipe_bytearray_binary_data(self): + """Verify that binary data (non-UTF8) is properly wiped.""" + data = bytearray(bytes(range(256))) + _wipe_bytearray(data) + assert all(b == 0 for b in data) + + +class TestReadFileToBytearray: + """Tests for reading files into bytearrays.""" + + def test_read_file_into_bytearray(self): + """Verify file contents are read into a bytearray.""" + test_content = b"-----BEGIN OPENSSH PRIVATE KEY-----\ntest key content\n-----END OPENSSH PRIVATE KEY-----\n" + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(test_content) + temp_path = f.name + + try: + result = read_file_to_bytearray(temp_path) + + assert isinstance(result, bytearray) + assert bytes(result) == test_content + + # Clean up the bytearray + _wipe_bytearray(result) + finally: + os.unlink(temp_path) + + def test_read_returns_mutable_bytearray(self): + """Verify the returned bytearray is mutable and can be wiped.""" + test_content = b"secret key data" + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(test_content) + temp_path = f.name + + try: + result = read_file_to_bytearray(temp_path) + + # Verify it's mutable by modifying it + result[0] = 0 + assert result[0] == 0 + + # Verify we can wipe it completely + _wipe_bytearray(result) + assert all(b == 0 for b in result) + finally: + os.unlink(temp_path) + + def test_read_nonexistent_file_raises(self): + """Verify reading a nonexistent file raises an error.""" + with pytest.raises(FileNotFoundError): + read_file_to_bytearray("/nonexistent/path/to/key") + + +class TestGetKeyContent: + """Tests for the unified get_key_content function.""" + + def test_returns_bytearray_for_public_key(self): + """Verify get_key_content returns a bytearray for public keys.""" + with tempfile.TemporaryDirectory() as ssh_dir: + pubkey_content = b"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITest test@example.com" + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + + with open(pubkey_path, 'wb') as f: + f.write(pubkey_content) + + result = get_key_content(ssh_dir, "id_ed25519", "public") + + assert isinstance(result, bytearray) + assert bytes(result) == pubkey_content + + # Clean up + _wipe_bytearray(result) + + def test_returns_bytearray_for_private_key(self): + """Verify get_key_content returns a bytearray for private keys.""" + with tempfile.TemporaryDirectory() as ssh_dir: + private_key_content = b"-----BEGIN OPENSSH PRIVATE KEY-----\ntest\n-----END OPENSSH PRIVATE KEY-----\n" + private_key_path = os.path.join(ssh_dir, "id_ed25519") + + with open(private_key_path, 'wb') as f: + f.write(private_key_content) + + result = get_key_content(ssh_dir, "id_ed25519", "private") + + assert isinstance(result, bytearray) + assert bytes(result) == private_key_content + + # Clean up + _wipe_bytearray(result) + + def test_returned_bytearray_is_wipeable(self): + """Verify the returned bytearray can be securely wiped.""" + with tempfile.TemporaryDirectory() as ssh_dir: + key_content = b"secret key data" + key_path = os.path.join(ssh_dir, "id_ed25519") + + with open(key_path, 'wb') as f: + f.write(key_content) + + result = get_key_content(ssh_dir, "id_ed25519", "private") + + # Wipe and verify + _wipe_bytearray(result) + assert all(b == 0 for b in result) + + def test_nonexistent_public_key_raises(self): + """Verify requesting a nonexistent public key raises ValueError.""" + with tempfile.TemporaryDirectory() as ssh_dir: + with pytest.raises(ValueError, match="not found"): + get_key_content(ssh_dir, "nonexistent_key", "public") + + def test_nonexistent_private_key_raises(self): + """Verify requesting a nonexistent private key raises ValueError.""" + with tempfile.TemporaryDirectory() as ssh_dir: + with pytest.raises(ValueError, match="not found"): + get_key_content(ssh_dir, "nonexistent_key", "private") + + def test_invalid_key_type_raises(self): + """Verify invalid key_type raises ValueError.""" + with tempfile.TemporaryDirectory() as ssh_dir: + with pytest.raises(ValueError, match="Invalid key_type"): + get_key_content(ssh_dir, "id_ed25519", "invalid") + + +class TestSSHKeyContentResponse: + """Tests for the SSHKeyContentResponse secure streaming response class.""" + + @pytest.mark.asyncio + async def test_response_streams_content_line_by_line(self): + """Verify the response streams content line by line with more_body flag.""" + key_content = bytearray(b"line1\nline2\nline3\n") + response = SSHKeyContentResponse(key_content) + + sent_messages = [] + captured_bodies = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + # Capture a copy of the body before it gets wiped + if message.get("type") == "http.response.body": + captured_bodies.append(bytes(message["body"])) + + scope = {"type": "http"} + + await response(scope, mock_receive, mock_send) + + # Verify response start was sent + assert sent_messages[0]["type"] == "http.response.start" + assert sent_messages[0]["status"] == 200 + + # Verify streaming: multiple body messages with more_body=True + # 3 lines + 1 final empty message = 4 body messages + body_messages = [m for m in sent_messages if m.get("type") == "http.response.body"] + assert len(body_messages) == 4 + + # All but the last should have more_body=True + for msg in body_messages[:-1]: + assert msg["more_body"] is True + + # Final message should have more_body=False and empty body + assert body_messages[-1]["more_body"] is False + assert captured_bodies[-1] == b"" + + # Reassembled content should match original + reassembled = b"".join(captured_bodies[:-1]) # Exclude final empty + assert reassembled == b"line1\nline2\nline3\n" + + @pytest.mark.asyncio + async def test_response_sends_single_line_content(self): + """Verify the response handles content without newlines.""" + key_content = bytearray(b"single line no newline") + response = SSHKeyContentResponse(key_content) + + sent_messages = [] + captured_bodies = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + if message.get("type") == "http.response.body": + captured_bodies.append(bytes(message["body"])) + + scope = {"type": "http"} + + await response(scope, mock_receive, mock_send) + + # Verify response start was sent + assert sent_messages[0]["type"] == "http.response.start" + assert sent_messages[0]["status"] == 200 + + # 1 content chunk + 1 final empty = 2 body messages + body_messages = [m for m in sent_messages if m.get("type") == "http.response.body"] + assert len(body_messages) == 2 + + # First should have more_body=True, second more_body=False + assert body_messages[0]["more_body"] is True + assert body_messages[1]["more_body"] is False + + # Content should match + assert captured_bodies[0] == b"single line no newline" + assert captured_bodies[1] == b"" + + @pytest.mark.asyncio + async def test_response_wipes_bytearray_after_streaming(self): + """Verify the bytearray is wiped after streaming completes.""" + key_content = bytearray(b"sensitive\nprivate\nkey\n") + original_length = len(key_content) + response = SSHKeyContentResponse(key_content) + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + pass + + scope = {"type": "http"} + + await response(scope, mock_receive, mock_send) + + # Verify the bytearray was wiped + assert len(key_content) == original_length + assert all(b == 0 for b in key_content) + + @pytest.mark.asyncio + async def test_response_wipes_bytearray_even_on_error(self): + """Verify the bytearray is wiped even if streaming fails.""" + key_content = bytearray(b"sensitive\ndata\n") + response = SSHKeyContentResponse(key_content) + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + if message["type"] == "http.response.body": + raise Exception("Simulated send error") + + scope = {"type": "http"} + + with pytest.raises(Exception, match="Simulated send error"): + await response(scope, mock_receive, mock_send) + + # Verify the bytearray was still wiped despite the error + assert all(b == 0 for b in key_content) + + def test_response_has_correct_content_type(self): + """Verify the response has text/plain content type.""" + key_content = bytearray(b"test") + response = SSHKeyContentResponse(key_content) + + assert response.media_type == "text/plain" + + # Clean up + _wipe_bytearray(key_content) + + def test_response_accepts_custom_status_code(self): + """Verify custom status codes are supported.""" + key_content = bytearray(b"test") + response = SSHKeyContentResponse(key_content, status_code=201) + + assert response.status_code == 201 + + # Clean up + _wipe_bytearray(key_content) + + @pytest.mark.asyncio + async def test_response_streams_realistic_ssh_key(self): + """Verify streaming works with a realistic SSH private key format.""" + key_content = bytearray( + b"-----BEGIN OPENSSH PRIVATE KEY-----\n" + b"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n" + b"c2gtZWQyNTUxOQAAACBGVmJsZnRtcm5yYmx0c21ibmRjc2xibmRzY21ibmRzYwAA\n" + b"AIhkc21ibmRzY21ibmRzY21ibmRzY21ibmRzY21ibmRzY2RzbWJuZHNjbWJuZHNj\n" + b"-----END OPENSSH PRIVATE KEY-----\n" + ) + expected_content = bytes(key_content) + response = SSHKeyContentResponse(key_content) + + sent_messages = [] + captured_bodies = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + if message.get("type") == "http.response.body": + captured_bodies.append(bytes(message["body"])) + + scope = {"type": "http"} + + await response(scope, mock_receive, mock_send) + + # Should have 5 lines + 1 final empty = 6 body messages + body_messages = [m for m in sent_messages if m.get("type") == "http.response.body"] + assert len(body_messages) == 6 + + # Verify each line message has more_body=True + for msg in body_messages[:-1]: + assert msg["more_body"] is True + + # Final message should be empty with more_body=False + assert body_messages[-1]["more_body"] is False + assert captured_bodies[-1] == b"" + + # Reassembled content should match original + reassembled = b"".join(captured_bodies[:-1]) + assert reassembled == expected_content + + # Verify bytearray was wiped + assert all(b == 0 for b in key_content) + + def test_iter_lines_yields_memoryview_slices(self): + """Verify _iter_lines yields memoryview slices without copying.""" + key_content = bytearray(b"line1\nline2\nline3\n") + response = SSHKeyContentResponse(key_content) + + lines = list(response._iter_lines()) + + # Should have 3 lines + assert len(lines) == 3 + + # Each should be a memoryview + for line in lines: + assert isinstance(line, memoryview) + + # Content should be correct + assert bytes(lines[0]) == b"line1\n" + assert bytes(lines[1]) == b"line2\n" + assert bytes(lines[2]) == b"line3\n" + + # Clean up + _wipe_bytearray(key_content) + + def test_iter_lines_handles_no_trailing_newline(self): + """Verify _iter_lines handles content without trailing newline.""" + key_content = bytearray(b"line1\nline2\nfinal") + response = SSHKeyContentResponse(key_content) + + lines = list(response._iter_lines()) + + # Should have 3 lines + assert len(lines) == 3 + + assert bytes(lines[0]) == b"line1\n" + assert bytes(lines[1]) == b"line2\n" + assert bytes(lines[2]) == b"final" # No newline + + # Clean up + _wipe_bytearray(key_content) + + @pytest.mark.asyncio + async def test_response_with_custom_headers(self): + """Verify custom headers are included in the response.""" + key_content = bytearray(b"test") + custom_headers = {"X-Custom-Header": "test-value"} + response = SSHKeyContentResponse(key_content, headers=custom_headers) + + sent_messages = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + + scope = {"type": "http"} + + await response(scope, mock_receive, mock_send) + + # Check headers in the response start message + headers = dict(sent_messages[0]["headers"]) + assert b"x-custom-header" in headers or any( + h[0] == b"x-custom-header" for h in sent_messages[0]["headers"] + ) + + @pytest.mark.asyncio + async def test_response_sets_correct_content_length(self): + """Verify Content-Length header matches the bytearray length.""" + key_content = bytearray(b"test private key with specific length") + expected_length = len(key_content) + response = SSHKeyContentResponse(key_content) + + sent_messages = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + + scope = {"type": "http"} + + await response(scope, mock_receive, mock_send) + + # Find Content-Length in headers + headers = sent_messages[0]["headers"] + content_length = None + for header_name, header_value in headers: + if header_name == b"content-length": + content_length = int(header_value.decode()) + break + + assert content_length == expected_length + + +class TestGenerateSSHKey: + """Tests for the generate_ssh_key function.""" + + def test_generate_key_no_passphrase(self): + """Verify generating a key with no passphrase.""" + with tempfile.TemporaryDirectory() as ssh_dir: + key_info = generate_ssh_key(ssh_dir, passphrase=None) + + assert key_info.filename == "id_ed25519" + key_path = os.path.join(ssh_dir, "id_ed25519") + assert os.path.exists(key_path) + assert os.path.exists(key_path + ".pub") + + # Verify it is NOT encrypted + check_cmd = ['ssh-keygen', '-y', '-f', key_path, '-P', ''] + result = subprocess.run(check_cmd, capture_output=True) + assert result.returncode == 0 + + def test_generate_key_with_passphrase(self): + """Verify generating a key with a passphrase.""" + with tempfile.TemporaryDirectory() as ssh_dir: + passphrase_str = "test-passphrase" + passphrase = SecretStr(passphrase_str) + key_info = generate_ssh_key(ssh_dir, passphrase=passphrase) + + assert key_info.filename == "id_ed25519" + key_path = os.path.join(ssh_dir, "id_ed25519") + assert os.path.exists(key_path) + assert os.path.exists(key_path + ".pub") + + # Verify it IS encrypted (fails with empty passphrase) + check_cmd_empty = ['ssh-keygen', '-y', '-f', key_path, '-P', ''] + result_empty = subprocess.run(check_cmd_empty, capture_output=True) + assert result_empty.returncode != 0 + + # Verify it accepts the correct passphrase + check_cmd_correct = ['ssh-keygen', '-y', '-f', key_path, '-P', passphrase_str] + result_correct = subprocess.run(check_cmd_correct, capture_output=True) + assert result_correct.returncode == 0 + + def test_generate_key_already_exists_raises(self): + """Verify generating a key when one already exists raises ValueError.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Create a dummy key file + key_path = os.path.join(ssh_dir, "id_ed25519") + with open(key_path, 'w') as f: + f.write("dummy key") + + with pytest.raises(ValueError, match="already exists"): + generate_ssh_key(ssh_dir) + + def test_generate_key_has_fileglancer_comment(self): + """Verify generated key has 'fileglancer' in the comment.""" + with tempfile.TemporaryDirectory() as ssh_dir: + key_info = generate_ssh_key(ssh_dir) + + assert key_info.comment == "fileglancer" + + # Also verify by reading the public key file + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + with open(pubkey_path, 'r') as f: + content = f.read() + assert "fileglancer" in content + + def test_generate_key_sets_correct_permissions(self): + """Verify generated keys have correct permissions.""" + with tempfile.TemporaryDirectory() as ssh_dir: + generate_ssh_key(ssh_dir) + + key_path = os.path.join(ssh_dir, "id_ed25519") + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + + # Private key should be 0o600 + key_mode = stat.S_IMODE(os.stat(key_path).st_mode) + assert key_mode == 0o600 + + # Public key should be 0o644 + pubkey_mode = stat.S_IMODE(os.stat(pubkey_path).st_mode) + assert pubkey_mode == 0o644 + + def test_generate_key_restores_umask(self): + """Verify umask is restored after key generation.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Set a known umask + original_umask = os.umask(0o022) + os.umask(original_umask) + + generate_ssh_key(ssh_dir) + + # Verify umask is restored + current_umask = os.umask(original_umask) + assert current_umask == original_umask + + +class TestRegeneratePublicKey: + """Tests for regenerating public keys from private keys.""" + + def test_regenerate_public_key_basic(self): + """Verify regenerating a public key from a private key.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # First generate a key pair + generate_ssh_key(ssh_dir) + + # Delete the public key + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + os.unlink(pubkey_path) + assert not os.path.exists(pubkey_path) + + # Regenerate + key_info = regenerate_public_key(ssh_dir) + + assert os.path.exists(pubkey_path) + assert key_info.filename == "id_ed25519" + assert key_info.comment == "fileglancer" + + def test_regenerate_public_key_with_passphrase(self): + """Verify regenerating with a passphrase-protected key.""" + with tempfile.TemporaryDirectory() as ssh_dir: + passphrase = SecretStr("test-passphrase") + generate_ssh_key(ssh_dir, passphrase=passphrase) + + # Delete the public key + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + os.unlink(pubkey_path) + + # Regenerate with passphrase + key_info = regenerate_public_key(ssh_dir, passphrase=passphrase) + + assert os.path.exists(pubkey_path) + assert key_info.comment == "fileglancer" + + def test_regenerate_public_key_wrong_passphrase(self): + """Verify regenerating with wrong passphrase fails.""" + with tempfile.TemporaryDirectory() as ssh_dir: + passphrase = SecretStr("correct-passphrase") + generate_ssh_key(ssh_dir, passphrase=passphrase) + + # Delete the public key + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + os.unlink(pubkey_path) + + # Try to regenerate with wrong passphrase + wrong_passphrase = SecretStr("wrong-passphrase") + with pytest.raises(RuntimeError, match="passphrase"): + regenerate_public_key(ssh_dir, passphrase=wrong_passphrase) + + def test_regenerate_public_key_nonexistent_raises(self): + """Verify regenerating a nonexistent key raises ValueError.""" + with tempfile.TemporaryDirectory() as ssh_dir: + with pytest.raises(ValueError, match="not found"): + regenerate_public_key(ssh_dir) + + +class TestCheckId25519Status: + """Tests for check_id_ed25519_status function.""" + + def test_no_key_exists(self): + """Verify status when no key exists.""" + with tempfile.TemporaryDirectory() as ssh_dir: + exists, unmanaged, missing_pubkey = check_id_ed25519_status(ssh_dir) + + assert exists is False + assert unmanaged is False + assert missing_pubkey is False + + def test_managed_key_exists(self): + """Verify status when managed key exists.""" + with tempfile.TemporaryDirectory() as ssh_dir: + generate_ssh_key(ssh_dir) + + exists, unmanaged, missing_pubkey = check_id_ed25519_status(ssh_dir) + + assert exists is True + assert unmanaged is False + assert missing_pubkey is False + + def test_managed_key_missing_pubkey(self): + """Verify status when managed key exists but pubkey is missing.""" + with tempfile.TemporaryDirectory() as ssh_dir: + generate_ssh_key(ssh_dir) + + # Delete the public key + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + os.unlink(pubkey_path) + + exists, unmanaged, missing_pubkey = check_id_ed25519_status(ssh_dir) + + assert exists is True + assert unmanaged is False + assert missing_pubkey is True + + def test_unmanaged_key_exists(self): + """Verify status when unmanaged key exists.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Create an unmanaged key (no fileglancer comment) + key_path = os.path.join(ssh_dir, "id_ed25519") + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + + subprocess.run([ + 'ssh-keygen', '-t', 'ed25519', '-N', '', + '-f', key_path, '-C', 'user@host' + ], capture_output=True, check=True) + + exists, unmanaged, missing_pubkey = check_id_ed25519_status(ssh_dir) + + assert exists is True + assert unmanaged is True + assert missing_pubkey is False + + +class TestListSSHKeys: + """Tests for list_ssh_keys function.""" + + def test_empty_directory(self): + """Verify listing returns empty list for empty directory.""" + with tempfile.TemporaryDirectory() as ssh_dir: + keys = list_ssh_keys(ssh_dir) + assert keys == [] + + def test_list_managed_key(self): + """Verify listing includes managed keys.""" + with tempfile.TemporaryDirectory() as ssh_dir: + generate_ssh_key(ssh_dir) + + keys = list_ssh_keys(ssh_dir) + + assert len(keys) == 1 + assert keys[0].filename == "id_ed25519" + assert keys[0].comment == "fileglancer" + + def test_list_excludes_unmanaged_keys(self): + """Verify listing excludes keys without 'fileglancer' comment.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Create an unmanaged key + key_path = os.path.join(ssh_dir, "id_rsa") + subprocess.run([ + 'ssh-keygen', '-t', 'rsa', '-b', '2048', '-N', '', + '-f', key_path, '-C', 'user@host' + ], capture_output=True, check=True) + + keys = list_ssh_keys(ssh_dir) + assert len(keys) == 0 + + def test_list_sorts_id_ed25519_first(self): + """Verify id_ed25519 is sorted first.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Create multiple keys with fileglancer comment + for name in ["aaa_key", "id_ed25519", "zzz_key"]: + key_path = os.path.join(ssh_dir, name) + subprocess.run([ + 'ssh-keygen', '-t', 'ed25519', '-N', '', + '-f', key_path, '-C', 'fileglancer' + ], capture_output=True, check=True) + + keys = list_ssh_keys(ssh_dir) + + assert len(keys) == 3 + assert keys[0].filename == "id_ed25519" + # The rest should be alphabetical + assert keys[1].filename == "aaa_key" + assert keys[2].filename == "zzz_key" + + +class TestGenerateTempKeyAndAuthorize: + """Tests for generate_temp_key_and_authorize function.""" + + @pytest.mark.asyncio + async def test_generate_temp_key_basic(self): + """Verify generating a temporary key.""" + with tempfile.TemporaryDirectory() as ssh_dir: + response = generate_temp_key_and_authorize(ssh_dir) + + # Verify it's a TempKeyResponse + assert isinstance(response, TempKeyResponse) + + # Stream the response to get the private key + sent_messages = [] + captured_bodies = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + if message.get("type") == "http.response.body": + captured_bodies.append(bytes(message["body"])) + + scope = {"type": "http"} + await response(scope, mock_receive, mock_send) + + # Reassemble private key + private_key = b"".join(captured_bodies[:-1]) + assert b"-----BEGIN OPENSSH PRIVATE KEY-----" in private_key + + # Verify authorized_keys was updated + auth_keys_path = os.path.join(ssh_dir, "authorized_keys") + assert os.path.exists(auth_keys_path) + with open(auth_keys_path, 'r') as f: + content = f.read() + assert "fileglancer" in content + + @pytest.mark.asyncio + async def test_generate_temp_key_with_passphrase(self): + """Verify generating a temp key with passphrase.""" + with tempfile.TemporaryDirectory() as ssh_dir: + passphrase = SecretStr("test-passphrase") + response = generate_temp_key_and_authorize(ssh_dir, passphrase=passphrase) + + # Stream the response + captured_bodies = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + if message.get("type") == "http.response.body": + captured_bodies.append(bytes(message["body"])) + + scope = {"type": "http"} + await response(scope, mock_receive, mock_send) + + # Reassemble private key + private_key = b"".join(captured_bodies[:-1]) + assert b"-----BEGIN OPENSSH PRIVATE KEY-----" in private_key + + @pytest.mark.asyncio + async def test_generate_temp_key_deletes_temp_files(self): + """Verify temp files are deleted after streaming.""" + with tempfile.TemporaryDirectory() as ssh_dir: + response = generate_temp_key_and_authorize(ssh_dir) + + # Get the temp file paths + temp_key_path = response._temp_key_path + temp_pubkey_path = response._temp_pubkey_path + + # Files should exist before streaming + assert os.path.exists(temp_key_path) + assert os.path.exists(temp_pubkey_path) + + # Stream the response + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + pass + + scope = {"type": "http"} + await response(scope, mock_receive, mock_send) + + # Files should be deleted after streaming + assert not os.path.exists(temp_key_path) + assert not os.path.exists(temp_pubkey_path) + + def test_generate_temp_key_restores_umask(self): + """Verify umask is restored after temp key generation.""" + with tempfile.TemporaryDirectory() as ssh_dir: + original_umask = os.umask(0o022) + os.umask(original_umask) + + # This will create the response but not stream it + # The umask should be restored after the try/finally + try: + generate_temp_key_and_authorize(ssh_dir) + except Exception: + pass + + current_umask = os.umask(original_umask) + assert current_umask == original_umask + + +class TestIsKeyInAuthorizedKeys: + """Tests for is_key_in_authorized_keys function.""" + + def test_key_not_in_empty_authorized_keys(self): + """Verify returns False when authorized_keys doesn't exist.""" + with tempfile.TemporaryDirectory() as ssh_dir: + result = is_key_in_authorized_keys(ssh_dir, "SHA256:abcdef123456") + assert result is False + + def test_key_in_authorized_keys(self): + """Verify returns True when key is in authorized_keys.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Generate a key and add to authorized_keys + generate_ssh_key(ssh_dir) + + # Read the public key and add to authorized_keys + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + with open(pubkey_path, 'r') as f: + pubkey = f.read().strip() + + add_to_authorized_keys(ssh_dir, pubkey) + + # Get the fingerprint + result = subprocess.run( + ['ssh-keygen', '-lf', pubkey_path], + capture_output=True, text=True + ) + fingerprint = result.stdout.split()[1] + + # Check + assert is_key_in_authorized_keys(ssh_dir, fingerprint) is True + + +class TestParseAuthorizedKeysFileglancer: + """Tests for _parse_authorized_keys_fileglancer function.""" + + def test_empty_authorized_keys(self): + """Verify returns empty dict when authorized_keys doesn't exist.""" + with tempfile.TemporaryDirectory() as ssh_dir: + result = _parse_authorized_keys_fileglancer(ssh_dir) + assert result == {} + + def test_parses_fileglancer_keys(self): + """Verify parses keys with fileglancer in comment.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Generate and authorize a key + generate_ssh_key(ssh_dir) + pubkey_path = os.path.join(ssh_dir, "id_ed25519.pub") + with open(pubkey_path, 'r') as f: + pubkey = f.read().strip() + add_to_authorized_keys(ssh_dir, pubkey) + + result = _parse_authorized_keys_fileglancer(ssh_dir) + + assert len(result) == 1 + key_info = list(result.values())[0] + assert "fileglancer" in key_info.comment + assert key_info.is_authorized is True + + def test_excludes_non_fileglancer_keys(self): + """Verify excludes keys without fileglancer in comment.""" + with tempfile.TemporaryDirectory() as ssh_dir: + # Create authorized_keys with a non-fileglancer key + auth_keys_path = os.path.join(ssh_dir, "authorized_keys") + + # Generate a key without fileglancer comment + key_path = os.path.join(ssh_dir, "other_key") + subprocess.run([ + 'ssh-keygen', '-t', 'ed25519', '-N', '', + '-f', key_path, '-C', 'user@host' + ], capture_output=True, check=True) + + with open(key_path + ".pub", 'r') as f: + pubkey = f.read().strip() + + with open(auth_keys_path, 'w') as f: + f.write(pubkey + "\n") + + result = _parse_authorized_keys_fileglancer(ssh_dir) + assert len(result) == 0 + + +class TestTempKeyResponse: + """Tests for TempKeyResponse class.""" + + @pytest.mark.asyncio + async def test_includes_key_info_in_headers(self): + """Verify key info is included in response headers.""" + from fileglancer.sshkeys import SSHKeyInfo + + key_content = bytearray(b"test private key") + key_info = SSHKeyInfo( + filename="test_key", + key_type="ssh-ed25519", + fingerprint="SHA256:abc123", + comment="fileglancer", + has_private_key=False, + is_authorized=True + ) + + with tempfile.NamedTemporaryFile(delete=False) as f1, \ + tempfile.NamedTemporaryFile(delete=False) as f2: + f1.write(b"private") + f2.write(b"public") + temp_key = f1.name + temp_pub = f2.name + + try: + response = TempKeyResponse(key_content, temp_key, temp_pub, key_info) + + sent_messages = [] + + async def mock_receive(): + return {"type": "http.request", "body": b""} + + async def mock_send(message): + sent_messages.append(message) + + scope = {"type": "http"} + await response(scope, mock_receive, mock_send) + + # Check headers + headers = dict(sent_messages[0]["headers"]) + header_names = [h[0] for h in sent_messages[0]["headers"]] + + assert b"x-ssh-key-fingerprint" in header_names + assert b"x-ssh-key-type" in header_names + finally: + # Files should be deleted by TempKeyResponse + pass