"""
Model downloader utility for PumaGuard.
"""
import datetime
import hashlib
import json
import logging
import os
import shutil
from pathlib import (
Path,
)
import requests
import yaml
logger = logging.getLogger("PumaGuard")
MODEL_TAG = "82ec09d65cabd06d46aeefed3a0317200888367d"
MODEL_BASE_URI = (
"https://github.com/PEEC-Nature-Youth-Group/pumaguard-models/raw"
)
_settings_file = Path(__file__).parent / "model-registry.yaml"
if not _settings_file.exists():
raise FileNotFoundError("Could not open model registry")
with open(_settings_file, encoding="utf-8") as fd_registry:
MODEL_REGISTRY: dict[str, dict[str, str | dict[str, dict[str, str]]]] = (
yaml.load(fd_registry, Loader=yaml.SafeLoader)
)
[docs]
def create_registry(models_dir: Path):
"""
Create a new registry file in the cache directory.
This file stores the checksums of the models cached.
"""
registry_file = models_dir / "model-resgistry.json"
if not registry_file.exists():
logger.debug("Creating new registry at %s", registry_file)
with open(registry_file, "w", encoding="utf-8") as fd:
json.dump(
{
"version": "1.0",
"created": datetime.datetime.now().isoformat(),
"last-updated": datetime.datetime.now().isoformat(),
"models": MODEL_REGISTRY,
"cached-models": {},
},
fd,
indent=2,
ensure_ascii=False,
)
logger.info("Created model registry at %s", registry_file)
[docs]
def get_models_directory() -> Path:
"""
Get the directory where models should be stored.
Uses XDG_DATA_HOME or defaults to ~/.local/share/pumaguard/models
"""
xdg_data_home = os.environ.get("XDG_DATA_HOME")
if xdg_data_home:
models_dir = Path(xdg_data_home) / "pumaguard" / "models"
else:
models_dir = Path.home() / ".local" / "share" / "pumaguard" / "models"
models_dir.mkdir(parents=True, exist_ok=True)
create_registry(models_dir)
return models_dir
[docs]
def verify_file_checksum(file_path: Path, expected_sha256: str) -> bool:
"""
Verify file checksum.
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
sha256_hash.update(chunk)
computed_hash = sha256_hash.hexdigest()
return computed_hash == expected_sha256
[docs]
def download_file(
url: str,
destination: Path,
expected_sha256: str | None = None,
print_progress: bool = True,
) -> bool:
"""
Download a file from URL to destination with progress reporting.
Args:
url: URL to download from
destination: Local file path to save to
expected_sha256: Optional SHA256 checksum for verification
Returns:
bool: True if download and verification successful
"""
try:
logger.info("Downloading %s to %s", url, destination)
# Respect custom CA bundle if provided via environment or system path
ca_bundle: str | None = None
# Priority:
# 1. explicit PumaGuard var
# 2. then common envs
# 3. then system bundle
for var in (
"PUMAGUARD_CA_BUNDLE",
"REQUESTS_CA_BUNDLE",
"SSL_CERT_FILE",
):
val = os.environ.get(var)
if val and Path(val).exists():
ca_bundle = val
break
if ca_bundle is None:
# Debian/Ubuntu system bundle
sys_bundle = "/etc/ssl/certs/ca-certificates.crt"
if Path(sys_bundle).exists():
ca_bundle = sys_bundle
response = requests.get(
url,
stream=True,
timeout=60,
verify=ca_bundle if ca_bundle else True,
)
logger.debug("response: %s", response)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
downloaded = 0
with open(destination, "wb") as f:
for chunk in response.iter_content(chunk_size=25 * 1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0 and print_progress:
percent = (downloaded / total_size) * 100
# pylint: disable=line-too-long
print(
f"\rDownload progress: {percent:.1f}% "
+ f"({downloaded / 1024 / 1024:.1f}/{total_size / 1024 / 1024:.1f} MB)",
end="",
flush=True,
)
logger.info("Done downloading %s", url)
# Verify checksum if provided
if expected_sha256:
if not verify_file_checksum(destination, expected_sha256):
logger.error(
"Checksum verification failed for %s", destination
)
destination.unlink() # Remove corrupted file
return False
logger.debug("Checksum verification passed for %s", destination)
logger.info("Successfully downloaded %s", destination)
return True
except requests.HTTPError as e:
logger.error("Failed to download %s: %s", url, e)
if destination.exists():
destination.unlink() # Clean up partial download
return False
except Exception:
logger.error("uncaught exception")
raise
[docs]
def assemble_model_fragments(
fragment_paths: list[Path],
output_path: Path,
expected_sha256: str | None = None,
) -> bool:
"""
Assemble model fragments into a single file (equivalent
to 'cat file* > output').
Args:
fragment_paths: List of paths to fragment files (in order)
output_path: Path where assembled file should be written
Returns:
bool: True if assembly successful
"""
try:
logger.info(
"Assembling %d fragments into %s", len(fragment_paths), output_path
)
with open(output_path, "wb") as output_file:
for i, fragment_path in enumerate(fragment_paths):
if not fragment_path.exists():
logger.error("Fragment %s does not exist", fragment_path)
return False
logger.debug(
"Adding fragment %d/%d: %s",
i + 1,
len(fragment_paths),
fragment_path,
)
with open(fragment_path, "rb") as fragment_file:
# Copy fragment to output file in chunks
while True:
chunk = fragment_file.read(8192)
if not chunk:
break
output_file.write(chunk)
# Verify checksum if provided
if expected_sha256:
if not verify_file_checksum(output_path, expected_sha256):
logger.error(
"Checksum verification failed for %s", output_path
)
output_path.unlink() # Remove corrupted file
return False
logger.debug("Checksum verification passed for %s", output_path)
logger.info("Successfully assembled model: %s", output_path)
return True
except OSError as e:
logger.error("Failed to assemble fragments: %s", e)
if output_path.exists():
output_path.unlink() # Clean up partial file
return False
[docs]
def download_model_fragments(
fragment_urls: list[str],
models_dir: Path,
print_progress: bool = True,
) -> list[Path]:
"""
Download all fragments for a split model.
Args:
fragment_urls: List of URLs to download fragments from
models_dir: Directory to store fragments
Returns:
List[Path]: Paths to downloaded fragment files
"""
fragment_paths: list[Path] = []
for _, url in enumerate(fragment_urls):
# Extract fragment filename from URL
fragment_name = url.split("/")[-1]
fragment_path = models_dir / fragment_name
if not fragment_path.exists():
if not download_file(
url, fragment_path, print_progress=print_progress
):
raise RuntimeError(f"Failed to download fragment: {url}")
fragment_paths.append(fragment_path)
return fragment_paths
# pylint: disable=too-many-branches
[docs]
def ensure_model_available(
model_name: str, print_progress: bool = True
) -> Path:
"""
Ensure a model is available locally, downloading and assembling
if necessary.
Args:
model_name: Name of the model (must be in MODEL_REGISTRY)
Returns:
Path: Path to the local model file
Raises:
ValueError: If model_name not in registry
RuntimeError: If download or assembly fails
"""
if model_name not in MODEL_REGISTRY:
raise ValueError(
f"Unknown model: {model_name}. "
f"Available models: {list(MODEL_REGISTRY.keys())}"
)
models_dir = get_models_directory()
model_path = models_dir / model_name
logger.debug("model_path = %s", model_path)
# Check if model already exists and is valid
if model_path.exists():
model_info = MODEL_REGISTRY[model_name]
sha256 = model_info.get("sha256")
if isinstance(sha256, str) and verify_file_checksum(
model_path, sha256
):
logger.debug(
"Model %s already available at %s", model_name, model_path
)
return model_path
if not isinstance(sha256, str):
raise RuntimeError("Could not get sha256")
logger.warning(
"Model %s exists but failed checksum, re-downloading", model_name
)
model_path.unlink()
model_info = MODEL_REGISTRY[model_name]
# Handle fragmented models
if "fragments" in model_info:
fragment_urls: str | dict[str, dict[str, str]] = model_info[
"fragments"
]
logger.info(
"Downloading fragmented model %s (%d fragments)",
model_name,
len(fragment_urls),
)
logger.debug("fragment_urls = %s", fragment_urls)
# Download all fragments
fragment_paths: list[Path] = []
if not isinstance(fragment_urls, dict):
raise RuntimeError("Unexpected type for fragment_urls")
for fragment_name, fragment_data in fragment_urls.items():
url = MODEL_BASE_URI + "/" + MODEL_TAG + "/" + fragment_name
if not download_file(
url,
models_dir / fragment_name,
fragment_data["sha256"],
print_progress=print_progress,
):
raise RuntimeError(
f"Failed to download fragment: {fragment_name}"
)
fragment_paths.append(models_dir / fragment_name)
# Assemble fragments into final model
sha256 = model_info.get("sha256")
if not isinstance(sha256, str):
raise RuntimeError("Could not get sha256 for model assembly")
if not assemble_model_fragments(fragment_paths, model_path, sha256):
raise RuntimeError(
f"Failed to assemble model fragments for: {model_name}"
)
# Handle single-file models
else:
url = MODEL_BASE_URI + "/" + MODEL_TAG + "/" + model_name
sha256 = model_info.get("sha256")
if not isinstance(sha256, str):
raise RuntimeError(
f"Invalid or missing sha256 for model: {model_name}"
)
if not download_file(url, model_path, sha256, print_progress):
raise RuntimeError(f"Failed to download model: {model_name}")
return model_path
[docs]
def list_available_models() -> list[str]:
"""
List all available models in the registry.
Returns:
Dict: Mapping of model names to their URLs
"""
return list(MODEL_REGISTRY.keys())
[docs]
def clear_model_cache():
"""
Clear all downloaded models from cache.
"""
models_dir = get_models_directory()
if models_dir.exists():
shutil.rmtree(models_dir)
logger.info("Cleared model cache: %s", models_dir)
[docs]
def update_model():
"""
Update a model to cache.
"""
[docs]
def export_registry():
"""
Export registry to standard out.
"""
print(yaml.dump(MODEL_REGISTRY))
[docs]
def cache_models():
"""
Cache all available models.
"""
for model_name in MODEL_REGISTRY:
logger.info("Caching %s", model_name)
ensure_model_available(model_name)