Spaces:
Build error
Build error
| import sys | |
| from pathlib import Path | |
| from typing import Any, Generator | |
| if sys.version_info[:2] >= (3, 11): | |
| from typing import Annotated | |
| else: | |
| from typing_extensions import Annotated | |
| from fastapi import Depends, FastAPI, Request, status | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse | |
| from gymnasium.wrappers.record_video import RecordVideo | |
| from loguru import logger | |
| from litrl.algo.mcts.agent import MCTSAgent | |
| from litrl.common.agent import RandomAgent | |
| from litrl.env.make import make | |
| from litrl.env.typing import GymId | |
| from src.app_state import AppState | |
| from src.huggingface.huggingface_client import HuggingFaceClient | |
| from src.typing import BotResponseType, CpuConfig, GridResponseType | |
| def stream_mp4(mp4_path: Path) -> StreamingResponse: | |
| def iter_file() -> Generator[bytes, Any, None]: | |
| with mp4_path.open(mode="rb") as env_file: | |
| yield from env_file | |
| return StreamingResponse(content=iter_file(), media_type="video/mp4") | |
| def get_app_state() -> AppState: | |
| return AppState() | |
| def create_app() -> FastAPI: # noqa: C901 # TODO move to routes | |
| app = FastAPI() | |
| async def redirect_to_docs() -> RedirectResponse: | |
| return RedirectResponse("/docs") | |
| def endpoint_play( | |
| action: int, | |
| app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
| ) -> GridResponseType: | |
| return app_state.step(action) | |
| def endpoint_observe( | |
| app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
| ) -> GridResponseType: | |
| return app_state.observe() | |
| def endpoint_bot_play( | |
| cpu_config: CpuConfig, | |
| app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
| ) -> BotResponseType: | |
| app_state.set_config(cpu_config) | |
| action = app_state.get_action() | |
| response = app_state.step(action) | |
| return BotResponseType( | |
| grid=response.grid, | |
| done=response.done, | |
| action=action, | |
| ) | |
| def endpoint_bot_progress( | |
| app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
| ) -> float: | |
| if isinstance(app_state.agent, MCTSAgent): | |
| if app_state.cpu_config.simulations is None: | |
| raise ValueError | |
| if app_state.agent.mcts is None: | |
| return 1.0 | |
| return float( | |
| app_state.agent.mcts.root.visits / app_state.cpu_config.simulations, | |
| ) # TODO why not recognized as float? | |
| return 1.0 | |
| def endpoint_reset( | |
| app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
| ) -> GridResponseType: | |
| return app_state.reset() | |
| def endpoint_get_huggingface_video( | |
| env_id: GymId, | |
| hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)], | |
| ) -> StreamingResponse: | |
| hf_client.mp4_paths[env_id] | |
| return stream_mp4(mp4_path=hf_client.mp4_paths[env_id]) | |
| def endpoint_get_env_video( | |
| env_id: GymId, | |
| ) -> StreamingResponse: | |
| env = make(id=env_id, render_mode="rgb_array") | |
| env = RecordVideo( | |
| env=env, | |
| video_folder="tmp", | |
| ) | |
| env.reset(seed=123) | |
| if env.video_recorder is None: | |
| msg = "env.video_recorder is None" | |
| raise ValueError(msg) | |
| agent = RandomAgent[Any, Any]() | |
| terminated, truncated = False, False | |
| while not (terminated or truncated): | |
| action = agent.get_action(env=env) # type: ignore[arg-type] | |
| _, _, terminated, truncated, _ = env.step(action=action) | |
| env.render() | |
| env.video_recorder.close() | |
| return stream_mp4(mp4_path=Path(env.video_recorder.path)) | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
| logger.debug(f"url: {request.url}") | |
| if hasattr(request, "_body"): | |
| logger.debug(f"body: {request._body.decode()}") # noqa: SLF001 | |
| logger.debug(f"header: {request.headers}") | |
| logger.error(f"{request}: {exc}") | |
| exc_str = str(exc).replace("\n", " ").replace(" ", " ") | |
| content = {"status_code": 10422, "message": exc_str, "data": None} | |
| return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) | |
| app.add_middleware( | |
| middleware_class=CORSMiddleware, | |
| allow_origins="*", | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| return app | |