Handle custom-domain Gradio CORS preflights

#41
Files changed (1) hide show
  1. app.py +56 -17
app.py CHANGED
@@ -385,7 +385,6 @@ logger.info("All routes configured")
385
 
386
  # Mount the REST API on /api
387
  from fastapi import FastAPI, Request
388
- from fastapi.middleware.cors import CORSMiddleware
389
  from fastapi.responses import RedirectResponse
390
  from starlette.middleware.base import BaseHTTPMiddleware
391
  from api import api_app
@@ -407,30 +406,78 @@ class RootRedirectMiddleware(BaseHTTPMiddleware):
407
 
408
 
409
  class StringifiedGradioJSONMiddleware:
410
- """Normalize JSON bodies double-encoded by the custom-domain proxy.
411
 
412
  Requests sent through index.openhands.dev can arrive at Gradio as a JSON
413
  string containing the real request object, which makes FastAPI validation
414
  reject interactive callbacks with 422. Direct HF Space traffic already sends
415
  proper JSON objects, so this only rewrites bodies that decode to strings.
 
 
 
 
 
416
  """
417
 
 
 
418
  def __init__(self, app):
419
  self.app = app
420
 
421
  async def __call__(self, scope, receive, send):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  if (
423
  scope["type"] == "http"
424
  and scope.get("method") == "POST"
425
  and scope.get("path", "").startswith("/gradio_api/")
426
  ):
427
- headers = {
428
- key.decode("latin-1").lower(): value.decode("latin-1")
429
- for key, value in scope.get("headers", [])
430
- }
431
  content_type = headers.get("content-type", "")
432
  if "application/json" not in content_type:
433
- return await self.app(scope, receive, send)
434
 
435
  body_parts = []
436
  while True:
@@ -466,9 +513,9 @@ class StringifiedGradioJSONMiddleware:
466
  "more_body": False,
467
  }
468
 
469
- return await self.app(scope, replay_receive, send)
470
 
471
- return await self.app(scope, receive, send)
472
 
473
 
474
  # Create a parent FastAPI app with redirect_slashes=False to prevent
@@ -482,14 +529,6 @@ root_app.mount("/api", api_app)
482
 
483
  # Mount Gradio app at root path
484
  app = gr.mount_gradio_app(root_app, demo, path="/")
485
- app = CORSMiddleware(
486
- app,
487
- allow_origins=["https://index.openhands.dev"],
488
- allow_credentials=True,
489
- allow_methods=["*"],
490
- allow_headers=["*"],
491
- expose_headers=["*"],
492
- )
493
  app = StringifiedGradioJSONMiddleware(app)
494
  logger.info("REST API mounted at /api, Gradio app mounted at /")
495
 
 
385
 
386
  # Mount the REST API on /api
387
  from fastapi import FastAPI, Request
 
388
  from fastapi.responses import RedirectResponse
389
  from starlette.middleware.base import BaseHTTPMiddleware
390
  from api import api_app
 
406
 
407
 
408
  class StringifiedGradioJSONMiddleware:
409
+ """Normalize custom-domain Gradio requests before they reach Gradio.
410
 
411
  Requests sent through index.openhands.dev can arrive at Gradio as a JSON
412
  string containing the real request object, which makes FastAPI validation
413
  reject interactive callbacks with 422. Direct HF Space traffic already sends
414
  proper JSON objects, so this only rewrites bodies that decode to strings.
415
+
416
+ The custom domain also loads the Gradio frontend from Vercel while
417
+ window.gradio_config.root points to the hf.space runtime. Chrome therefore
418
+ requires successful credentialed CORS preflights for queue and heartbeat
419
+ endpoints.
420
  """
421
 
422
+ CUSTOM_DOMAIN_ORIGIN = "https://index.openhands.dev"
423
+
424
  def __init__(self, app):
425
  self.app = app
426
 
427
  async def __call__(self, scope, receive, send):
428
+ origin = None
429
+ if scope["type"] == "http":
430
+ headers = {
431
+ key.decode("latin-1").lower(): value.decode("latin-1")
432
+ for key, value in scope.get("headers", [])
433
+ }
434
+ origin = headers.get("origin")
435
+
436
+ should_apply_cors = (
437
+ scope["type"] == "http"
438
+ and scope.get("path", "").startswith("/gradio_api/")
439
+ and origin == self.CUSTOM_DOMAIN_ORIGIN
440
+ )
441
+
442
+ if should_apply_cors:
443
+ cors_headers = [
444
+ (b"access-control-allow-origin", self.CUSTOM_DOMAIN_ORIGIN.encode("latin-1")),
445
+ (b"access-control-allow-credentials", b"true"),
446
+ (b"access-control-allow-methods", b"DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT"),
447
+ (b"access-control-allow-headers", b"*"),
448
+ (b"access-control-expose-headers", b"*"),
449
+ (b"vary", b"Origin"),
450
+ ]
451
+
452
+ if scope.get("method") == "OPTIONS":
453
+ await send({
454
+ "type": "http.response.start",
455
+ "status": 200,
456
+ "headers": cors_headers,
457
+ })
458
+ await send({"type": "http.response.body", "body": b""})
459
+ return
460
+
461
+ async def cors_send(message):
462
+ if message["type"] == "http.response.start":
463
+ message["headers"] = [
464
+ (key, value)
465
+ for key, value in message.get("headers", [])
466
+ if not key.lower().startswith(b"access-control-")
467
+ and key.lower() != b"vary"
468
+ ] + cors_headers
469
+ await send(message)
470
+ else:
471
+ cors_send = send
472
+
473
  if (
474
  scope["type"] == "http"
475
  and scope.get("method") == "POST"
476
  and scope.get("path", "").startswith("/gradio_api/")
477
  ):
 
 
 
 
478
  content_type = headers.get("content-type", "")
479
  if "application/json" not in content_type:
480
+ return await self.app(scope, receive, cors_send)
481
 
482
  body_parts = []
483
  while True:
 
513
  "more_body": False,
514
  }
515
 
516
+ return await self.app(scope, replay_receive, cors_send)
517
 
518
+ return await self.app(scope, receive, cors_send)
519
 
520
 
521
  # Create a parent FastAPI app with redirect_slashes=False to prevent
 
529
 
530
  # Mount Gradio app at root path
531
  app = gr.mount_gradio_app(root_app, demo, path="/")
 
 
 
 
 
 
 
 
532
  app = StringifiedGradioJSONMiddleware(app)
533
  logger.info("REST API mounted at /api, Gradio app mounted at /")
534