next-frame-predict / dataset_test.py
dvdface's picture
Simplify project structure: rename dirs and files
8b78124 verified
from __future__ import annotations
import argparse
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf
from infer.io import save_image, save_sequence_grid
from infer.predictor import Predictor
def _natural_key(s: str):
return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]
def _timestamp_from_name(p: Path) -> Optional[int]:
m = re.match(r"^(\d+)", p.stem)
if not m:
return None
try:
return int(m.group(1))
except ValueError:
return None
def _parse_unix_timestamp_ms(path: Union[str, Path]) -> Optional[int]:
"""
Mirror notebook logic:
- >=19 digits: ns -> ms
- >=16 digits: us -> ms
- >=13 digits: ms
- >=10 digits: s -> ms
else: None
"""
stem = Path(path).stem
if not stem.isdigit():
return None
raw = int(stem)
digits = len(stem)
if digits >= 19:
return raw // 1_000_000
if digits >= 16:
return raw // 1_000
if digits >= 13:
return raw
if digits >= 10:
return raw * 1_000
return None
def _has_valid_adjacent_timestamp_gap(paths: Sequence[Path], max_gap_ms: int = 32) -> bool:
timestamps = [_parse_unix_timestamp_ms(p) for p in paths]
if any(ts is None for ts in timestamps):
return True
for prev_ts, cur_ts in zip(timestamps, timestamps[1:]):
if cur_ts < prev_ts or (cur_ts - prev_ts) > max_gap_ms:
return False
return True
def _gap_ms(ts0: int, ts1: int) -> float:
diff = ts1 - ts0
if diff <= 0:
return 0.0
# Treat timestamps as ms already (they come from filenames); this is only used
# by the legacy top-k mode which splits clips first.
return float(diff)
def _resize_np(x: np.ndarray, hw: Tuple[int, int]) -> np.ndarray:
t, h, w, c = x.shape
y = tf.image.resize(tf.convert_to_tensor(x, tf.float32), hw, method="bilinear", antialias=True)
return y.numpy()
def _ssim_adjacent(frames01: np.ndarray, ssim_hw: Tuple[int, int]) -> np.ndarray:
# frames01: [T,H,W,C] float32 0..1
x = frames01
if x.shape[1:3] != (ssim_hw[0], ssim_hw[1]):
x = _resize_np(x, ssim_hw)
# tf.image.ssim expects [N,H,W,C]
a = tf.convert_to_tensor(x[:-1], tf.float32)
b = tf.convert_to_tensor(x[1:], tf.float32)
s = tf.image.ssim(a, b, max_val=1.0) # [T-1]
return s.numpy()
def _find_motion_segments_in_clip(frames01: np.ndarray, threshold: float, ssim_hw: Tuple[int, int]) -> List[Tuple[int, int, float]]:
"""
Return segments [start,end) where (1-SSIM) > threshold on adjacent pairs.
Used by legacy top-k selection mode.
"""
if frames01.shape[0] < 2:
return []
ssim = _ssim_adjacent(frames01, ssim_hw) # len T-1
motion = 1.0 - ssim
mask = motion > threshold
segs: List[Tuple[int, int, float]] = []
i = 0
while i < len(mask):
if not mask[i]:
i += 1
continue
j = i
while j < len(mask) and mask[j]:
j += 1
start = i
end = j + 1
score = float(np.mean(motion[i:j])) if j > i else 0.0
segs.append((start, end, score))
i = j
return segs
def _split_by_timestamp_gap(paths: List[Path], max_gap_ms: int) -> List[List[Path]]:
clips: List[List[Path]] = []
cur: List[Path] = []
prev_ts_ms: Optional[int] = None
for p in paths:
ts_ms = _parse_unix_timestamp_ms(p)
if ts_ms is None:
# If we can't parse timestamp, keep as continuous.
if not cur:
cur = [p]
else:
cur.append(p)
continue
if prev_ts_ms is None:
cur = [p]
prev_ts_ms = ts_ms
continue
if (ts_ms - prev_ts_ms) > int(max_gap_ms):
if cur:
clips.append(cur)
cur = [p]
else:
cur.append(p)
prev_ts_ms = ts_ms
if cur:
clips.append(cur)
return clips
@dataclass
class Segment:
clip_index: int
start: int # inclusive index in clip
end: int # exclusive index in clip
score: float
@property
def length(self) -> int:
return max(0, self.end - self.start)
def sort_key(self) -> Tuple[float, int]:
# Higher score first, then longer
return (self.score, self.length)
def overlaps(self, other: "Segment") -> bool:
if self.clip_index != other.clip_index:
return False
return not (self.end <= other.start or other.end <= self.start)
def _ssim_score_np(img_a: np.ndarray, img_b: np.ndarray, c1: float = 0.01**2, c2: float = 0.03**2) -> float:
# Mirror notebook SSIM (mean/var based, no gaussian window)
mu_a = float(img_a.mean())
mu_b = float(img_b.mean())
sigma_a = float(img_a.var())
sigma_b = float(img_b.var())
sigma_ab = float(((img_a - mu_a) * (img_b - mu_b)).mean())
numerator = (2.0 * mu_a * mu_b + c1) * (2.0 * sigma_ab + c2)
denominator = (mu_a * mu_a + mu_b * mu_b + c1) * (sigma_a + sigma_b + c2) + 1e-8
return float(numerator / denominator)
def _to_gray_resized01(rgb01: np.ndarray, hw: Tuple[int, int]) -> np.ndarray:
# rgb01: [H,W,3] 0..1
x = tf.image.resize(tf.convert_to_tensor(rgb01[None, ...], tf.float32), hw, method="bilinear", antialias=True)[0]
g = tf.reduce_mean(x, axis=-1) # [H,W]
return g.numpy()
def _find_motion_segments_notebook_style(
clip_paths: Sequence[Path],
threshold: float,
ssim_hw: Tuple[int, int],
) -> List[Tuple[int, int, float]]:
"""
Mirror notebook find_motion_segments:
- compute delta_ssim = 1 - ssim_score_np(gray[i-1], gray[i])
- enter motion when delta > threshold (start=i-1)
- exit when delta <= threshold (append (start, i))
- if still in motion at end, append (start, len-1)
We also compute score as mean delta over the segment transitions.
"""
if len(clip_paths) < 2:
return []
# Streamed implementation to avoid loading the whole clip into memory.
from PIL import Image
def load_gray(p: Path) -> np.ndarray:
rgb = np.asarray(Image.open(p).convert("RGB"), dtype=np.float32) / 255.0
return _to_gray_resized01(rgb, ssim_hw)
segments: List[Tuple[int, int, float]] = []
in_motion = False
start_idx = -1
acc = 0.0
acc_n = 0
prev_gray = load_gray(clip_paths[0])
for i in range(1, len(clip_paths)):
cur_gray = load_gray(clip_paths[i])
delta_ssim = 1.0 - _ssim_score_np(prev_gray, cur_gray)
if (not in_motion) and delta_ssim > threshold:
in_motion = True
start_idx = i - 1
acc = float(delta_ssim)
acc_n = 1
elif in_motion and delta_ssim > threshold:
acc += float(delta_ssim)
acc_n += 1
elif in_motion and delta_ssim <= threshold:
end_idx = i # as notebook
score = (acc / acc_n) if acc_n else 0.0
segments.append((start_idx, end_idx, score))
in_motion = False
acc = 0.0
acc_n = 0
prev_gray = cur_gray
if in_motion:
end_idx = len(clip_paths) - 1
score = (acc / acc_n) if acc_n else 0.0
segments.append((start_idx, end_idx, score))
return segments
def _topk_motion_windows(
frames01: np.ndarray,
ssim_hw: Tuple[int, int],
window_len: int,
step: int = 1,
) -> List[Tuple[int, int, float]]:
"""
Sliding windows scored by mean motion (1-SSIM) inside the window.
Returns [(start,end,score)] where end=start+window_len.
"""
if window_len < 2 or frames01.shape[0] < window_len:
return []
ssim = _ssim_adjacent(frames01, ssim_hw) # len T-1
motion = 1.0 - ssim
out: List[Tuple[int, int, float]] = []
last_start = frames01.shape[0] - window_len
for start in range(0, last_start + 1, step):
# motion indices cover transitions; for frames [start, end) with end=start+L,
# the relevant motion slice is [start, end-1) (length L-1)
end = start + window_len
m = motion[start : end - 1]
score = float(np.mean(m)) if m.size else 0.0
out.append((start, end, score))
return out
def pick_best_segments(
frames_dir: Path,
threshold: float = 0.005,
ssim_hw: Tuple[int, int] = (64, 64),
max_gap_ms: int = 32,
min_len: int = 8,
topk: int = 50,
k: int = 4,
) -> Tuple[List[List[Path]], List[Segment]]:
paths = [p for p in frames_dir.iterdir() if p.is_file() and p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp", ".bmp"}]
paths.sort(key=lambda p: _natural_key(p.name))
if not paths:
raise ValueError(f"No frames found in {frames_dir}")
clips = _split_by_timestamp_gap(paths, max_gap_ms=max_gap_ms)
candidates: List[Tuple[List[Path], Segment]] = []
for ci, clip in enumerate(clips):
if len(clip) < min_len:
continue
arr = []
from PIL import Image
for p in clip:
arr.append(np.asarray(Image.open(p).convert("RGB"), dtype=np.float32))
x = np.stack(arr, axis=0)
x01 = np.clip(x / 255.0, 0.0, 1.0)
# Primary: threshold-based motion segments (not always present)
cand: List[Segment] = []
for s, e, sc in _find_motion_segments_in_clip(x01, threshold=threshold, ssim_hw=ssim_hw):
seg = Segment(clip_index=ci, start=s, end=e, score=sc)
if seg.length >= min_len:
cand.append(seg)
# Fallback/augment: always take top motion windows so we can pick "bigger change"
# even when thresholding yields nothing.
windows = _topk_motion_windows(x01, ssim_hw=ssim_hw, window_len=max(min_len, 8), step=1)
windows.sort(key=lambda t: t[2], reverse=True)
for s, e, sc in windows[:topk]:
cand.append(Segment(clip_index=ci, start=s, end=e, score=sc))
for seg in cand:
candidates.append((clip, seg))
if not candidates:
clip = max(clips, key=len)
seg = Segment(clip_index=clips.index(clip), start=0, end=len(clip), score=0.0)
return [clip], [seg]
# Sort all candidates by score then length
candidates.sort(key=lambda cs: (cs[1].score, cs[1].length), reverse=True)
picked_clips: List[List[Path]] = []
picked_segs: List[Segment] = []
for clip, seg in candidates:
if any(seg.overlaps(s) for s in picked_segs):
continue
picked_clips.append(clip)
picked_segs.append(seg)
if len(picked_segs) >= k:
break
return picked_clips, picked_segs
def find_all_motion_segments(
frames_dir: Path,
threshold: float,
ssim_hw: Tuple[int, int],
max_gap_ms: int,
min_len: int,
) -> Tuple[List[List[Path]], List[Segment]]:
# Mirror notebook: do NOT split into clips by timestamp gap.
paths = [p for p in frames_dir.iterdir() if p.is_file() and p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp", ".bmp"}]
paths.sort(key=lambda p: _natural_key(p.name))
if not paths:
raise ValueError(f"No frames found in {frames_dir}")
# Notebook has a single clip here (unless directory-of-directories).
clip = paths
segs = _find_motion_segments_notebook_style(clip, threshold=threshold, ssim_hw=ssim_hw)
out_clips: List[List[Path]] = []
out_segs: List[Segment] = []
for s, e, sc in segs:
# In notebook segments are tuples (start, end) with mixed inclusivity; for our slicing,
# treat as [start, end] inclusive if end < len, else [start, end).
# We'll convert to Python slice end-exclusive conservatively.
end_excl = min(e + 1, len(clip))
seg_len = end_excl - s
if seg_len < min_len:
continue
# Also require there exists at least one valid sequence of length (sequence_len) within this segment
# that passes the adjacent timestamp gap filter, matching notebook behavior.
out_clips.append(clip)
out_segs.append(Segment(clip_index=0, start=s, end=end_excl, score=float(sc)))
order = sorted(range(len(out_segs)), key=lambda i: (out_segs[i].score, out_segs[i].length), reverse=True)
return [out_clips[i] for i in order], [out_segs[i] for i in order]
def _parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="SSIM-filter dataset and generate teacher-forcing predictions.")
ap.add_argument("--data_dir", type=str, required=True, help="Directory with frame images (e.g. .../data/data)")
ap.add_argument("--model_dir", type=str, default="savedmodel", help="SavedModel directory")
ap.add_argument("--out_dir", type=str, default="dataset_outputs", help="Output directory")
ap.add_argument("--threshold", type=float, default=0.005, help="Motion threshold on (1-SSIM)")
ap.add_argument("--ssim_hw", type=str, default="64,64", help="SSIM resize H,W")
ap.add_argument("--max_gap_ms", type=int, default=32, help="Max adjacent timestamp gap (ms)")
ap.add_argument("--min_len", type=int, default=8, help="Min segment length (frames)")
ap.add_argument("--sequence_len", type=int, default=4, help="Model input sequence length (frames)")
ap.add_argument(
"--pad_last_frame",
type=str,
default="none",
choices=["none", "zero", "one", "repeat"],
help="When model expects 4 frames but you provide 3, pad the last frame with this mode.",
)
ap.add_argument("--grid_cols", type=int, default=8, help="Columns for saved grids")
ap.add_argument("--topk_windows", type=int, default=50, help="How many top motion windows to consider per clip")
ap.add_argument("--num_tests", type=int, default=4, help="How many motion regions to run")
ap.add_argument(
"--mode",
type=str,
default="topk",
choices=["topk", "all_motion"],
help="topk: pick non-overlapping top regions; all_motion: run every threshold-based motion segment",
)
return ap.parse_args()
def main() -> None:
args = _parse_args()
data_dir = Path(args.data_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
ssim_hw = tuple(int(x.strip()) for x in args.ssim_hw.split(",")) # type: ignore[assignment]
if args.mode == "all_motion":
clips, segs = find_all_motion_segments(
data_dir,
threshold=args.threshold,
ssim_hw=ssim_hw, # type: ignore[arg-type]
max_gap_ms=args.max_gap_ms,
min_len=args.min_len,
)
print(f"Found motion segments: {len(segs)}")
else:
clips, segs = pick_best_segments(
data_dir,
threshold=args.threshold,
ssim_hw=ssim_hw, # type: ignore[arg-type]
max_gap_ms=args.max_gap_ms,
min_len=args.min_len,
topk=args.topk_windows,
k=args.num_tests,
)
pred = Predictor(args.model_dir)
L = int(args.sequence_len)
from PIL import Image
test_root = out_dir / "test"
test_root.mkdir(parents=True, exist_ok=True)
for idx, (clip, seg) in enumerate(zip(clips, segs), start=1):
seg_paths = clip[seg.start : seg.end]
if len(seg_paths) < L:
continue
test_dir = test_root / str(idx)
(test_dir / "pred").mkdir(parents=True, exist_ok=True)
(test_dir / "gt").mkdir(parents=True, exist_ok=True)
print(f"[test/{idx}] clip={seg.clip_index} start={seg.start} end={seg.end} len={len(seg_paths)} score={seg.score:.6f}")
print(f"[test/{idx}] first={seg_paths[0].name} last={seg_paths[-1].name}")
frames = np.stack([np.asarray(Image.open(p).convert("RGB"), dtype=np.float32) for p in seg_paths], axis=0)
frames01 = np.clip(frames / 255.0, 0.0, 1.0).astype(np.float32)
preds_last: List[np.ndarray] = []
gt_last: List[np.ndarray] = []
kept = 0
skipped = 0
for i in range(L - 1, len(frames01)):
# Teacher forcing with the same "max adjacent timestamp gap" constraint as notebook.
seq_paths = seg_paths[i - (L - 1) : i + 1]
if not _has_valid_adjacent_timestamp_gap(seq_paths, args.max_gap_ms):
skipped += 1
continue
window = frames01[i - (L - 1) : i + 1]
y = pred.predict_last_frame(window, pad_last_frame=args.pad_last_frame)[0]
preds_last.append(y)
gt_last.append(frames01[i])
kept += 1
stem = seg_paths[i].stem
save_image(test_dir / "pred" / f"pred_{i:04d}_{stem}.png", y)
save_image(test_dir / "gt" / f"gt_{i:04d}_{stem}.png", frames01[i])
if preds_last:
save_sequence_grid(test_dir / "pred_last_grid.png", np.stack(preds_last, axis=0), cols=args.grid_cols)
save_sequence_grid(test_dir / "gt_grid.png", np.stack(gt_last, axis=0), cols=args.grid_cols)
(test_dir / "meta.txt").write_text(
"\n".join(
[
f"clip_index={seg.clip_index}",
f"start={seg.start}",
f"end={seg.end}",
f"len={len(seg_paths)}",
f"score_mean_1_minus_ssim={seg.score:.6f}",
f"first_frame={seg_paths[0].name}",
f"last_frame={seg_paths[-1].name}",
f"sequence_len={L}",
f"threshold={args.threshold}",
f"ssim_hw={ssim_hw[0]},{ssim_hw[1]}",
f"max_gap_ms={args.max_gap_ms}",
f"kept_predictions={kept}",
f"skipped_by_gap={skipped}",
"",
]
),
encoding="utf-8",
)
print(f"Wrote tests under: {test_root}")
if __name__ == "__main__":
main()