Source code for pumaguard.utils

"""
Some utility functions.
"""

# pylint: disable=wrong-import-position

import datetime
import gc
import hashlib
import logging
import os
import shutil
import threading
from pathlib import (
    Path,
)
from typing import (
    Tuple,
)

import keras  # type: ignore
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL

matplotlib.use("Agg")
import ultralytics

from pumaguard.model_downloader import (
    ensure_model_available,
)
from pumaguard.presets import (
    Settings,
)

logger = logging.getLogger("PumaGuard")

_MODEL_CACHE = {}
_CACHE_LOCK = threading.Lock()


[docs] def get_cached_model(model_type: str, model_path: Path): """ Get a cached model or load it if not in cache. Thread-safe model caching to prevent memory leaks. Args: model_type: Either 'classifier' or 'detector' model_path: Path to the model file Returns: The loaded model from cache or freshly loaded """ cache_key = f"{model_type}:{model_path}" with _CACHE_LOCK: if cache_key not in _MODEL_CACHE: logger.info( "Loading %s model from %s (caching for reuse)", model_type, model_path, ) if model_type == "classifier": _MODEL_CACHE[cache_key] = keras.models.load_model( str(model_path) ) elif model_type == "detector": _MODEL_CACHE[cache_key] = ultralytics.YOLO(str(model_path)) else: raise ValueError(f"Unknown model type: {model_type}") logger.info("Model cached successfully") else: logger.debug("Using cached %s model", model_type) return _MODEL_CACHE[cache_key]
[docs] def clear_model_cache(): """ Clear the model cache. Use this if you need to reload models or free memory. **Warning**: Next classification will need to reload models from disk. """ with _CACHE_LOCK: logger.info("Clearing model cache") _MODEL_CACHE.clear() if hasattr(keras.backend, "clear_session"): keras.backend.clear_session() gc.collect()
[docs] def get_duration( start_time: datetime.datetime, end_time: datetime.datetime ) -> float: """ Get duration between start and end time in seconds. Args: start_time (datetime.timezone): The start time. end_time (datetime.timezone): The end time. Returns: float: The duration in seconds. """ duration = end_time - start_time return duration / datetime.timedelta(microseconds=1) / 1e6
[docs] def copy_images(work_directory, lion_images, no_lion_images): """ Copy images to work directory. """ print( "Copying images to working directory " + os.path.realpath(work_directory) ) for image in lion_images: shutil.copy(image, f"{work_directory}/lion") for image in no_lion_images: shutil.copy(image, f"{work_directory}/no_lion") print("Copied all images")
[docs] def get_md5(filepath: str) -> str: """ Compute the MD5 hash for a file. """ hasher = hashlib.md5() with open(filepath, "rb") as f: while True: data = f.read(65536) if not data: break hasher.update(data) return hasher.hexdigest()
[docs] def get_sha256(filepath: str) -> str: """ Compute the SHA-256 hash for a file. """ hasher = hashlib.sha256() with open(filepath, "rb") as f: while True: data = f.read(65536) if not data: break hasher.update(data) return hasher.hexdigest()
[docs] def classify_image(presets: Settings, image_path: str) -> float: """ Classify the image and print out the result. Args: presets (BasePreset): An instance of the BasePreset class containing image processing settings. model (keras.Model): A pre-trained Keras model used for image classification. image_path (str): The file path to the image to be classified. Returns: float: The classification result as a float value. Prints: The color mode being used, the image being classified, and the time taken for classification. """ model_file = "model-ringtails.h5" logger.debug("classifying image %s using external model", image_path) logger.debug("loading model %s", model_file) classifier_model = keras.models.load_model( os.path.join(presets.base_output_directory, model_file) ) feature_extractor = keras.applications.Xception( weights="imagenet", include_top=True ) try: img_array = prepare_image(image_path, presets.image_dimensions) except Exception as e: logger.error("Failed to load or preprocess image: %s", e) raise start_time = datetime.datetime.now() # Extract features using Xception features = feature_extractor.predict(img_array) # Predict with the trained classifier prediction = classifier_model.predict(features) end_time = datetime.datetime.now() logger.debug( "Classification took %.2f seconds", get_duration(start_time, end_time) ) # Adjusted: Assuming index 0 is 'lion' lion_probability = float(prediction[0][0]) logger.debug("predicted lion probability %.2f", lion_probability) return lion_probability
[docs] def prepare_image(img_path: str, image_dimensions: Tuple[int, int]): """ Prepare the image. """ img = keras.preprocessing.image.load_img( img_path, target_size=image_dimensions ) img_array = keras.preprocessing.image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) img_array = keras.applications.xception.preprocess_input(img_array) return img_array
[docs] def cache_model_two_stage( yolo_model_filename: str, classifier_model_filename: str, print_progress: bool = True, ): """ Caches the model weights. """ _ = ensure_model_available(classifier_model_filename, print_progress) _ = ensure_model_available(yolo_model_filename, print_progress)
[docs] def classify_image_two_stage( presets: Settings, image_path: str, print_progress: bool = True, intermediate_dir: str | None = None, ) -> float: """ Classify the image using two-stage approach: YOLO detection + EfficientNet classification. Args: presets (Preset): An instance of the Preset class containing settings. image_path (str): The file path to the image to be classified. Args: presets (Preset): Settings preset. image_path (str): Path to image file. print_progress (bool): Whether to print model download progress. intermediate_dir (str | None): If provided, store visualization and CSV summaries inside this directory instead of CWD. Returns: float: Maximum puma probability from all detections (0.0 if no detections) """ def expand_box(xyxy, crop_expand, width, height): x1, y1, x2, y2 = xyxy w, h = x2 - x1, y2 - y1 dx, dy = w * crop_expand, h * crop_expand return [ max(0, int(x1 - dx)), max(0, int(y1 - dy)), min(width - 1, int(x2 + dx)), min(height - 1, int(y2 + dy)), ] def prob_puma_from_crop(pil_img): arr = keras.utils.img_to_array( pil_img.resize((image_size, image_size)) ) arr = np.expand_dims(arr, 0) arr = keras.applications.efficientnet_v2.preprocess_input(arr) start_time = datetime.datetime.now() p = float(classifier.predict(arr, verbose=0).ravel()[0]) end_time = datetime.datetime.now() logger.debug( "Classification took %.6f seconds", get_duration(start_time, end_time), ) return p logger.debug("classifying image %s using two-stage approach", image_path) assert presets is not None classifier_model_path = ensure_model_available( presets.classifier_model_filename, print_progress, ) yolo_model_path = ensure_model_available( presets.yolo_model_filename, print_progress, ) image_size = 384 # must match training iou_thresh = 0.45 # YOLO NMS IoU crop_expand = 0.15 # padding around detected box for crop detector = get_cached_model("detector", yolo_model_path) classifier = get_cached_model("classifier", classifier_model_path) best_t = presets.puma_threshold start_time = datetime.datetime.now() image_file = Path(image_path) if not image_file.exists(): logger.error("Could not find file %s", image_file) raise FileNotFoundError(f"Could not find file {image_file}") try: PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True with PIL.Image.open(image_path) as img: image = img.convert("RGB") PIL.ImageFile.LOAD_TRUNCATED_IMAGES = False width, height = image.size except FileNotFoundError: logger.error("Could not find file %s", image_file) raise end_time = datetime.datetime.now() logger.debug( "Loading of image %s took %.6f seconds", image_path, get_duration(start_time, end_time), ) start_time = datetime.datetime.now() res = detector.predict( image, imgsz=640, conf=presets.yolo_conf_thresh, iou=iou_thresh, max_det=presets.yolo_max_dets, verbose=False, ) end_time = datetime.datetime.now() logger.debug( "YOLO step took %.6f seconds", get_duration(start_time, end_time) ) boxes = ( res[0].boxes.xyxy.cpu().numpy() if res and res[0].boxes is not None and res[0].boxes.xyxy is not None else [] ) logger.debug("boxes:\n%s", boxes) logger.debug( "box sizes: %s", [ float((x2 - x1) * (y2 - y1) / image_size / image_size) for _, (x1, y1, x2, y2) in enumerate(boxes) ], ) det_probs, crops_xyxy, crops_imgs = [], [], [] for _, (x1, y1, x2, y2) in enumerate(boxes): # Filter crops smaller than min_size fraction if (x2 - x1) * ( y2 - y1 ) / image_size / image_size < presets.yolo_min_size: logger.debug( "ignoring bounding box below threshold: %s", [float(x1), float(y1), float(x2), float(y2)], ) continue x1e, y1e, x2e, y2e = expand_box( [x1, y1, x2, y2], crop_expand, width, height ) crop = image.crop((x1e, y1e, x2e, y2e)) p = prob_puma_from_crop(crop) det_probs.append(p) crops_xyxy.append((x1e, y1e, x2e, y2e)) crops_imgs.append(crop) rows = 1 + ((len(crops_imgs) + 3) // 4) fig = plt.figure(figsize=(max(8, min(16, 4 * 4)), max(5, 3 * rows))) # Original with boxes ax = fig.add_subplot(rows, 1, 1) ax.imshow(image) ax.axis("off") for p, (x1e, y1e, x2e, y2e) in zip(det_probs, crops_xyxy): rect = plt.Rectangle( (x1e, y1e), x2e - x1e, y2e - y1e, fill=False, color="lime", linewidth=2, ) ax.add_patch(rect) ax.text( x1e, max(0, y1e - 5), f"{p:.2f}", color="black", bbox={"facecolor": "lime", "alpha": 0.7, "pad": 2}, ) title_probs = ( ", ".join(f"{i}:{p:.2f}" for i, p in enumerate(det_probs)) if det_probs else "no detections" ) ax.set_title(f"{image_path} — det_probs: {title_probs}") idx = 0 for r in range(1, rows): for c in range(1, 5): if idx >= len(crops_imgs): break axc = fig.add_subplot(rows, 4, r * 4 + c) axc.imshow(crops_imgs[idx]) axc.axis("off") lbl = "Puma" if det_probs[idx] >= best_t else "Not-puma" axc.set_title(f"det {idx}{det_probs[idx]:.3f}{lbl}") idx += 1 # Determine output paths out_jpg = f"{image_file.stem}_viz.jpg" if intermediate_dir: out_jpg = str(Path(intermediate_dir) / out_jpg) plt.tight_layout() plt.savefig(out_jpg, dpi=160, format="jpeg") plt.close(fig) logger.debug("Freeing memory") gc.collect() logger.debug("probabilities: %s", det_probs) if len(det_probs) == 0: logger.debug("no detections") return max(det_probs) if len(det_probs) > 0 else 0