NisabaRelief / dev_scripts /process_images.py
boatbomber's picture
Refactor process_images and improve filtering out bad inputs
3f8604c
Raw
History Blame Contribute Delete
12.2 kB
"""Process a directory of images through NisabaRelief and save as PNG."""
import argparse
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from PIL import Image
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
ProgressColumn,
SpinnerColumn,
Task,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text
from nisaba_relief import NisabaRelief
from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION
Image.MAX_IMAGE_PIXELS = None
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp", ".gif"}
SKIP_LABELS = {
"small": "image(s) smaller than {min_size}px",
"empty": "mostly-empty image(s)",
"bw": "black-and-white image(s)",
"corrupt": "corrupt/truncated image(s)",
}
class SimpleTimeRemainingColumn(ProgressColumn):
"""Estimates remaining time from the average duration of recent iterations.
The window is 0.5% of the task total (minimum 1, maximum 200). Only recomputes when a new
step completes so the display is stable.
"""
def __init__(self) -> None:
super().__init__()
self._last_completed: float = 0
self._last_elapsed: float = 0.0
self._durations: list[float] = []
self._window: int = 0
self._cached: Text = Text("-:--:--", style="progress.remaining")
def render(self, task: Task) -> Text:
if task.completed <= self._last_completed:
return self._cached
if not self._window and task.total:
self._window = min(max(1, int(task.total * 0.005)), 200)
elapsed = task.finished_time if task.finished else task.elapsed
if not elapsed or not task.completed:
self._last_completed = task.completed
self._cached = Text("-:--:--", style="progress.remaining")
return self._cached
step_duration = elapsed - self._last_elapsed
steps = task.completed - self._last_completed
if steps > 0 and self._last_completed > 0:
per_step = step_duration / steps
self._durations.append(per_step)
if self._window and len(self._durations) > self._window:
self._durations = self._durations[-self._window :]
self._last_completed = task.completed
self._last_elapsed = elapsed
if not self._durations:
self._cached = Text("-:--:--", style="progress.remaining")
return self._cached
avg = sum(self._durations) / len(self._durations)
remaining = task.total - task.completed
eta_seconds = avg * remaining
hours, rem = divmod(int(eta_seconds), 3600)
minutes, seconds = divmod(rem, 60)
if hours:
self._cached = Text(
f"{hours}:{minutes:02d}:{seconds:02d}", style="progress.remaining"
)
else:
self._cached = Text(f"{minutes}:{seconds:02d}", style="progress.remaining")
return self._cached
def _make_progress(label: str) -> Progress:
"""Build a Progress bar with the standard column layout."""
return Progress(
SpinnerColumn(),
TextColumn(label),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TextColumn("eta"),
SimpleTimeRemainingColumn(),
)
def _classify_histogram(
img: Image.Image,
uniform_threshold: float,
sat_threshold: float = 0.03,
mid_threshold: float = 0.28,
sample_size: int = 256,
) -> str | None:
"""Classify an image by its grayscale histogram. Returns a skip reason or None.
Builds a single thumbnail + histogram and runs two checks:
1. Black and White: lacking saturated colors and mid tones.
2. Mostly-empty: a single non-black color dominates (±5 sliding window).
"""
# JPEG: decode at reduced resolution via libjpeg DCT scaling (fast, low memory)
# Other formats: no-op, thumbnail handles resize after full load
img.draft("RGB", (sample_size, sample_size))
img.thumbnail((sample_size, sample_size), Image.NEAREST)
hist = img.convert("L").histogram()
total = sum(hist)
# Check if it contains only black and white with no midtones (eg: lineart, text screenshots)
sat_hist = img.convert("HSV").split()[1].histogram()
high_sat = sum(sat_hist[31:]) / total
if high_sat < sat_threshold and sum(hist[45:205]) / total < mid_threshold:
return "bw"
# Check for dominant single color (sliding window of width 11, ±5)
if uniform_threshold < 1:
window = 11
half = window // 2
running = sum(hist[:window])
best_count = running
best_center = half
for center in range(half + 1, 256 - half):
running += hist[center + half] - hist[center - half - 1]
if running > best_count:
best_count = running
best_center = center
if best_center >= 10 and best_count / total >= uniform_threshold:
return "empty"
return None
def _check_image(
src: Path, dst: Path, min_size: int, max_uniform: float
) -> tuple[Path, Path, str]:
"""Classify a single image for filtering. Returns (src, dst, status)."""
try:
with warnings.catch_warnings(), Image.open(src) as img:
warnings.simplefilter("ignore", UserWarning)
if max(img.size) < min_size or min(img.size) < MIN_IMAGE_DIMENSION:
return src, dst, "small"
reason = _classify_histogram(img, max_uniform)
if reason:
return src, dst, reason
except (OSError, SyntaxError):
return src, dst, "corrupt"
return src, dst, "process"
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Process images through NisabaRelief and save as PNG."
)
parser.add_argument(
"--input-dir", type=Path, required=True, help="Source image directory"
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Destination directory (created if needed)",
)
parser.add_argument(
"--max-size",
type=int,
default=MAX_TILE * 5,
help="Downsample images larger than this before processing (default: %(default)s)",
)
parser.add_argument(
"--min-size",
type=int,
default=1536,
help="Skip images where max dimension < this (default: %(default)s)",
)
parser.add_argument(
"--max-uniform",
type=float,
default=0.65,
help="Skip images where this fraction of pixels share a single non-black color (default: %(default)s, set to 1 to disable)",
)
parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
parser.add_argument(
"--weights-dir", type=Path, default=None, help="Local weights directory"
)
parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
parser.add_argument(
"--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)"
)
parser.add_argument(
"--device", default="cuda", help="Torch device (default: %(default)s)"
)
parser.add_argument(
"--overwrite", action="store_true", help="Re-process even if output file exists"
)
return parser.parse_args()
def _gather_candidates(
input_images: list[Path], output_dir: Path, overwrite: bool
) -> tuple[list[tuple[Path, Path]], int]:
"""Scan filesystem for images that need processing. Returns (candidates, skipped_existing)."""
candidates = []
skipped_existing = 0
with _make_progress("Gathering candidates") as progress:
task = progress.add_task("Scanning", total=len(input_images))
for src in input_images:
dst = output_dir / (src.stem + ".png")
if not overwrite and dst.exists():
skipped_existing += 1
else:
candidates.append((src, dst))
progress.advance(task)
return candidates, skipped_existing
def _filter_candidates(
candidates: list[tuple[Path, Path]], min_size: int, max_uniform: float
) -> tuple[list[tuple[Path, Path]], dict[str, int]]:
"""Run parallel image checks (size + histogram). Returns (to_process, skipped_counts)."""
to_process = []
skipped: dict[str, int] = {}
executor = ThreadPoolExecutor(max_workers=8)
futures = [
executor.submit(_check_image, src, dst, min_size, max_uniform)
for src, dst in candidates
]
with _make_progress("Filtering candidates") as progress:
task = progress.add_task("Filtering", total=len(futures))
try:
for future in as_completed(futures):
src, dst, status = future.result()
if status == "process":
to_process.append((src, dst))
else:
skipped[status] = skipped.get(status, 0) + 1
progress.advance(task)
except KeyboardInterrupt:
executor.shutdown(wait=False, cancel_futures=True)
raise
executor.shutdown()
to_process.sort()
return to_process, skipped
def _process_image(src: Path, dst: Path, model: NisabaRelief, max_size: int) -> None:
"""Load, optionally downsample, run model, restore size, and save a single image."""
image = Image.open(src).convert("RGB")
original_size = image.size
if max(image.size) > max_size:
scale = max_size / max(image.size)
new_size = (
round(image.width * scale) // 16 * 16,
round(image.height * scale) // 16 * 16,
)
image = image.resize(new_size, Image.LANCZOS)
result = model.process(image, show_pbar=False)
if result.size != original_size:
result = result.resize(original_size, Image.LANCZOS)
result.save(dst)
def main():
args = _parse_args()
console = Console()
input_dir: Path = args.input_dir
output_dir: Path = args.output_dir
if not input_dir.is_dir():
console.print(f"[red]Input directory not found:[/red] [cyan]{input_dir}[/cyan]")
return
input_images = sorted(
p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
)
if not input_images:
console.print(f"[red]No images found in[/red] [cyan]{input_dir}[/cyan]")
return
output_dir.mkdir(parents=True, exist_ok=True)
candidates, skipped_existing = _gather_candidates(
input_images, output_dir, args.overwrite
)
to_process, skipped = _filter_candidates(candidates, args.min_size, args.max_uniform)
if skipped_existing:
console.print(
f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
)
for reason, label in SKIP_LABELS.items():
if count := skipped.get(reason):
console.print(
f"[dim]Skipping {count} {label.format(min_size=args.min_size)}[/dim]"
)
if not to_process:
console.print("[green]All images already processed.[/green]")
return
console.print(
f"Processing [bold]{len(to_process)}[/bold] / {len(input_images)} images "
f"[dim]({input_dir}{output_dir})[/dim]"
)
model_kwargs = dict(num_steps=args.num_steps, device=args.device)
if args.seed is not None:
model_kwargs["seed"] = args.seed
if args.weights_dir is not None:
model_kwargs["weights_dir"] = args.weights_dir
if args.batch_size is not None:
model_kwargs["batch_size"] = args.batch_size
model = NisabaRelief(**model_kwargs)
progress = _make_progress("[progress.description]{task.description}")
with progress:
task = progress.add_task("Processing", total=len(to_process))
for src, dst in to_process:
progress.update(task, description=f"[cyan]{src.name}[/cyan]")
_process_image(src, dst, model, args.max_size)
progress.advance(task)
console.print(
f"[green]Done.[/green] {len(to_process)} image(s) saved to [cyan]{output_dir}[/cyan]"
)
if __name__ == "__main__":
main()