AuroraImageGen / app.py
lazarus19's picture
Update app.py
15db5ed verified
Raw
History Blame Contribute Delete
2.88 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "lazarus19/AuroraImageGen"
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Load model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
device_map="auto"
)
# Generate function
def generate(
prompt,
max_new_tokens,
temperature,
top_p,
):
if not prompt.strip():
return "Please enter a prompt."
inputs = tokenizer(
prompt,
return_tensors="pt"
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
return response
examples = [
"Write a short story about a robot explorer.",
"Explain quantum computing in simple terms.",
"Create a fantasy character profile.",
]
css = """
#col-container {
margin: 0 auto;
max-width: 900px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# AuroraImageGen Chat")
prompt = gr.Textbox(
label="Prompt",
lines=6,
placeholder="Enter your prompt..."
)
output = gr.Textbox(
label="Response",
lines=20
)
with gr.Accordion("Advanced Settings", open=False):
max_new_tokens = gr.Slider(
minimum=32,
maximum=2048,
value=512,
step=32,
label="Max New Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P"
)
run_button = gr.Button(
"Generate",
variant="primary"
)
gr.Examples(
examples=examples,
inputs=[prompt]
)
run_button.click(
fn=generate,
inputs=[
prompt,
max_new_tokens,
temperature,
top_p,
],
outputs=output,
)
if __name__ == "__main__":
demo.launch()