diff --git a/README.md b/README.md index c43f783..97c35e3 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,20 @@ Exit codes: - `LTP_DOWNLOAD_SHA256_`: version-specific expected SHA-256 for the downloaded LanguageTool archive, for example `LTP_DOWNLOAD_SHA256_6_9_SNAPSHOT`. - `LTP_DOWNLOAD_SHA256`: fallback expected SHA-256 for the downloaded LanguageTool archive. - `LTP_BYPASS_VERIFIED_DOWNLOADS`: set to `true` to skip SHA-256 verification. +- `LTP_MAX_DOWNLOAD_BYTES`: maximum downloaded ZIP size in bytes. + - default: `536870912` (512 MiB) +- `LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES`: maximum total compressed member size in bytes. + - default: `536870912` (512 MiB) +- `LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES`: maximum total extracted size in bytes. + - default: `805306368` (768 MiB) +- `LTP_SAFE_ZIP_MAX_MEMBERS`: maximum ZIP member count. + - default: `5000` +- `LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES`: maximum extracted size for a single ZIP member in bytes. + - default: `134217728` (128 MiB) +- `LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO`: maximum compression ratio for a single ZIP member. + - default: `100.0` +- `LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO`: maximum compression ratio for the whole ZIP archive. + - default: `10.0` Downloaded zips are verified with SHA-256 when a checksum is available. Checksums are resolved in this order: 1. `LTP_DOWNLOAD_SHA256_`, where non-alphanumeric characters in the version are replaced with `_` and the name is uppercased. diff --git a/language_tool_python/download_lt.py b/language_tool_python/download_lt.py index 4e9aa55..8252412 100755 --- a/language_tool_python/download_lt.py +++ b/language_tool_python/download_lt.py @@ -27,8 +27,10 @@ from ._deprecated import deprecated from .config_file import LanguageToolConfig from .exceptions import JavaError, PathError +from .safe_zip import SafeZipExtractor from .utils import ( LTP_JAR_DIR_PATH_ENV_VAR, + get_env_int, get_language_tool_download_path, ) @@ -55,6 +57,9 @@ LT_SNAPSHOT_CURRENT_VERSION = "6.9-SNAPSHOT" LTP_DOWNLOAD_SHA256_ENV_VAR = "LTP_DOWNLOAD_SHA256" LTP_BYPASS_VERIFIED_DOWNLOADS_ENV_VAR = "LTP_BYPASS_VERIFIED_DOWNLOADS" +LTP_MAX_DOWNLOAD_BYTES_ENV_VAR = "LTP_MAX_DOWNLOAD_BYTES" +DOWNLOAD_CHUNK_BYTES = 1024 * 1024 +_SAFE_ZIP_EXTRACTOR = SafeZipExtractor() with ( importlib.resources.as_file( @@ -76,6 +81,12 @@ ) +MAX_DOWNLOAD_BYTES = get_env_int( + LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, + 512 * 1024 * 1024, +) # 512 MiB, latest snapshot: 246.58 MiB archive + + def _get_zip_hash(version_name: str) -> Optional[str]: """Get the expected SHA-256 hash for a given version of LanguageTool. This function checks for environment variables that may specify the expected hash for the given version. It normalizes the version name to construct the environment variable name. If no specific environment variable is found for the version, it falls back to a general environment variable or a manifest lookup. If the bypass environment variable is set, it will skip verification and return None. @@ -109,6 +120,39 @@ def _get_zip_hash(version_name: str) -> Optional[str]: return None +def _validate_download_size(content_length: Optional[str]) -> Optional[int]: + """ + Validate the HTTP Content-Length header before downloading a ZIP file. + + :param content_length: The Content-Length header value, if present. + :type content_length: Optional[str] + :return: The parsed content length, or None when the header is missing. + :rtype: Optional[int] + :raises PathError: If the header is invalid or exceeds the download size limit. + """ + if content_length is None: + return None + + try: + total = int(content_length) + except ValueError as e: + err = f"Invalid Content-Length header: {content_length!r}." + raise PathError(err) from e + + if total < 0: + err = f"Invalid Content-Length header: {content_length!r}." + raise PathError(err) + + if total > MAX_DOWNLOAD_BYTES: + err = ( + f"Refusing to download {total} bytes. " + f"Maximum allowed download size is {MAX_DOWNLOAD_BYTES} bytes." + ) + raise PathError(err) + + return total + + def parse_java_version(version_text: str) -> Tuple[int, int]: """ Parse the Java version from a given version text. @@ -261,8 +305,15 @@ def unzip_file(temp_file_name: str, directory_to_extract_to: Path) -> None: """ logger.info("Unzipping %s to %s", temp_file_name, directory_to_extract_to) - with zipfile.ZipFile(temp_file_name, "r") as zip_ref: - zip_ref.extractall(directory_to_extract_to) + with ( + tempfile.TemporaryDirectory(dir=directory_to_extract_to.parent) as temp_dir, + zipfile.ZipFile(temp_file_name, "r") as zip_ref, + ): + _SAFE_ZIP_EXTRACTOR.extractall( + zip_ref, + directory_to_extract_to, + work_dir=Path(temp_dir), + ) @deprecated( @@ -419,8 +470,6 @@ def _get_remote_zip( except requests.exceptions.Timeout as e: err = f"Request to {self.download_url} timed out." raise TimeoutError(err) from e - content_length = req.headers.get("Content-Length") - total = int(content_length) if content_length is not None else None if req.status_code == 404: err = f"Could not find at URL {self.download_url}. The given version may not exist or is no longer available." raise PathError(err) @@ -430,14 +479,25 @@ def _get_remote_zip( if req.status_code != 200: err = f"Failed to download from {self.download_url}. HTTP status code: {req.status_code}." raise PathError(err) + content_length = req.headers.get("Content-Length") + total = _validate_download_size(content_length) progress = tqdm.tqdm( unit="B", unit_scale=True, total=total, desc=f"Downloading LanguageTool {self.version_name}", ) - for chunk in req.iter_content(chunk_size=1024): + downloaded_bytes = 0 + for chunk in req.iter_content(chunk_size=DOWNLOAD_CHUNK_BYTES): if chunk: # filter out keep-alive new chunks + downloaded_bytes += len(chunk) + if downloaded_bytes > MAX_DOWNLOAD_BYTES: + progress.close() + err = ( + f"Refusing to download more than {MAX_DOWNLOAD_BYTES} bytes " + f"from {self.download_url}." + ) + raise PathError(err) sha256.update(chunk) progress.update(len(chunk)) downloaded_file.write(chunk) @@ -708,13 +768,17 @@ def download(self) -> None: if self not in self.get_installed_versions(): with ( - tempfile.TemporaryDirectory() as temp_dir, + tempfile.TemporaryDirectory(dir=download_folder) as temp_dir, tempfile.NamedTemporaryFile( suffix=".zip", dir=temp_dir ) as downloaded_file, self._get_remote_zip(downloaded_file) as zip_file, ): - zip_file.extractall(download_folder) + _SAFE_ZIP_EXTRACTOR.extractall( + zip_file, + download_folder, + work_dir=Path(temp_dir), + ) @property def version_name(self) -> str: @@ -790,8 +854,7 @@ def download(self) -> None: Download and install this snapshot version of LanguageTool. This method checks Java compatibility, downloads the snapshot ZIP file, - and extracts it to the download folder. For snapshots, the extracted - directory is renamed to match the expected version name if necessary. + and extracts it to the download folder using the requested snapshot name. """ confirm_java_compatibility(self._version_name) @@ -803,33 +866,41 @@ def download(self) -> None: return if self not in self.get_installed_versions(): - # For snapshots, pass expected_dirname to rename the extracted folder with ( - tempfile.TemporaryDirectory() as temp_dir, + tempfile.TemporaryDirectory(dir=download_folder) as temp_dir, tempfile.NamedTemporaryFile( suffix=".zip", dir=temp_dir ) as downloaded_file, self._get_remote_zip(downloaded_file) as zip_file, ): - lt_dir = zip_file.infolist()[0].filename - expected_dirname = f"LanguageTool-{self.version_name}/" - if lt_dir != expected_dirname: - with ( - tempfile.NamedTemporaryFile( - suffix=".zip", dir=temp_dir - ) as temp_file, - zipfile.ZipFile(temp_file, "w") as renamed_zip, - ): - for item in zip_file.infolist(): - buffer = zip_file.read(item.filename) - new_name = item.filename.replace( - lt_dir, expected_dirname, 1 - ) - renamed_zip.writestr(new_name, buffer) - temp_file.seek(0) - renamed_zip.extractall(download_folder) - else: - zip_file.extractall(download_folder) + snapshot_extract_dir = Path(temp_dir) / "snapshot" + _SAFE_ZIP_EXTRACTOR.extractall( + zip_file, + snapshot_extract_dir, + work_dir=Path(temp_dir), + ) + extracted_roots = list(snapshot_extract_dir.iterdir()) + if len(extracted_roots) != 1 or not extracted_roots[0].is_dir(): + err = ( + "Expected snapshot archive to contain exactly one " + "root directory." + ) + raise PathError(err) + + expected_dir = download_folder / f"LanguageTool-{self.version_name}" + if expected_dir.exists() or expected_dir.is_symlink(): + err = ( + "Refusing to overwrite existing LanguageTool snapshot " + f"directory: {expected_dir}." + ) + raise PathError(err) + + logger.debug( + "Renaming extracted snapshot directory %s to %s", + extracted_roots[0], + expected_dir, + ) + extracted_roots[0].rename(expected_dir) @property def version_name(self) -> str: diff --git a/language_tool_python/safe_zip.py b/language_tool_python/safe_zip.py new file mode 100644 index 0000000..1d96398 --- /dev/null +++ b/language_tool_python/safe_zip.py @@ -0,0 +1,663 @@ +"""Safe ZIP extraction utilities.""" + +import contextlib +import logging +import re +import shutil +import stat +import tempfile +import zipfile +from dataclasses import dataclass +from pathlib import Path, PurePosixPath +from typing import Optional + +from .exceptions import PathError +from .utils import get_env_float, get_env_int + +logger = logging.getLogger(__name__) + +LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES_ENV_VAR = "LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES" +LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES_ENV_VAR = "LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES" +LTP_SAFE_ZIP_MAX_MEMBERS_ENV_VAR = "LTP_SAFE_ZIP_MAX_MEMBERS" +LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES_ENV_VAR = ( + "LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES" +) +LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO_ENV_VAR = ( + "LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO" +) +LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO_ENV_VAR = ( + "LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO" +) +DEFAULT_MAX_ARCHIVE_BYTES = get_env_int( + LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES_ENV_VAR, 512 * 1024 * 1024 +) # 512 MiB, latest snapshot: 246.15 MiB compressed members +DEFAULT_MAX_EXTRACTED_BYTES = get_env_int( + LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES_ENV_VAR, 768 * 1024 * 1024 +) # 768 MiB, latest snapshot: 394.48 MiB extracted +DEFAULT_MAX_MEMBERS = get_env_int( + LTP_SAFE_ZIP_MAX_MEMBERS_ENV_VAR, + 5_000, +) # latest snapshot: 2,051 members +DEFAULT_COPY_CHUNK_BYTES = 1024 * 1024 # I/O chunk size +DEFAULT_MAX_MEMBER_EXTRACTED_BYTES = get_env_int( + LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES_ENV_VAR, 128 * 1024 * 1024 +) # 128 MiB, latest snapshot: 32.91 MiB largest member +DEFAULT_MAX_MEMBER_COMPRESSION_RATIO = get_env_float( + LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO_ENV_VAR, + 100.0, +) # latest snapshot: 57.89 max member ratio +DEFAULT_MAX_TOTAL_COMPRESSION_RATIO = get_env_float( + LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO_ENV_VAR, + 10.0, +) # latest snapshot: 1.60 total ratio +RESERVED_WINDOWS_FILENAMES = { + "CON", + "PRN", + "AUX", + "NUL", + *(f"COM{index}" for index in range(1, 10)), + *(f"LPT{index}" for index in range(1, 10)), +} + + +@dataclass(frozen=True) +class SafeZipLimits: + """ + Limits applied while validating and extracting a ZIP archive. + + Values are expressed in bytes unless otherwise stated. + """ + + max_archive_bytes: int = DEFAULT_MAX_ARCHIVE_BYTES + max_extracted_bytes: int = DEFAULT_MAX_EXTRACTED_BYTES + max_members: int = DEFAULT_MAX_MEMBERS + copy_chunk_bytes: int = DEFAULT_COPY_CHUNK_BYTES + max_member_compression_ratio: float = DEFAULT_MAX_MEMBER_COMPRESSION_RATIO + max_total_compression_ratio: float = DEFAULT_MAX_TOTAL_COMPRESSION_RATIO + max_member_extracted_bytes: int = DEFAULT_MAX_MEMBER_EXTRACTED_BYTES + + +class SafeZipExtractor: + """ + Extract ZIP archives after validating paths, member types, and size limits. + """ + + def __init__(self, limits: Optional[SafeZipLimits] = None) -> None: + """ + Initialize the safe extractor. + + :param limits: Optional extraction limits. Defaults to SafeZipLimits(). + :type limits: Optional[SafeZipLimits] + """ + self.limits = limits or SafeZipLimits() + + def extractall( + self, + zip_file: zipfile.ZipFile, + destination: Path, + work_dir: Optional[Path] = None, + ) -> None: + """ + Safely extract all ZIP members into destination. + + Extraction first happens inside a private directory, then validated + top-level entries are moved into the final destination. + + :param zip_file: The open ZIP archive to extract. + :type zip_file: zipfile.ZipFile + :param destination: Directory where ZIP contents should be placed. + :type destination: Path + :param work_dir: Optional parent directory for temporary extraction. + :type work_dir: Optional[Path] + :raises PathError: If the archive or destination is unsafe. + """ + destination = Path(destination) + + logger.debug( + "Starting safe ZIP extraction to %s (work_dir=%s)", + destination, + work_dir, + ) + + if work_dir is None: + destination.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(dir=destination.parent) as temp_dir: + self._extractall_to_directory(zip_file, destination, Path(temp_dir)) + else: + self._extractall_to_directory(zip_file, destination, work_dir) + + logger.debug("Completed safe ZIP extraction to %s", destination) + + def _normalize_member_path(self, filename: str) -> PurePosixPath: + """ + Normalize and validate a ZIP member path. + + :param filename: Raw ZIP member filename. + :type filename: str + :return: A safe relative POSIX path. + :rtype: PurePosixPath + :raises PathError: If the path is absolute, traverses, or is unsafe. + """ + if not filename or "\x00" in filename: + err = f"Unsafe ZIP member name: {filename!r}." + raise PathError(err) + + if any(ord(character) < 32 for character in filename): + err = f"Unsafe ZIP member name: {filename!r}." + raise PathError(err) + + normalized = filename.replace("\\", "/") + + if normalized.startswith("/") or re.match(r"^[A-Za-z]:", normalized): + err = f"Unsafe ZIP member path: {filename!r}." + raise PathError(err) + + parts = normalized.split("/") + + if parts[-1] == "": + parts = parts[:-1] + + if not parts or any(part in {"", ".", ".."} for part in parts): + err = f"Unsafe ZIP member path: {filename!r}." + raise PathError(err) + + for part in parts: + if ":" in part or part.endswith((" ", ".")): + err = f"Unsafe ZIP member path: {filename!r}." + raise PathError(err) + + windows_name = part.rstrip(" .").split(".", 1)[0].upper() + + if windows_name in RESERVED_WINDOWS_FILENAMES: + err = f"Unsafe ZIP member path: {filename!r}." + raise PathError(err) + + member_path = PurePosixPath(*parts) + + if member_path.is_absolute() or any(part == ".." for part in member_path.parts): + err = f"Unsafe ZIP member path: {filename!r}." + raise PathError(err) + + return member_path + + def _validate_member_type(self, member: zipfile.ZipInfo) -> None: + """ + Reject symlinks and unsupported ZIP member types. + + :param member: ZIP member metadata. + :type member: zipfile.ZipInfo + :raises PathError: If the member is not a regular file or directory. + """ + mode = member.external_attr >> 16 + file_type = stat.S_IFMT(mode) + + if stat.S_ISLNK(mode): + err = f"Refusing to extract symlink from ZIP archive: {member.filename!r}." + raise PathError(err) + + if file_type == 0: + return + + if member.is_dir(): + if stat.S_ISDIR(mode): + return + elif stat.S_ISREG(mode): + return + + err = f"Refusing to extract unsupported ZIP member type: {member.filename!r}." + raise PathError(err) + + def _validate_member_compression_ratio(self, member: zipfile.ZipInfo) -> None: + """ + Reject a member with a suspicious compression ratio. + + :param member: ZIP member metadata. + :type member: zipfile.ZipInfo + :raises PathError: If the compressed size is invalid or the ratio is too high. + """ + if member.file_size == 0: + return + + if member.compress_size == 0: + err = ( + f"Refusing ZIP member with zero compressed size and non-zero " + f"expanded size: {member.filename!r}." + ) + raise PathError(err) + + ratio = member.file_size / member.compress_size + + if ratio > self.limits.max_member_compression_ratio: + err = ( + f"Refusing ZIP member with suspicious compression ratio " + f"{ratio:.1f}: {member.filename!r}. " + f"Maximum allowed ratio is " + f"{self.limits.max_member_compression_ratio:.1f}." + ) + raise PathError(err) + + def _zip_target(self, destination: Path, member_path: PurePosixPath) -> Path: + """ + Resolve a member target and ensure it stays inside destination. + + :param destination: Extraction root directory. + :type destination: Path + :param member_path: Normalized ZIP member path. + :type member_path: PurePosixPath + :return: The filesystem target for the member. + :rtype: Path + :raises PathError: If the target escapes destination. + """ + target = destination.joinpath(*member_path.parts) + destination_resolved = destination.resolve(strict=True) + target_resolved = target.resolve(strict=False) + + if destination_resolved != target_resolved and ( + destination_resolved not in target_resolved.parents + ): + err = f"Unsafe ZIP member path: {str(member_path)!r}." + raise PathError(err) + + return target + + def _validate_members( + self, + members: list[zipfile.ZipInfo], + ) -> list[tuple[zipfile.ZipInfo, PurePosixPath]]: + """ + Validate all ZIP members before writing any file. + + :param members: ZIP members to validate. + :type members: list[zipfile.ZipInfo] + :return: Members paired with normalized safe paths. + :rtype: list[tuple[zipfile.ZipInfo, PurePosixPath]] + :raises PathError: If a member is unsafe or archive limits are exceeded. + """ + if len(members) > self.limits.max_members: + err = ( + f"Refusing to extract {len(members)} ZIP members. " + f"Maximum allowed member count is {self.limits.max_members}." + ) + raise PathError(err) + + total_compressed = 0 + total_uncompressed = 0 + seen_paths: set[str] = set() + seen_file_paths: set[str] = set() + validated_members: list[tuple[zipfile.ZipInfo, PurePosixPath]] = [] + + for member in members: + member_path = self._normalize_member_path(member.filename) + path_key = "/".join(part.casefold() for part in member_path.parts) + + if path_key in seen_paths: + err = ( + f"Refusing to extract duplicate ZIP member path: " + f"{member.filename!r}." + ) + raise PathError(err) + + seen_paths.add(path_key) + + ancestor_keys = [ + "/".join(part.casefold() for part in member_path.parts[:index]) + for index in range(1, len(member_path.parts)) + ] + if any(ancestor in seen_file_paths for ancestor in ancestor_keys): + err = ( + f"Refusing to extract ZIP member below file path: " + f"{member.filename!r}." + ) + raise PathError(err) + + if not member.is_dir(): + descendant_prefix = f"{path_key}/" + if any( + existing.startswith(descendant_prefix) for existing in seen_paths + ): + err = ( + f"Refusing to extract ZIP file over directory path: " + f"{member.filename!r}." + ) + raise PathError(err) + seen_file_paths.add(path_key) + + self._validate_member_type(member) + + if member.compress_size < 0 or member.file_size < 0: + err = f"Invalid ZIP member size: {member.filename!r}." + raise PathError(err) + + self._validate_member_compression_ratio(member) + + total_compressed += member.compress_size + total_uncompressed += member.file_size + + if total_compressed > self.limits.max_archive_bytes: + err = ( + f"Refusing to extract ZIP archive with {total_compressed} " + f"compressed member bytes. Maximum allowed size is " + f"{self.limits.max_archive_bytes} bytes." + ) + raise PathError(err) + + if total_uncompressed > self.limits.max_extracted_bytes: + err = ( + f"Refusing to extract {total_uncompressed} bytes. " + f"Maximum allowed extracted size is " + f"{self.limits.max_extracted_bytes} bytes." + ) + raise PathError(err) + + validated_members.append((member, member_path)) + + if total_compressed > 0: + total_ratio = total_uncompressed / total_compressed + + if total_ratio > self.limits.max_total_compression_ratio: + err = ( + f"Refusing ZIP archive with suspicious total compression ratio " + f"{total_ratio:.1f}. Maximum allowed ratio is " + f"{self.limits.max_total_compression_ratio:.1f}." + ) + raise PathError(err) + + logger.debug( + "Validated ZIP archive: members=%d, compressed=%d bytes, " + "uncompressed=%d bytes", + len(validated_members), + total_compressed, + total_uncompressed, + ) + + return validated_members + + def _ensure_safe_parent(self, destination: Path, target: Path) -> None: + """ + Ensure the target parent is inside destination and not symlinked. + + :param destination: Extraction root directory. + :type destination: Path + :param target: Target path about to be written. + :type target: Path + :raises PathError: If a parent directory is unsafe. + """ + destination_resolved = destination.resolve(strict=True) + parent_resolved = target.parent.resolve(strict=True) + + if destination_resolved != parent_resolved and ( + destination_resolved not in parent_resolved.parents + ): + err = f"Unsafe ZIP extraction parent path: {target.parent}." + raise PathError(err) + + try: + relative_parent = target.parent.relative_to(destination) + except ValueError as e: + err = f"Unsafe ZIP extraction parent path: {target.parent}." + raise PathError(err) from e + + current = destination + + for part in relative_parent.parts: + current = current / part + + if current.is_symlink(): + err = f"Refusing to extract through symlinked directory: {current}." + raise PathError(err) + + if not current.is_dir(): + err = f"Refusing to extract through non-directory path: {current}." + raise PathError(err) + + current_resolved = current.resolve(strict=True) + + if destination_resolved != current_resolved and ( + destination_resolved not in current_resolved.parents + ): + err = f"Unsafe ZIP extraction directory path: {current}." + raise PathError(err) + + def _copy_member( + self, + zip_file: zipfile.ZipFile, + member: zipfile.ZipInfo, + target: Path, + ) -> None: + """ + Copy one validated file member without overwriting existing paths. + + :param zip_file: The open ZIP archive. + :type zip_file: zipfile.ZipFile + :param member: ZIP member metadata. + :type member: zipfile.ZipInfo + :param target: Destination file path. + :type target: Path + :raises PathError: If the target is unsafe or size checks fail. + """ + if target.exists() or target.is_symlink(): + err = f"Refusing to overwrite existing path while extracting ZIP: {target}." + raise PathError(err) + + if target.parent.is_symlink(): + err = ( + f"Refusing to extract into symlinked parent directory: {target.parent}." + ) + raise PathError(err) + + bytes_written = 0 + + try: + with ( + zip_file.open(member, "r") as source, + open(target, "xb") as target_file, + ): + while True: + chunk = source.read(self.limits.copy_chunk_bytes) + + if not chunk: + break + + bytes_written += len(chunk) + + if bytes_written > member.file_size: + err = ( + f"ZIP member expanded beyond declared size: " + f"{member.filename!r}." + ) + raise PathError(err) + + if bytes_written > self.limits.max_member_extracted_bytes: + err = ( + f"Refusing to extract ZIP member larger than " + f"{self.limits.max_member_extracted_bytes} bytes: " + f"{member.filename!r}." + ) + raise PathError(err) + + target_file.write(chunk) + + except Exception: + with contextlib.suppress(OSError): + target.unlink() + raise + + if bytes_written != member.file_size: + with contextlib.suppress(OSError): + target.unlink() + + err = ( + f"ZIP member extracted size mismatch for {member.filename!r}: " + f"expected {member.file_size} bytes, wrote {bytes_written} bytes." + ) + raise PathError(err) + + def _extract_to_private_directory( + self, + zip_file: zipfile.ZipFile, + destination: Path, + ) -> None: + """ + Extract validated members into a private temporary directory. + + :param zip_file: The open ZIP archive. + :type zip_file: zipfile.ZipFile + :param destination: Private extraction directory. + :type destination: Path + :raises PathError: If validation or extraction fails. + """ + validated_members = self._validate_members(zip_file.infolist()) + + destination.mkdir(parents=True, exist_ok=True) + + if destination.is_symlink(): + err = f"Refusing to extract into symlinked destination: {destination}." + raise PathError(err) + + destination_resolved = destination.resolve(strict=True) + + if not destination_resolved.is_dir(): + err = f"ZIP extraction destination is not a directory: {destination}." + raise PathError(err) + + for member, member_path in validated_members: + target = self._zip_target(destination, member_path) + + if member.is_dir(): + if target.exists() and not target.is_dir(): + err = ( + f"Refusing to overwrite existing path while extracting ZIP: " + f"{target}." + ) + raise PathError(err) + + target.mkdir(parents=True, exist_ok=True) + self._ensure_safe_parent(destination, target) + + if target.is_symlink(): + err = ( + f"Refusing to create or use symlinked ZIP directory: {target}." + ) + raise PathError(err) + + target_resolved = target.resolve(strict=True) + + if destination_resolved != target_resolved and ( + destination_resolved not in target_resolved.parents + ): + err = f"Unsafe ZIP directory path after creation: {target}." + raise PathError(err) + + continue + + target.parent.mkdir(parents=True, exist_ok=True) + self._ensure_safe_parent(destination, target) + self._copy_member(zip_file, member, target) + + logger.debug("Finished extracting ZIP members into %s", destination) + + def _make_private_extract_dir(self, base_dir: Path) -> Path: + """ + Create a private temporary directory under base_dir. + + :param base_dir: Parent directory for temporary extraction. + :type base_dir: Path + :return: The created private directory. + :rtype: Path + :raises PathError: If base_dir is a symlink. + """ + base_dir.mkdir(parents=True, exist_ok=True) + + if base_dir.is_symlink(): + err = ( + f"Refusing to create private extraction directory inside symlink: " + f"{base_dir}." + ) + raise PathError(err) + + extract_dir = Path( + tempfile.mkdtemp( + prefix="zip-extract-", + dir=base_dir, + ) + ) + + with contextlib.suppress(OSError): + extract_dir.chmod(0o700) + + logger.debug("Created private ZIP extraction directory: %s", extract_dir) + return extract_dir + + def _extractall_to_directory( + self, + zip_file: zipfile.ZipFile, + final_directory: Path, + private_work_dir: Path, + ) -> None: + """ + Extract into a private directory and move safe top-level entries. + + :param zip_file: The open ZIP archive. + :type zip_file: zipfile.ZipFile + :param final_directory: Final extraction destination. + :type final_directory: Path + :param private_work_dir: Parent directory for private extraction. + :type private_work_dir: Path + :raises PathError: If extraction or the final move is unsafe. + """ + extract_dir = self._make_private_extract_dir(private_work_dir) + + try: + self._extract_to_private_directory(zip_file, extract_dir) + + final_directory.mkdir(parents=True, exist_ok=True) + + if final_directory.is_symlink(): + err = f"Refusing to extract into symlinked destination: {final_directory}." + raise PathError(err) + + final_directory_resolved = final_directory.resolve(strict=True) + + if not final_directory_resolved.is_dir(): + err = ( + f"ZIP extraction destination is not a directory: {final_directory}." + ) + raise PathError(err) + + destinations: list[tuple[Path, Path]] = [] + for child in extract_dir.iterdir(): + if child.is_symlink(): + err = f"Refusing to move symlinked extracted path: {child}." + raise PathError(err) + + destination = final_directory / child.name + destination_resolved = destination.resolve(strict=False) + + if final_directory_resolved != destination_resolved and ( + final_directory_resolved not in destination_resolved.parents + ): + err = f"Unsafe extracted ZIP destination path: {destination}." + raise PathError(err) + + if destination.exists() or destination.is_symlink(): + err = ( + f"Refusing to overwrite existing path while extracting ZIP: " + f"{destination}." + ) + raise PathError(err) + + destinations.append((child, destination)) + + for child, destination in destinations: + child.rename(destination) + + logger.debug( + "Moved %d top-level ZIP entries to %s", + len(destinations), + final_directory, + ) + + except Exception: + with contextlib.suppress(OSError): + shutil.rmtree(extract_dir) + raise diff --git a/language_tool_python/utils.py b/language_tool_python/utils.py index 0860da9..5769155 100644 --- a/language_tool_python/utils.py +++ b/language_tool_python/utils.py @@ -3,6 +3,7 @@ import contextlib import locale import logging +import math import os import subprocess import urllib.parse @@ -63,6 +64,66 @@ def parse_url(url_str: str) -> str: return urllib.parse.urlparse(url_str).geturl() +def get_env_int(env_var: str, default: int) -> int: + """ + Read a positive integer from the environment. + + :param env_var: Environment variable name. + :type env_var: str + :param default: Value to use when the environment variable is absent. + :type default: int + :return: Configured integer value, or the default. + :rtype: int + :raises PathError: If the configured value is invalid. + """ + configured = os.environ.get(env_var) + + if configured is None: + return default + + try: + value = int(configured) + except ValueError as e: + err = f"Invalid integer configured by {env_var}: {configured!r}." + raise PathError(err) from e + + if value <= 0: + err = f"Invalid integer configured by {env_var}: {configured!r}." + raise PathError(err) + + return value + + +def get_env_float(env_var: str, default: float) -> float: + """ + Read a positive float from the environment. + + :param env_var: Environment variable name. + :type env_var: str + :param default: Value to use when the environment variable is absent. + :type default: float + :return: Configured float value, or the default. + :rtype: float + :raises PathError: If the configured value is invalid. + """ + configured = os.environ.get(env_var) + + if configured is None: + return default + + try: + value = float(configured) + except ValueError as e: + err = f"Invalid float configured by {env_var}: {configured!r}." + raise PathError(err) from e + + if not math.isfinite(value) or value <= 0: + err = f"Invalid float configured by {env_var}: {configured!r}." + raise PathError(err) + + return value + + class TextStatus(Enum): CORRECT = "correct" FAULTY = "faulty" diff --git a/tests/test_download.py b/tests/test_download.py index 798dd99..a21c7e9 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,17 +1,25 @@ """Tests for the download/language functionality of LanguageTool.""" +import contextlib import hashlib +import importlib import io import re +import shutil +import uuid import zipfile from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path from unittest.mock import MagicMock, patch import pytest +from language_tool_python import download_lt from language_tool_python.download_lt import ( LTP_BYPASS_VERIFIED_DOWNLOADS_ENV_VAR, LTP_DOWNLOAD_SHA256_ENV_VAR, + LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, LocalLanguageTool, ) from language_tool_python.exceptions import LanguageToolError, PathError @@ -43,6 +51,22 @@ def make_zip_payload(files: dict[str, bytes]) -> bytes: return buffer.getvalue() +@contextmanager +def workspace_temp_dir() -> Iterator[Path]: + """ + Create a temporary directory inside the repository workspace. + """ + root = Path.cwd() / ".test_download_tmp" + path = root / uuid.uuid4().hex + path.mkdir(parents=True) + try: + yield path + finally: + shutil.rmtree(path, ignore_errors=True) + with contextlib.suppress(OSError): + root.rmdir() + + def test_install_inexistent_version() -> None: """ Test that attempting to download a non-existent LanguageTool version raises an error. @@ -135,6 +159,111 @@ def test_http_get_other_error_codes() -> None: local_language_tool._get_remote_zip(out_file) +def test_http_get_rejects_oversized_content_length( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Test that oversized ZIP downloads are rejected before streaming. + """ + payload = make_zip_payload( + {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"} + ) + response = MockDownloadResponse(payload) + response.headers["Content-Length"] = "2" + monkeypatch.setattr(download_lt, "MAX_DOWNLOAD_BYTES", 1) + + with ( + patch( + "language_tool_python.download_lt.requests.get", + return_value=response, + ), + pytest.raises(PathError, match="Maximum allowed download size"), + ): + LocalLanguageTool.from_version_name()._get_remote_zip(io.BytesIO()) + + +def test_max_download_bytes_uses_env_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Test that the download size limit can be configured from the environment. + """ + try: + with monkeypatch.context() as env: + env.setenv(LTP_MAX_DOWNLOAD_BYTES_ENV_VAR, "123") + reloaded_download_lt = importlib.reload(download_lt) + + assert reloaded_download_lt.MAX_DOWNLOAD_BYTES == 123 + finally: + importlib.reload(download_lt) + + +def test_http_get_rejects_oversized_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Test that downloads are still size-limited when Content-Length is missing. + """ + payload = make_zip_payload( + {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"} + ) + response = MockDownloadResponse(payload) + response.headers = {} + monkeypatch.setattr(download_lt, "MAX_DOWNLOAD_BYTES", len(payload) - 1) + + with ( + patch( + "language_tool_python.download_lt.requests.get", + return_value=response, + ), + pytest.raises(PathError, match="Refusing to download more than"), + ): + LocalLanguageTool.from_version_name()._get_remote_zip(io.BytesIO()) + + +def test_http_get_rejects_oversized_stream_with_small_content_length( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Test that a lying Content-Length cannot bypass the streamed download limit. + """ + payload = make_zip_payload( + {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"} + ) + response = MockDownloadResponse(payload) + response.headers["Content-Length"] = "1" + monkeypatch.setattr(download_lt, "MAX_DOWNLOAD_BYTES", len(payload) - 1) + + with ( + patch( + "language_tool_python.download_lt.requests.get", + return_value=response, + ), + pytest.raises(PathError, match="Refusing to download more than"), + ): + LocalLanguageTool.from_version_name()._get_remote_zip(io.BytesIO()) + + +@pytest.mark.parametrize("content_length", ["not-a-number", "-1"]) # type: ignore[untyped-decorator] +def test_http_get_rejects_invalid_content_length( + content_length: str, +) -> None: + """ + Test that invalid Content-Length values are rejected before streaming. + """ + response = MockDownloadResponse(b"") + response.headers["Content-Length"] = content_length + + with ( + patch( + "language_tool_python.download_lt.requests.get", + return_value=response, + ), + pytest.raises(PathError, match="Invalid Content-Length"), + ): + LocalLanguageTool.from_version_name()._get_remote_zip(io.BytesIO()) + + def test_http_get_verifies_configured_sha256( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -248,6 +377,42 @@ def test_http_get_bypass_skips_sha256_verification( ] +def test_snapshot_download_renames_archive_root_to_requested_date( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Test that date-pinned snapshots are installed under the requested date name. + """ + requested_snapshot = "20240101" + payload = make_zip_payload( + {"LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar"} + ) + local_language_tool = LocalLanguageTool.from_version_name(requested_snapshot) + monkeypatch.setattr(download_lt, "confirm_java_compatibility", lambda _: None) + + with ( + workspace_temp_dir() as temp_dir, + patch( + "language_tool_python.download_lt.requests.get", + return_value=MockDownloadResponse(payload), + ), + ): + monkeypatch.setattr( + download_lt, "get_language_tool_download_path", lambda: temp_dir + ) + local_language_tool.download() + + expected_dir = temp_dir / f"LanguageTool-{requested_snapshot}" + assert (expected_dir / "languagetool-server.jar").read_bytes() == b"jar" + assert not (temp_dir / "LanguageTool-6.9-SNAPSHOT").exists() + assert local_language_tool.get_directory_path() == expected_dir + + with patch("language_tool_python.download_lt.requests.get") as get_mock: + local_language_tool.download() + + get_mock.assert_not_called() + + def test_install_oldest_supported_version() -> None: """ Test that downloading the oldest supported LanguageTool version works correctly. diff --git a/tests/test_safe_zip.py b/tests/test_safe_zip.py new file mode 100644 index 0000000..d690305 --- /dev/null +++ b/tests/test_safe_zip.py @@ -0,0 +1,564 @@ +"""Tests for safe ZIP extraction.""" + +import contextlib +import hashlib +import importlib +import io +import os +import shutil +import stat +import uuid +import zipfile +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path + +import pytest + +from language_tool_python import safe_zip, utils +from language_tool_python.exceptions import PathError +from language_tool_python.safe_zip import SafeZipExtractor, SafeZipLimits + + +def make_zip_payload(files: dict[str, bytes]) -> bytes: + """ + Create an in-memory ZIP payload for safe extraction tests. + """ + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zip_file: + for filename, payload in files.items(): + zip_file.writestr(filename, payload) + return buffer.getvalue() + + +def make_deflated_zip_payload(files: dict[str, bytes]) -> bytes: + """ + Create an in-memory ZIP payload using DEFLATE compression. + """ + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", compression=zipfile.ZIP_DEFLATED) as zip_file: + for filename, payload in files.items(): + zip_file.writestr(filename, payload) + return buffer.getvalue() + + +def make_zip_payload_from_info(member: zipfile.ZipInfo, payload: bytes) -> bytes: + """ + Create an in-memory ZIP payload with explicit member metadata. + """ + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zip_file: + zip_file.writestr(member, payload) + return buffer.getvalue() + + +def make_symlink_or_skip( + target: Path, + link: Path, + *, + target_is_directory: bool = False, +) -> None: + """ + Create a symlink, or skip the test if the platform disallows it. + """ + try: + os.symlink(target, link, target_is_directory=target_is_directory) + except (NotImplementedError, OSError) as error: + pytest.skip(f"Cannot create symlink for this test: {error}") + + +@contextmanager +def workspace_temp_dir() -> Iterator[Path]: + """ + Create a temporary directory inside the repository workspace. + """ + root = Path.cwd() / ".test_safe_zip_tmp" + path = root / uuid.uuid4().hex + path.mkdir(parents=True) + try: + yield path + finally: + shutil.rmtree(path, ignore_errors=True) + with contextlib.suppress(OSError): + root.rmdir() + + +def test_safe_extract_allows_regular_zip() -> None: + """ + Test that a regular ZIP is extracted by the safe extractor. + """ + payload = make_zip_payload( + { + "LanguageTool-6.9-SNAPSHOT/": b"", + "LanguageTool-6.9-SNAPSHOT/languagetool-server.jar": b"jar", + } + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + assert ( + temp_dir / "LanguageTool-6.9-SNAPSHOT" / "languagetool-server.jar" + ).read_bytes() == b"jar" + + +def test_safe_zip_limits_use_env_overrides( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Test that safe ZIP limits can be configured from the environment. + """ + try: + with monkeypatch.context() as env: + env.setenv(safe_zip.LTP_SAFE_ZIP_MAX_ARCHIVE_BYTES_ENV_VAR, "11") + env.setenv(safe_zip.LTP_SAFE_ZIP_MAX_EXTRACTED_BYTES_ENV_VAR, "22") + env.setenv(safe_zip.LTP_SAFE_ZIP_MAX_MEMBERS_ENV_VAR, "33") + env.setenv( + safe_zip.LTP_SAFE_ZIP_MAX_MEMBER_EXTRACTED_BYTES_ENV_VAR, + "44", + ) + env.setenv( + safe_zip.LTP_SAFE_ZIP_MAX_MEMBER_COMPRESSION_RATIO_ENV_VAR, + "55.5", + ) + env.setenv( + safe_zip.LTP_SAFE_ZIP_MAX_TOTAL_COMPRESSION_RATIO_ENV_VAR, + "66.5", + ) + + reloaded_safe_zip = importlib.reload(safe_zip) + limits = reloaded_safe_zip.SafeZipLimits() + + assert limits.max_archive_bytes == 11 + assert limits.max_extracted_bytes == 22 + assert limits.max_members == 33 + assert limits.max_member_extracted_bytes == 44 + assert limits.max_member_compression_ratio == 55.5 + assert limits.max_total_compression_ratio == 66.5 + finally: + importlib.reload(safe_zip) + + +@pytest.mark.parametrize("configured", ["nan", "inf"]) # type: ignore[untyped-decorator] +def test_safe_zip_float_env_rejects_non_finite_values( + monkeypatch: pytest.MonkeyPatch, + configured: str, +) -> None: + """ + Test that non-finite ratio limits are rejected. + """ + env_var = "LTP_TEST_SAFE_ZIP_FLOAT" + monkeypatch.setenv(env_var, configured) + + with pytest.raises(PathError, match="Invalid float configured"): + utils.get_env_float(env_var, 1.0) + + +@pytest.mark.parametrize( # type: ignore[untyped-decorator] + "filename", + [ + "../outside.txt", + "LanguageTool/../../outside.txt", + "/absolute.txt", + "C:/absolute.txt", + "D:\\somewhere\\file", + "\\\\server\\share\\outside.txt", + "..\\outside.txt", + "LanguageTool\\..\\outside.txt", + "....\\evil", + "foo....\\evil", + "LanguageTool//file.txt", + "LanguageTool/./file.txt", + "LanguageTool/file.txt:stream", + "LanguageTool/CON", + "LanguageTool/NUL.txt", + "LanguageTool/com1", + "LanguageTool/LPT1.log", + "LanguageTool/AUX/report.txt", + "LanguageTool/trailing-space ", + "LanguageTool/trailing-dot.", + ], +) +def test_safe_extract_rejects_unsafe_member_names( + filename: str, +) -> None: + """ + Test that unsafe ZIP member names are rejected. + """ + payload = make_zip_payload({filename: b"nope"}) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="Unsafe ZIP member"), + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_duplicate_member_paths() -> None: + """ + Test that duplicate ZIP member paths are rejected before extraction. + """ + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zip_file: + zip_file.writestr("LanguageTool/file.txt", b"one") + zip_file.writestr("LanguageTool/FILE.txt", b"two") + buffer.seek(0) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(buffer) as zip_file, + pytest.raises(PathError, match="duplicate ZIP member path"), + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_file_directory_conflict() -> None: + """ + Test that archives cannot contain both a file and children below that file path. + """ + payload = make_zip_payload( + { + "LanguageTool/path": b"file", + "LanguageTool/path/child.txt": b"child", + } + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="below file path|file over directory path"), + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_file_directory_conflict_in_reverse_order() -> None: + """ + Test that archives cannot replace a directory path with a file path. + """ + payload = make_zip_payload( + { + "LanguageTool/path/child.txt": b"child", + "LanguageTool/path": b"file", + } + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="below file path|file over directory path"), + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_zip_symlink() -> None: + """ + Test that ZIP symlink entries are rejected. + """ + member = zipfile.ZipInfo("LanguageTool/link") + member.create_system = 3 + member.external_attr = (stat.S_IFLNK | 0o777) << 16 + payload = make_zip_payload_from_info(member, b"target") + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="symlink"), + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_symlinked_destination() -> None: + """ + Test that the final destination itself cannot be a symlink. + """ + payload = make_zip_payload({"LanguageTool/file.txt": b"jar"}) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + real_destination = temp_dir / "real-destination" + real_destination.mkdir() + destination_link = temp_dir / "destination-link" + make_symlink_or_skip( + real_destination, + destination_link, + target_is_directory=True, + ) + + with pytest.raises(PathError, match="symlinked destination"): + SafeZipExtractor().extractall( + zip_file, + destination_link, + work_dir=temp_dir / "work", + ) + + assert not (real_destination / "LanguageTool").exists() + + +def test_safe_extract_rejects_existing_symlink_in_destination() -> None: + """ + Test that an existing destination symlink cannot redirect extracted content. + """ + payload = make_zip_payload({"LanguageTool/file.txt": b"jar"}) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + destination = temp_dir / "destination" + destination.mkdir() + outside = temp_dir / "outside" + outside.mkdir() + make_symlink_or_skip( + outside, + destination / "LanguageTool", + target_is_directory=True, + ) + + with pytest.raises( + PathError, + match="Unsafe extracted ZIP destination path|overwrite existing path", + ): + SafeZipExtractor().extractall( + zip_file, + destination, + work_dir=temp_dir / "work", + ) + + assert not (outside / "file.txt").exists() + + +def test_safe_extract_rejects_symlinked_work_dir() -> None: + """ + Test that the private extraction work directory cannot be a symlink. + """ + payload = make_zip_payload({"LanguageTool/file.txt": b"jar"}) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + work_target = temp_dir / "work-target" + work_target.mkdir() + work_link = temp_dir / "work-link" + make_symlink_or_skip(work_target, work_link, target_is_directory=True) + + with pytest.raises(PathError, match="private extraction directory"): + SafeZipExtractor().extractall( + zip_file, + temp_dir / "destination", + work_dir=work_link, + ) + + +def test_safe_extract_rejects_special_zip_member_type() -> None: + """ + Test that non-file, non-directory ZIP entries are rejected. + """ + member = zipfile.ZipInfo("LanguageTool/fifo") + member.create_system = 3 + member.external_attr = (stat.S_IFIFO | 0o644) << 16 + payload = make_zip_payload_from_info(member, b"") + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="unsupported ZIP member type"), + ): + SafeZipExtractor().extractall(zip_file, temp_dir) + + +def test_safe_extract_allows_multiple_safe_roots() -> None: + """ + Test that safe extraction does not require a LanguageTool-specific root. + """ + payload = make_zip_payload( + { + "first/file.txt": b"one", + "second/file.txt": b"two", + } + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + destination = temp_dir / "destination" + work_dir = temp_dir / "work" + SafeZipExtractor().extractall(zip_file, destination, work_dir=work_dir) + + assert (destination / "first" / "file.txt").read_bytes() == b"one" + assert (destination / "second" / "file.txt").read_bytes() == b"two" + + +def test_safe_extract_rejects_existing_destination_path() -> None: + """ + Test that extraction never overwrites an existing final destination path. + """ + payload = make_zip_payload({"file.txt": b"new"}) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + destination = temp_dir / "destination" + destination.mkdir() + existing_file = destination / "file.txt" + existing_file.write_bytes(b"old") + + with pytest.raises(PathError, match="overwrite existing path"): + SafeZipExtractor().extractall( + zip_file, + destination, + work_dir=temp_dir / "work", + ) + + assert existing_file.read_bytes() == b"old" + + +def test_safe_extract_rejects_too_many_members() -> None: + """ + Test that ZIP archives with too many entries are rejected. + """ + payload = make_zip_payload( + { + "LanguageTool/one.txt": b"one", + "LanguageTool/two.txt": b"two", + } + ) + extractor = SafeZipExtractor(SafeZipLimits(max_members=1)) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="Maximum allowed member count"), + ): + extractor.extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_too_much_uncompressed_data() -> None: + """ + Test that ZIP archives with too much uncompressed data are rejected. + """ + payload = make_zip_payload({"LanguageTool/file.txt": b"four"}) + extractor = SafeZipExtractor(SafeZipLimits(max_extracted_bytes=3)) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="Maximum allowed extracted size"), + ): + extractor.extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_oversized_member_during_copy() -> None: + """ + Test that per-member extracted size limits are enforced while copying. + """ + payload = make_zip_payload({"LanguageTool/file.txt": b"four"}) + extractor = SafeZipExtractor( + SafeZipLimits( + max_extracted_bytes=100, + max_member_extracted_bytes=3, + ) + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="ZIP member larger"), + ): + extractor.extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_too_much_compressed_data() -> None: + """ + Test that local ZIP extraction also applies the compressed-size limit. + """ + payload = make_zip_payload({"LanguageTool/file.txt": b"data"}) + extractor = SafeZipExtractor(SafeZipLimits(max_archive_bytes=1)) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="compressed member bytes"), + ): + extractor.extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_suspicious_member_compression_ratio() -> None: + """ + Test that a single member with an abusive compression ratio is rejected. + """ + payload = make_deflated_zip_payload({"LanguageTool/file.txt": b"A" * 4096}) + extractor = SafeZipExtractor( + SafeZipLimits( + max_member_compression_ratio=2.0, + max_total_compression_ratio=10_000.0, + ) + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="suspicious compression ratio"), + ): + extractor.extractall(zip_file, temp_dir) + + +def test_safe_extract_rejects_suspicious_total_compression_ratio() -> None: + """ + Test that an archive with an abusive total compression ratio is rejected. + """ + payload = make_deflated_zip_payload({"LanguageTool/file.txt": b"A" * 4096}) + extractor = SafeZipExtractor( + SafeZipLimits( + max_member_compression_ratio=10_000.0, + max_total_compression_ratio=2.0, + ) + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + pytest.raises(PathError, match="suspicious total compression ratio"), + ): + extractor.extractall(zip_file, temp_dir) + + +def test_safe_extract_checks_total_compression_ratio_after_all_members() -> None: + """ + Test that total ratio checks are based on the final archive ratio. + """ + already_compressed = b"".join( + hashlib.sha256(index.to_bytes(4, "big")).digest() for index in range(2048) + ) + payload = make_deflated_zip_payload( + { + "LanguageTool/compressible.txt": b"A" * 4096, + "LanguageTool/already-compressed.bin": already_compressed, + } + ) + extractor = SafeZipExtractor( + SafeZipLimits( + max_member_compression_ratio=1_000.0, + max_total_compression_ratio=5.0, + ) + ) + + with ( + workspace_temp_dir() as temp_dir, + zipfile.ZipFile(io.BytesIO(payload)) as zip_file, + ): + extractor.extractall(zip_file, temp_dir) + + assert ( + temp_dir / "LanguageTool" / "compressible.txt" + ).read_bytes() == b"A" * 4096 + assert ( + temp_dir / "LanguageTool" / "already-compressed.bin" + ).read_bytes() == already_compressed