next-frame-predict / modeling.py
dvdface's picture
Simplify project structure: rename dirs and files
8b78124 verified
import os
from pathlib import Path
from huggingface_hub import snapshot_download
from .configuration import PredNetConfig
from .infer.predictor import Predictor
class PredNetModel:
"""Black-box TF SavedModel wrapper compatible with AutoModel via trust_remote_code."""
def __init__(self, config: PredNetConfig, predictor: Predictor):
self.config = config
self._predictor = predictor
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
if os.path.isdir(pretrained_model_name_or_path):
local_dir = pretrained_model_name_or_path
else:
local_dir = snapshot_download(pretrained_model_name_or_path)
config = PredNetConfig.from_pretrained(local_dir)
savedmodel_dir = str(Path(local_dir) / "savedmodel")
predictor = Predictor(
model_dir=savedmodel_dir,
resize_hw=tuple(config.resize_hw),
)
return cls(config, predictor)
def predict_sequence(self, frames, pad_last_frame="none"):
return self._predictor.predict_sequence(frames, pad_last_frame=pad_last_frame)
def predict_last_frame(self, frames, pad_last_frame="none"):
return self._predictor.predict_last_frame(frames, pad_last_frame=pad_last_frame)
def predict_outputs(self, frames, pad_last_frame="none"):
return self._predictor.predict_outputs(frames, pad_last_frame=pad_last_frame)