Instructions to use dvdface/next-frame-predict with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- TF-Keras
How to use dvdface/next-frame-predict with TF-Keras:
# Note: 'keras<3.x' or 'tf_keras' must be installed (legacy) # See https://github.com/keras-team/tf-keras for more details. from huggingface_hub import from_pretrained_keras model = from_pretrained_keras("dvdface/next-frame-predict") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import tensorflow as tf | |
| from infer.io import save_image, save_sequence_grid | |
| from infer.predictor import Predictor | |
| def _natural_key(s: str): | |
| return [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)] | |
| def _timestamp_from_name(p: Path) -> Optional[int]: | |
| m = re.match(r"^(\d+)", p.stem) | |
| if not m: | |
| return None | |
| try: | |
| return int(m.group(1)) | |
| except ValueError: | |
| return None | |
| def _parse_unix_timestamp_ms(path: Union[str, Path]) -> Optional[int]: | |
| """ | |
| Mirror notebook logic: | |
| - >=19 digits: ns -> ms | |
| - >=16 digits: us -> ms | |
| - >=13 digits: ms | |
| - >=10 digits: s -> ms | |
| else: None | |
| """ | |
| stem = Path(path).stem | |
| if not stem.isdigit(): | |
| return None | |
| raw = int(stem) | |
| digits = len(stem) | |
| if digits >= 19: | |
| return raw // 1_000_000 | |
| if digits >= 16: | |
| return raw // 1_000 | |
| if digits >= 13: | |
| return raw | |
| if digits >= 10: | |
| return raw * 1_000 | |
| return None | |
| def _has_valid_adjacent_timestamp_gap(paths: Sequence[Path], max_gap_ms: int = 32) -> bool: | |
| timestamps = [_parse_unix_timestamp_ms(p) for p in paths] | |
| if any(ts is None for ts in timestamps): | |
| return True | |
| for prev_ts, cur_ts in zip(timestamps, timestamps[1:]): | |
| if cur_ts < prev_ts or (cur_ts - prev_ts) > max_gap_ms: | |
| return False | |
| return True | |
| def _gap_ms(ts0: int, ts1: int) -> float: | |
| diff = ts1 - ts0 | |
| if diff <= 0: | |
| return 0.0 | |
| # Treat timestamps as ms already (they come from filenames); this is only used | |
| # by the legacy top-k mode which splits clips first. | |
| return float(diff) | |
| def _resize_np(x: np.ndarray, hw: Tuple[int, int]) -> np.ndarray: | |
| t, h, w, c = x.shape | |
| y = tf.image.resize(tf.convert_to_tensor(x, tf.float32), hw, method="bilinear", antialias=True) | |
| return y.numpy() | |
| def _ssim_adjacent(frames01: np.ndarray, ssim_hw: Tuple[int, int]) -> np.ndarray: | |
| # frames01: [T,H,W,C] float32 0..1 | |
| x = frames01 | |
| if x.shape[1:3] != (ssim_hw[0], ssim_hw[1]): | |
| x = _resize_np(x, ssim_hw) | |
| # tf.image.ssim expects [N,H,W,C] | |
| a = tf.convert_to_tensor(x[:-1], tf.float32) | |
| b = tf.convert_to_tensor(x[1:], tf.float32) | |
| s = tf.image.ssim(a, b, max_val=1.0) # [T-1] | |
| return s.numpy() | |
| def _find_motion_segments_in_clip(frames01: np.ndarray, threshold: float, ssim_hw: Tuple[int, int]) -> List[Tuple[int, int, float]]: | |
| """ | |
| Return segments [start,end) where (1-SSIM) > threshold on adjacent pairs. | |
| Used by legacy top-k selection mode. | |
| """ | |
| if frames01.shape[0] < 2: | |
| return [] | |
| ssim = _ssim_adjacent(frames01, ssim_hw) # len T-1 | |
| motion = 1.0 - ssim | |
| mask = motion > threshold | |
| segs: List[Tuple[int, int, float]] = [] | |
| i = 0 | |
| while i < len(mask): | |
| if not mask[i]: | |
| i += 1 | |
| continue | |
| j = i | |
| while j < len(mask) and mask[j]: | |
| j += 1 | |
| start = i | |
| end = j + 1 | |
| score = float(np.mean(motion[i:j])) if j > i else 0.0 | |
| segs.append((start, end, score)) | |
| i = j | |
| return segs | |
| def _split_by_timestamp_gap(paths: List[Path], max_gap_ms: int) -> List[List[Path]]: | |
| clips: List[List[Path]] = [] | |
| cur: List[Path] = [] | |
| prev_ts_ms: Optional[int] = None | |
| for p in paths: | |
| ts_ms = _parse_unix_timestamp_ms(p) | |
| if ts_ms is None: | |
| # If we can't parse timestamp, keep as continuous. | |
| if not cur: | |
| cur = [p] | |
| else: | |
| cur.append(p) | |
| continue | |
| if prev_ts_ms is None: | |
| cur = [p] | |
| prev_ts_ms = ts_ms | |
| continue | |
| if (ts_ms - prev_ts_ms) > int(max_gap_ms): | |
| if cur: | |
| clips.append(cur) | |
| cur = [p] | |
| else: | |
| cur.append(p) | |
| prev_ts_ms = ts_ms | |
| if cur: | |
| clips.append(cur) | |
| return clips | |
| class Segment: | |
| clip_index: int | |
| start: int # inclusive index in clip | |
| end: int # exclusive index in clip | |
| score: float | |
| def length(self) -> int: | |
| return max(0, self.end - self.start) | |
| def sort_key(self) -> Tuple[float, int]: | |
| # Higher score first, then longer | |
| return (self.score, self.length) | |
| def overlaps(self, other: "Segment") -> bool: | |
| if self.clip_index != other.clip_index: | |
| return False | |
| return not (self.end <= other.start or other.end <= self.start) | |
| def _ssim_score_np(img_a: np.ndarray, img_b: np.ndarray, c1: float = 0.01**2, c2: float = 0.03**2) -> float: | |
| # Mirror notebook SSIM (mean/var based, no gaussian window) | |
| mu_a = float(img_a.mean()) | |
| mu_b = float(img_b.mean()) | |
| sigma_a = float(img_a.var()) | |
| sigma_b = float(img_b.var()) | |
| sigma_ab = float(((img_a - mu_a) * (img_b - mu_b)).mean()) | |
| numerator = (2.0 * mu_a * mu_b + c1) * (2.0 * sigma_ab + c2) | |
| denominator = (mu_a * mu_a + mu_b * mu_b + c1) * (sigma_a + sigma_b + c2) + 1e-8 | |
| return float(numerator / denominator) | |
| def _to_gray_resized01(rgb01: np.ndarray, hw: Tuple[int, int]) -> np.ndarray: | |
| # rgb01: [H,W,3] 0..1 | |
| x = tf.image.resize(tf.convert_to_tensor(rgb01[None, ...], tf.float32), hw, method="bilinear", antialias=True)[0] | |
| g = tf.reduce_mean(x, axis=-1) # [H,W] | |
| return g.numpy() | |
| def _find_motion_segments_notebook_style( | |
| clip_paths: Sequence[Path], | |
| threshold: float, | |
| ssim_hw: Tuple[int, int], | |
| ) -> List[Tuple[int, int, float]]: | |
| """ | |
| Mirror notebook find_motion_segments: | |
| - compute delta_ssim = 1 - ssim_score_np(gray[i-1], gray[i]) | |
| - enter motion when delta > threshold (start=i-1) | |
| - exit when delta <= threshold (append (start, i)) | |
| - if still in motion at end, append (start, len-1) | |
| We also compute score as mean delta over the segment transitions. | |
| """ | |
| if len(clip_paths) < 2: | |
| return [] | |
| # Streamed implementation to avoid loading the whole clip into memory. | |
| from PIL import Image | |
| def load_gray(p: Path) -> np.ndarray: | |
| rgb = np.asarray(Image.open(p).convert("RGB"), dtype=np.float32) / 255.0 | |
| return _to_gray_resized01(rgb, ssim_hw) | |
| segments: List[Tuple[int, int, float]] = [] | |
| in_motion = False | |
| start_idx = -1 | |
| acc = 0.0 | |
| acc_n = 0 | |
| prev_gray = load_gray(clip_paths[0]) | |
| for i in range(1, len(clip_paths)): | |
| cur_gray = load_gray(clip_paths[i]) | |
| delta_ssim = 1.0 - _ssim_score_np(prev_gray, cur_gray) | |
| if (not in_motion) and delta_ssim > threshold: | |
| in_motion = True | |
| start_idx = i - 1 | |
| acc = float(delta_ssim) | |
| acc_n = 1 | |
| elif in_motion and delta_ssim > threshold: | |
| acc += float(delta_ssim) | |
| acc_n += 1 | |
| elif in_motion and delta_ssim <= threshold: | |
| end_idx = i # as notebook | |
| score = (acc / acc_n) if acc_n else 0.0 | |
| segments.append((start_idx, end_idx, score)) | |
| in_motion = False | |
| acc = 0.0 | |
| acc_n = 0 | |
| prev_gray = cur_gray | |
| if in_motion: | |
| end_idx = len(clip_paths) - 1 | |
| score = (acc / acc_n) if acc_n else 0.0 | |
| segments.append((start_idx, end_idx, score)) | |
| return segments | |
| def _topk_motion_windows( | |
| frames01: np.ndarray, | |
| ssim_hw: Tuple[int, int], | |
| window_len: int, | |
| step: int = 1, | |
| ) -> List[Tuple[int, int, float]]: | |
| """ | |
| Sliding windows scored by mean motion (1-SSIM) inside the window. | |
| Returns [(start,end,score)] where end=start+window_len. | |
| """ | |
| if window_len < 2 or frames01.shape[0] < window_len: | |
| return [] | |
| ssim = _ssim_adjacent(frames01, ssim_hw) # len T-1 | |
| motion = 1.0 - ssim | |
| out: List[Tuple[int, int, float]] = [] | |
| last_start = frames01.shape[0] - window_len | |
| for start in range(0, last_start + 1, step): | |
| # motion indices cover transitions; for frames [start, end) with end=start+L, | |
| # the relevant motion slice is [start, end-1) (length L-1) | |
| end = start + window_len | |
| m = motion[start : end - 1] | |
| score = float(np.mean(m)) if m.size else 0.0 | |
| out.append((start, end, score)) | |
| return out | |
| def pick_best_segments( | |
| frames_dir: Path, | |
| threshold: float = 0.005, | |
| ssim_hw: Tuple[int, int] = (64, 64), | |
| max_gap_ms: int = 32, | |
| min_len: int = 8, | |
| topk: int = 50, | |
| k: int = 4, | |
| ) -> Tuple[List[List[Path]], List[Segment]]: | |
| paths = [p for p in frames_dir.iterdir() if p.is_file() and p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp", ".bmp"}] | |
| paths.sort(key=lambda p: _natural_key(p.name)) | |
| if not paths: | |
| raise ValueError(f"No frames found in {frames_dir}") | |
| clips = _split_by_timestamp_gap(paths, max_gap_ms=max_gap_ms) | |
| candidates: List[Tuple[List[Path], Segment]] = [] | |
| for ci, clip in enumerate(clips): | |
| if len(clip) < min_len: | |
| continue | |
| arr = [] | |
| from PIL import Image | |
| for p in clip: | |
| arr.append(np.asarray(Image.open(p).convert("RGB"), dtype=np.float32)) | |
| x = np.stack(arr, axis=0) | |
| x01 = np.clip(x / 255.0, 0.0, 1.0) | |
| # Primary: threshold-based motion segments (not always present) | |
| cand: List[Segment] = [] | |
| for s, e, sc in _find_motion_segments_in_clip(x01, threshold=threshold, ssim_hw=ssim_hw): | |
| seg = Segment(clip_index=ci, start=s, end=e, score=sc) | |
| if seg.length >= min_len: | |
| cand.append(seg) | |
| # Fallback/augment: always take top motion windows so we can pick "bigger change" | |
| # even when thresholding yields nothing. | |
| windows = _topk_motion_windows(x01, ssim_hw=ssim_hw, window_len=max(min_len, 8), step=1) | |
| windows.sort(key=lambda t: t[2], reverse=True) | |
| for s, e, sc in windows[:topk]: | |
| cand.append(Segment(clip_index=ci, start=s, end=e, score=sc)) | |
| for seg in cand: | |
| candidates.append((clip, seg)) | |
| if not candidates: | |
| clip = max(clips, key=len) | |
| seg = Segment(clip_index=clips.index(clip), start=0, end=len(clip), score=0.0) | |
| return [clip], [seg] | |
| # Sort all candidates by score then length | |
| candidates.sort(key=lambda cs: (cs[1].score, cs[1].length), reverse=True) | |
| picked_clips: List[List[Path]] = [] | |
| picked_segs: List[Segment] = [] | |
| for clip, seg in candidates: | |
| if any(seg.overlaps(s) for s in picked_segs): | |
| continue | |
| picked_clips.append(clip) | |
| picked_segs.append(seg) | |
| if len(picked_segs) >= k: | |
| break | |
| return picked_clips, picked_segs | |
| def find_all_motion_segments( | |
| frames_dir: Path, | |
| threshold: float, | |
| ssim_hw: Tuple[int, int], | |
| max_gap_ms: int, | |
| min_len: int, | |
| ) -> Tuple[List[List[Path]], List[Segment]]: | |
| # Mirror notebook: do NOT split into clips by timestamp gap. | |
| paths = [p for p in frames_dir.iterdir() if p.is_file() and p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp", ".bmp"}] | |
| paths.sort(key=lambda p: _natural_key(p.name)) | |
| if not paths: | |
| raise ValueError(f"No frames found in {frames_dir}") | |
| # Notebook has a single clip here (unless directory-of-directories). | |
| clip = paths | |
| segs = _find_motion_segments_notebook_style(clip, threshold=threshold, ssim_hw=ssim_hw) | |
| out_clips: List[List[Path]] = [] | |
| out_segs: List[Segment] = [] | |
| for s, e, sc in segs: | |
| # In notebook segments are tuples (start, end) with mixed inclusivity; for our slicing, | |
| # treat as [start, end] inclusive if end < len, else [start, end). | |
| # We'll convert to Python slice end-exclusive conservatively. | |
| end_excl = min(e + 1, len(clip)) | |
| seg_len = end_excl - s | |
| if seg_len < min_len: | |
| continue | |
| # Also require there exists at least one valid sequence of length (sequence_len) within this segment | |
| # that passes the adjacent timestamp gap filter, matching notebook behavior. | |
| out_clips.append(clip) | |
| out_segs.append(Segment(clip_index=0, start=s, end=end_excl, score=float(sc))) | |
| order = sorted(range(len(out_segs)), key=lambda i: (out_segs[i].score, out_segs[i].length), reverse=True) | |
| return [out_clips[i] for i in order], [out_segs[i] for i in order] | |
| def _parse_args() -> argparse.Namespace: | |
| ap = argparse.ArgumentParser(description="SSIM-filter dataset and generate teacher-forcing predictions.") | |
| ap.add_argument("--data_dir", type=str, required=True, help="Directory with frame images (e.g. .../data/data)") | |
| ap.add_argument("--model_dir", type=str, default="savedmodel", help="SavedModel directory") | |
| ap.add_argument("--out_dir", type=str, default="dataset_outputs", help="Output directory") | |
| ap.add_argument("--threshold", type=float, default=0.005, help="Motion threshold on (1-SSIM)") | |
| ap.add_argument("--ssim_hw", type=str, default="64,64", help="SSIM resize H,W") | |
| ap.add_argument("--max_gap_ms", type=int, default=32, help="Max adjacent timestamp gap (ms)") | |
| ap.add_argument("--min_len", type=int, default=8, help="Min segment length (frames)") | |
| ap.add_argument("--sequence_len", type=int, default=4, help="Model input sequence length (frames)") | |
| ap.add_argument( | |
| "--pad_last_frame", | |
| type=str, | |
| default="none", | |
| choices=["none", "zero", "one", "repeat"], | |
| help="When model expects 4 frames but you provide 3, pad the last frame with this mode.", | |
| ) | |
| ap.add_argument("--grid_cols", type=int, default=8, help="Columns for saved grids") | |
| ap.add_argument("--topk_windows", type=int, default=50, help="How many top motion windows to consider per clip") | |
| ap.add_argument("--num_tests", type=int, default=4, help="How many motion regions to run") | |
| ap.add_argument( | |
| "--mode", | |
| type=str, | |
| default="topk", | |
| choices=["topk", "all_motion"], | |
| help="topk: pick non-overlapping top regions; all_motion: run every threshold-based motion segment", | |
| ) | |
| return ap.parse_args() | |
| def main() -> None: | |
| args = _parse_args() | |
| data_dir = Path(args.data_dir) | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| ssim_hw = tuple(int(x.strip()) for x in args.ssim_hw.split(",")) # type: ignore[assignment] | |
| if args.mode == "all_motion": | |
| clips, segs = find_all_motion_segments( | |
| data_dir, | |
| threshold=args.threshold, | |
| ssim_hw=ssim_hw, # type: ignore[arg-type] | |
| max_gap_ms=args.max_gap_ms, | |
| min_len=args.min_len, | |
| ) | |
| print(f"Found motion segments: {len(segs)}") | |
| else: | |
| clips, segs = pick_best_segments( | |
| data_dir, | |
| threshold=args.threshold, | |
| ssim_hw=ssim_hw, # type: ignore[arg-type] | |
| max_gap_ms=args.max_gap_ms, | |
| min_len=args.min_len, | |
| topk=args.topk_windows, | |
| k=args.num_tests, | |
| ) | |
| pred = Predictor(args.model_dir) | |
| L = int(args.sequence_len) | |
| from PIL import Image | |
| test_root = out_dir / "test" | |
| test_root.mkdir(parents=True, exist_ok=True) | |
| for idx, (clip, seg) in enumerate(zip(clips, segs), start=1): | |
| seg_paths = clip[seg.start : seg.end] | |
| if len(seg_paths) < L: | |
| continue | |
| test_dir = test_root / str(idx) | |
| (test_dir / "pred").mkdir(parents=True, exist_ok=True) | |
| (test_dir / "gt").mkdir(parents=True, exist_ok=True) | |
| print(f"[test/{idx}] clip={seg.clip_index} start={seg.start} end={seg.end} len={len(seg_paths)} score={seg.score:.6f}") | |
| print(f"[test/{idx}] first={seg_paths[0].name} last={seg_paths[-1].name}") | |
| frames = np.stack([np.asarray(Image.open(p).convert("RGB"), dtype=np.float32) for p in seg_paths], axis=0) | |
| frames01 = np.clip(frames / 255.0, 0.0, 1.0).astype(np.float32) | |
| preds_last: List[np.ndarray] = [] | |
| gt_last: List[np.ndarray] = [] | |
| kept = 0 | |
| skipped = 0 | |
| for i in range(L - 1, len(frames01)): | |
| # Teacher forcing with the same "max adjacent timestamp gap" constraint as notebook. | |
| seq_paths = seg_paths[i - (L - 1) : i + 1] | |
| if not _has_valid_adjacent_timestamp_gap(seq_paths, args.max_gap_ms): | |
| skipped += 1 | |
| continue | |
| window = frames01[i - (L - 1) : i + 1] | |
| y = pred.predict_last_frame(window, pad_last_frame=args.pad_last_frame)[0] | |
| preds_last.append(y) | |
| gt_last.append(frames01[i]) | |
| kept += 1 | |
| stem = seg_paths[i].stem | |
| save_image(test_dir / "pred" / f"pred_{i:04d}_{stem}.png", y) | |
| save_image(test_dir / "gt" / f"gt_{i:04d}_{stem}.png", frames01[i]) | |
| if preds_last: | |
| save_sequence_grid(test_dir / "pred_last_grid.png", np.stack(preds_last, axis=0), cols=args.grid_cols) | |
| save_sequence_grid(test_dir / "gt_grid.png", np.stack(gt_last, axis=0), cols=args.grid_cols) | |
| (test_dir / "meta.txt").write_text( | |
| "\n".join( | |
| [ | |
| f"clip_index={seg.clip_index}", | |
| f"start={seg.start}", | |
| f"end={seg.end}", | |
| f"len={len(seg_paths)}", | |
| f"score_mean_1_minus_ssim={seg.score:.6f}", | |
| f"first_frame={seg_paths[0].name}", | |
| f"last_frame={seg_paths[-1].name}", | |
| f"sequence_len={L}", | |
| f"threshold={args.threshold}", | |
| f"ssim_hw={ssim_hw[0]},{ssim_hw[1]}", | |
| f"max_gap_ms={args.max_gap_ms}", | |
| f"kept_predictions={kept}", | |
| f"skipped_by_gap={skipped}", | |
| "", | |
| ] | |
| ), | |
| encoding="utf-8", | |
| ) | |
| print(f"Wrote tests under: {test_root}") | |
| if __name__ == "__main__": | |
| main() | |