| import sys, os
|
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__)))
|
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'transformer_model'))
|
|
|
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
|
|
|
| import time
|
| from contextlib import asynccontextmanager
|
| from typing import Optional
|
|
|
| from fastapi import Depends, FastAPI, HTTPException
|
| from fastapi.concurrency import run_in_threadpool
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from fastapi.responses import FileResponse
|
| from pydantic import BaseModel, Field
|
| from sqlalchemy.orm import Session
|
|
|
| from database import Base, engine, get_db
|
| import crud as crud
|
|
|
| SEQ_LEN = 256
|
| WEIGHTS = os.path.join(os.path.dirname(__file__), '..', 'saved_models', 'tinystories_model.weights.h5')
|
| TOKENIZER = os.path.join(os.path.dirname(__file__), '..', 'saved_models', 'tinystories_tokenizer.json')
|
| INDEX_HTML = os.path.join(os.path.dirname(__file__), '..', 'index.html')
|
|
|
|
|
| TEMP_MIN = 0.1
|
| TEMP_MAX = 1.5
|
| TEMP_STEP = 0.1
|
| TEMP_DEFAULT = 0.5
|
| REPETITION_PENALTY = 1.3
|
|
|
|
|
|
|
| TEMPERATURE_ZONES = [
|
| {"from": 0.1, "to": 0.5, "label": "Safe",
|
| "color": "#2ecc71",
|
| "description": "Focused and predictable. The model picks the most likely words, so stories stay simple, calm and easy to follow."},
|
| {"from": 0.5, "to": 0.8, "label": "Balanced",
|
| "color": "#f1c40f",
|
| "description": "A good mix of sense and surprise. Stories stay coherent but take a few fun turns. Recommended for most prompts."},
|
| {"from": 0.8, "to": 1.1, "label": "Creative",
|
| "color": "#e67e22",
|
| "description": "More imaginative. The model reaches for less obvious words and ideas, so stories get more varied and playful."},
|
| {"from": 1.1, "to": 1.5, "label": "Wild",
|
| "color": "#e74c3c",
|
| "description": "Unpredictable and quirky. Anything can happen, but the story may wander or stop making sense."},
|
| ]
|
|
|
|
|
| _GRADIENT = [(0.0, (46, 204, 113)), (0.5, (241, 196, 15)), (1.0, (231, 76, 60))]
|
|
|
|
|
| def _clamp(v, lo, hi):
|
| return max(lo, min(hi, v))
|
|
|
|
|
| def temperature_color(temp: float) -> str:
|
| """Interpolate green -> yellow -> red across the temperature range, return hex."""
|
| pos = _clamp((temp - TEMP_MIN) / (TEMP_MAX - TEMP_MIN), 0.0, 1.0)
|
| for i in range(len(_GRADIENT) - 1):
|
| p0, c0 = _GRADIENT[i]
|
| p1, c1 = _GRADIENT[i + 1]
|
| if p0 <= pos <= p1:
|
| t = 0.0 if p1 == p0 else (pos - p0) / (p1 - p0)
|
| r = round(c0[0] + (c1[0] - c0[0]) * t)
|
| g = round(c0[1] + (c1[1] - c0[1]) * t)
|
| b = round(c0[2] + (c1[2] - c0[2]) * t)
|
| return f"#{r:02x}{g:02x}{b:02x}"
|
| return "#e74c3c"
|
|
|
|
|
| def describe_temperature(temp: float):
|
| """Return (label, description, color) for a temperature value."""
|
| for z in TEMPERATURE_ZONES:
|
| if z["from"] <= temp < z["to"]:
|
| return z["label"], z["description"], temperature_color(temp)
|
|
|
| last = TEMPERATURE_ZONES[-1]
|
| return last["label"], last["description"], temperature_color(temp)
|
|
|
|
|
| def derive_top_p(temp: float) -> float:
|
| """One slider controls everything: scale top_p with temperature (0.80 -> 0.95)."""
|
| pos = _clamp((temp - TEMP_MIN) / (TEMP_MAX - TEMP_MIN), 0.0, 1.0)
|
| return round(0.80 + pos * 0.15, 3)
|
|
|
|
|
| @asynccontextmanager
|
| async def lifespan(app: FastAPI):
|
| print("Loading model...")
|
| import tensorflow as tf
|
| from model import GPT, generate_text
|
| from tokenizer import HFTokenizer
|
|
|
| tokenizer = HFTokenizer()
|
| tokenizer.load(TOKENIZER)
|
|
|
| model = GPT(
|
| vocab_size=tokenizer.vocab_size,
|
| d_model=640,
|
| num_heads=10,
|
| dff=2560,
|
| num_layers=10,
|
| max_len=SEQ_LEN,
|
| )
|
| model(tf.zeros((1, SEQ_LEN), dtype=tf.int32), training=False)
|
| model.load_weights(WEIGHTS)
|
|
|
| app.state.tf = tf
|
| app.state.model = model
|
| app.state.tokenizer = tokenizer
|
| app.state.generate_text = generate_text
|
| print("Model loaded.")
|
| yield
|
|
|
|
|
| app = FastAPI(
|
| title="TinyStories GPT",
|
| description="A small GPT language model that writes short children's stories from a prompt.",
|
| version="2.0.0",
|
| lifespan=lifespan,
|
| )
|
|
|
|
|
| try:
|
| Base.metadata.create_all(bind=engine)
|
| DB_AVAILABLE = True
|
| except Exception as e:
|
| print(f"DB unavailable, history disabled: {e}")
|
| DB_AVAILABLE = False
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["https://avi080704-tinygpt.hf.space/"],
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| class GenerateRequest(BaseModel):
|
| prompt: str = Field(..., min_length=1, max_length=500,
|
| description="How the story should start.")
|
| temperature: float = Field(
|
| default=TEMP_DEFAULT, ge=TEMP_MIN, le=TEMP_MAX,
|
| description="Creativity slider. Low (green) = safe and coherent, high (red) = wild and random.",
|
| )
|
| top_p: Optional[float] = Field(
|
| default=None, ge=0.1, le=1.0,
|
| description="Advanced override for nucleus sampling. Auto-derived from temperature if omitted.",
|
| )
|
| max_new_tokens: int = Field(default=200, ge=10, le=300,
|
| description="Roughly how long the story can get.")
|
|
|
|
|
| class GenerateResponse(BaseModel):
|
| prompt: str
|
| generated_text: str
|
| temperature: float
|
| temperature_label: str
|
| temperature_description: str
|
| temperature_color: str
|
| top_p: float
|
| max_new_tokens: int
|
| response_time_ms: float
|
|
|
|
|
| class HealthResponse(BaseModel):
|
| status: str
|
| vocab_size: int
|
| model_params: str
|
| seq_len: int
|
|
|
|
|
| @app.get("/", include_in_schema=False)
|
| def root():
|
| return FileResponse(INDEX_HTML)
|
|
|
|
|
| @app.get("/temperature-info")
|
| def temperature_info():
|
| """Slider configuration + colour zones so the frontend can render a green->red track."""
|
| return {
|
| "min": TEMP_MIN,
|
| "max": TEMP_MAX,
|
| "step": TEMP_STEP,
|
| "default": TEMP_DEFAULT,
|
| "gradient": ["#2ecc71", "#f1c40f", "#e74c3c"],
|
| "zones": TEMPERATURE_ZONES,
|
| }
|
|
|
|
|
| @app.get("/health", response_model=HealthResponse)
|
| def health():
|
| import tensorflow as tf
|
| total_params = sum(tf.size(w).numpy() for w in app.state.model.trainable_variables)
|
| return HealthResponse(
|
| status="healthy",
|
| vocab_size=app.state.tokenizer.vocab_size,
|
| model_params=f"{total_params / 1e6:.1f}M",
|
| seq_len=SEQ_LEN,
|
| )
|
|
|
|
|
| @app.post("/generate", response_model=GenerateResponse)
|
| async def generate(request: GenerateRequest, db: Session = Depends(get_db)):
|
| tf = app.state.tf
|
| model = app.state.model
|
| tokenizer = app.state.tokenizer
|
| generate_text = app.state.generate_text
|
|
|
| temperature = request.temperature
|
| top_p = request.top_p if request.top_p is not None else derive_top_p(temperature)
|
| label, description, color = describe_temperature(temperature)
|
|
|
| try:
|
| start_time = time.time()
|
|
|
| prompt_tokens = tokenizer.encode(request.prompt)
|
| if len(prompt_tokens) > SEQ_LEN - 10:
|
| prompt_tokens = prompt_tokens[-(SEQ_LEN - 10):]
|
|
|
| start_tokens = tf.constant([prompt_tokens], dtype=tf.int32)
|
|
|
| output = await run_in_threadpool(
|
| generate_text,
|
| model,
|
| start_tokens,
|
| request.max_new_tokens,
|
| temperature,
|
| None,
|
| top_p,
|
| getattr(tokenizer, "eos_id", None),
|
| REPETITION_PENALTY,
|
| )
|
|
|
| generated = tokenizer.decode(output[0].numpy().tolist())
|
| response_time_ms = round((time.time() - start_time) * 1000, 2)
|
|
|
| if DB_AVAILABLE:
|
| try:
|
| crud.save_generation(
|
| db=db,
|
| prompt=request.prompt,
|
| generated_text=generated,
|
| temperature=temperature,
|
| top_p=top_p,
|
| max_new_tokens=request.max_new_tokens,
|
| response_time_ms=response_time_ms,
|
| )
|
| except Exception as e:
|
| print(f"save_generation failed (continuing): {e}")
|
|
|
| return GenerateResponse(
|
| prompt=request.prompt,
|
| generated_text=generated,
|
| temperature=temperature,
|
| temperature_label=label,
|
| temperature_description=description,
|
| temperature_color=color,
|
| top_p=top_p,
|
| max_new_tokens=request.max_new_tokens,
|
| response_time_ms=response_time_ms,
|
| )
|
|
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
| @app.get("/history")
|
| def history(limit: int = 50, db: Session = Depends(get_db)):
|
| records = crud.get_all_generations(db, limit=limit)
|
| return [
|
| {
|
| "id": r.id,
|
| "prompt": r.prompt,
|
| "generated_text": r.generated_text,
|
| "temperature": r.temperature,
|
| "top_p": r.top_p,
|
| "max_new_tokens": r.max_new_tokens,
|
| "response_time_ms": r.response_time_ms,
|
| "created_at": r.created_at,
|
| }
|
| for r in records
|
| ]
|
|
|
|
|
| @app.delete("/history/{generation_id}")
|
| def delete_history_entry(generation_id: int, db: Session = Depends(get_db)):
|
| deleted = crud.delete_generation(db, generation_id)
|
| if not deleted:
|
| raise HTTPException(status_code=404, detail="Record not found")
|
| return {"deleted": generation_id}
|
|
|