tinygpt / app /server.py
avi080704's picture
Update app/server.py
d89f393 verified
Raw
History Blame Contribute Delete
10.3 kB
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__)))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'transformer_model'))
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Set CUDA_VISIBLE_DEVICES=-1 in your environment to force CPU inference.
import time
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import Depends, FastAPI, HTTPException
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from database import Base, engine, get_db
import crud as crud
SEQ_LEN = 256
WEIGHTS = os.path.join(os.path.dirname(__file__), '..', 'saved_models', 'tinystories_model.weights.h5')
TOKENIZER = os.path.join(os.path.dirname(__file__), '..', 'saved_models', 'tinystories_tokenizer.json')
INDEX_HTML = os.path.join(os.path.dirname(__file__), '..', 'index.html')
# Temperature is the single creativity slider. Low (left) = safe/coherent, high (right) = wild.
TEMP_MIN = 0.1
TEMP_MAX = 1.5
TEMP_STEP = 0.1
TEMP_DEFAULT = 0.5
REPETITION_PENALTY = 1.3
# Descriptive zones along the slider, low -> high. Each carries the colour the slider
# should show in that region so the track runs green (safe) to red (wild).
TEMPERATURE_ZONES = [
{"from": 0.1, "to": 0.5, "label": "Safe",
"color": "#2ecc71",
"description": "Focused and predictable. The model picks the most likely words, so stories stay simple, calm and easy to follow."},
{"from": 0.5, "to": 0.8, "label": "Balanced",
"color": "#f1c40f",
"description": "A good mix of sense and surprise. Stories stay coherent but take a few fun turns. Recommended for most prompts."},
{"from": 0.8, "to": 1.1, "label": "Creative",
"color": "#e67e22",
"description": "More imaginative. The model reaches for less obvious words and ideas, so stories get more varied and playful."},
{"from": 1.1, "to": 1.5, "label": "Wild",
"color": "#e74c3c",
"description": "Unpredictable and quirky. Anything can happen, but the story may wander or stop making sense."},
]
# Gradient stops used to compute the exact colour for any temperature: green -> yellow -> red.
_GRADIENT = [(0.0, (46, 204, 113)), (0.5, (241, 196, 15)), (1.0, (231, 76, 60))]
def _clamp(v, lo, hi):
return max(lo, min(hi, v))
def temperature_color(temp: float) -> str:
"""Interpolate green -> yellow -> red across the temperature range, return hex."""
pos = _clamp((temp - TEMP_MIN) / (TEMP_MAX - TEMP_MIN), 0.0, 1.0)
for i in range(len(_GRADIENT) - 1):
p0, c0 = _GRADIENT[i]
p1, c1 = _GRADIENT[i + 1]
if p0 <= pos <= p1:
t = 0.0 if p1 == p0 else (pos - p0) / (p1 - p0)
r = round(c0[0] + (c1[0] - c0[0]) * t)
g = round(c0[1] + (c1[1] - c0[1]) * t)
b = round(c0[2] + (c1[2] - c0[2]) * t)
return f"#{r:02x}{g:02x}{b:02x}"
return "#e74c3c"
def describe_temperature(temp: float):
"""Return (label, description, color) for a temperature value."""
for z in TEMPERATURE_ZONES:
if z["from"] <= temp < z["to"]:
return z["label"], z["description"], temperature_color(temp)
# temp == TEMP_MAX falls into the last zone
last = TEMPERATURE_ZONES[-1]
return last["label"], last["description"], temperature_color(temp)
def derive_top_p(temp: float) -> float:
"""One slider controls everything: scale top_p with temperature (0.80 -> 0.95)."""
pos = _clamp((temp - TEMP_MIN) / (TEMP_MAX - TEMP_MIN), 0.0, 1.0)
return round(0.80 + pos * 0.15, 3)
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading model...")
import tensorflow as tf
from model import GPT, generate_text
from tokenizer import HFTokenizer
tokenizer = HFTokenizer()
tokenizer.load(TOKENIZER)
model = GPT(
vocab_size=tokenizer.vocab_size,
d_model=640,
num_heads=10,
dff=2560,
num_layers=10,
max_len=SEQ_LEN,
)
model(tf.zeros((1, SEQ_LEN), dtype=tf.int32), training=False)
model.load_weights(WEIGHTS)
app.state.tf = tf
app.state.model = model
app.state.tokenizer = tokenizer
app.state.generate_text = generate_text
print("Model loaded.")
yield
app = FastAPI(
title="TinyStories GPT",
description="A small GPT language model that writes short children's stories from a prompt.",
version="2.0.0",
lifespan=lifespan,
)
# DB is optional: if it's unreachable (e.g. ephemeral hosting without Postgres),
# the app still serves generations, it just won't persist history.
try:
Base.metadata.create_all(bind=engine)
DB_AVAILABLE = True
except Exception as e:
print(f"DB unavailable, history disabled: {e}")
DB_AVAILABLE = False
app.add_middleware(
CORSMiddleware,
allow_origins=["https://avi080704-tinygpt.hf.space/"],
allow_methods=["*"],
allow_headers=["*"],
)
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=500,
description="How the story should start.")
temperature: float = Field(
default=TEMP_DEFAULT, ge=TEMP_MIN, le=TEMP_MAX,
description="Creativity slider. Low (green) = safe and coherent, high (red) = wild and random.",
)
top_p: Optional[float] = Field(
default=None, ge=0.1, le=1.0,
description="Advanced override for nucleus sampling. Auto-derived from temperature if omitted.",
)
max_new_tokens: int = Field(default=200, ge=10, le=300,
description="Roughly how long the story can get.")
class GenerateResponse(BaseModel):
prompt: str
generated_text: str
temperature: float
temperature_label: str
temperature_description: str
temperature_color: str
top_p: float
max_new_tokens: int
response_time_ms: float
class HealthResponse(BaseModel):
status: str
vocab_size: int
model_params: str
seq_len: int
@app.get("/", include_in_schema=False)
def root():
return FileResponse(INDEX_HTML)
@app.get("/temperature-info")
def temperature_info():
"""Slider configuration + colour zones so the frontend can render a green->red track."""
return {
"min": TEMP_MIN,
"max": TEMP_MAX,
"step": TEMP_STEP,
"default": TEMP_DEFAULT,
"gradient": ["#2ecc71", "#f1c40f", "#e74c3c"],
"zones": TEMPERATURE_ZONES,
}
@app.get("/health", response_model=HealthResponse)
def health():
import tensorflow as tf
total_params = sum(tf.size(w).numpy() for w in app.state.model.trainable_variables)
return HealthResponse(
status="healthy",
vocab_size=app.state.tokenizer.vocab_size,
model_params=f"{total_params / 1e6:.1f}M",
seq_len=SEQ_LEN,
)
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest, db: Session = Depends(get_db)):
tf = app.state.tf
model = app.state.model
tokenizer = app.state.tokenizer
generate_text = app.state.generate_text
temperature = request.temperature
top_p = request.top_p if request.top_p is not None else derive_top_p(temperature)
label, description, color = describe_temperature(temperature)
try:
start_time = time.time()
prompt_tokens = tokenizer.encode(request.prompt)
if len(prompt_tokens) > SEQ_LEN - 10:
prompt_tokens = prompt_tokens[-(SEQ_LEN - 10):]
start_tokens = tf.constant([prompt_tokens], dtype=tf.int32)
output = await run_in_threadpool(
generate_text,
model,
start_tokens,
request.max_new_tokens,
temperature,
None, # top_k
top_p,
getattr(tokenizer, "eos_id", None),
REPETITION_PENALTY,
)
generated = tokenizer.decode(output[0].numpy().tolist())
response_time_ms = round((time.time() - start_time) * 1000, 2)
if DB_AVAILABLE:
try:
crud.save_generation(
db=db,
prompt=request.prompt,
generated_text=generated,
temperature=temperature,
top_p=top_p,
max_new_tokens=request.max_new_tokens,
response_time_ms=response_time_ms,
)
except Exception as e:
print(f"save_generation failed (continuing): {e}")
return GenerateResponse(
prompt=request.prompt,
generated_text=generated,
temperature=temperature,
temperature_label=label,
temperature_description=description,
temperature_color=color,
top_p=top_p,
max_new_tokens=request.max_new_tokens,
response_time_ms=response_time_ms,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/history")
def history(limit: int = 50, db: Session = Depends(get_db)):
records = crud.get_all_generations(db, limit=limit)
return [
{
"id": r.id,
"prompt": r.prompt,
"generated_text": r.generated_text,
"temperature": r.temperature,
"top_p": r.top_p,
"max_new_tokens": r.max_new_tokens,
"response_time_ms": r.response_time_ms,
"created_at": r.created_at,
}
for r in records
]
@app.delete("/history/{generation_id}")
def delete_history_entry(generation_id: int, db: Session = Depends(get_db)):
deleted = crud.delete_generation(db, generation_id)
if not deleted:
raise HTTPException(status_code=404, detail="Record not found")
return {"deleted": generation_id}