| import { useEffect, useState, useRef } from "react"; |
|
|
| import Chat from "./components/Chat"; |
| import ArrowRightIcon from "./components/icons/ArrowRightIcon"; |
| import StopIcon from "./components/icons/StopIcon"; |
| import Progress from "./components/Progress"; |
|
|
| const IS_WEBGPU_AVAILABLE = !!navigator.gpu; |
| const STICKY_SCROLL_THRESHOLD = 120; |
| const EXAMPLES = [ |
| "Give me some tips to improve my time management skills.", |
| "What is the difference between AI and ML?", |
| "Write python code to compute the nth fibonacci number.", |
| ]; |
|
|
| function App() { |
| |
| const worker = useRef(null); |
|
|
| const textareaRef = useRef(null); |
| const chatContainerRef = useRef(null); |
|
|
| |
| const [status, setStatus] = useState(null); |
| const [error, setError] = useState(null); |
| const [loadingMessage, setLoadingMessage] = useState(""); |
| const [progressItems, setProgressItems] = useState([]); |
| const [isRunning, setIsRunning] = useState(false); |
|
|
| |
| const [input, setInput] = useState(""); |
| const [messages, setMessages] = useState([]); |
| const [tps, setTps] = useState(null); |
| const [numTokens, setNumTokens] = useState(null); |
|
|
| function onEnter(message) { |
| setMessages((prev) => [...prev, { role: "user", content: message }]); |
| setTps(null); |
| setIsRunning(true); |
| setInput(""); |
| } |
|
|
| function onInterrupt() { |
| |
| |
| worker.current.postMessage({ type: "interrupt" }); |
| } |
|
|
| useEffect(() => { |
| resizeInput(); |
| }, [input]); |
|
|
| function resizeInput() { |
| if (!textareaRef.current) return; |
|
|
| const target = textareaRef.current; |
| target.style.height = "auto"; |
| const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200); |
| target.style.height = `${newHeight}px`; |
| } |
|
|
| |
| useEffect(() => { |
| |
| if (!worker.current) { |
| worker.current = new Worker(new URL("./worker.js", import.meta.url), { |
| type: "module", |
| }); |
| worker.current.postMessage({ type: "check" }); |
| } |
|
|
| |
| const onMessageReceived = (e) => { |
| switch (e.data.status) { |
| case "loading": |
| |
| setStatus("loading"); |
| setLoadingMessage(e.data.data); |
| break; |
|
|
| case "initiate": |
| setProgressItems((prev) => [...prev, e.data]); |
| break; |
|
|
| case "progress": |
| |
| setProgressItems((prev) => |
| prev.map((item) => { |
| if (item.file === e.data.file) { |
| return { ...item, ...e.data }; |
| } |
| return item; |
| }), |
| ); |
| break; |
|
|
| case "done": |
| |
| setProgressItems((prev) => |
| prev.filter((item) => item.file !== e.data.file), |
| ); |
| break; |
|
|
| case "ready": |
| |
| setStatus("ready"); |
| break; |
|
|
| case "start": |
| { |
| |
| setMessages((prev) => [ |
| ...prev, |
| { role: "assistant", content: "" }, |
| ]); |
| } |
| break; |
|
|
| case "update": |
| { |
| |
| |
| const { output, tps, numTokens } = e.data; |
| setTps(tps); |
| setNumTokens(numTokens); |
| setMessages((prev) => { |
| const cloned = [...prev]; |
| const last = cloned.at(-1); |
| cloned[cloned.length - 1] = { |
| ...last, |
| content: last.content + output, |
| }; |
| return cloned; |
| }); |
| } |
| break; |
|
|
| case "complete": |
| |
| setIsRunning(false); |
| break; |
|
|
| case "error": |
| setError(e.data.data); |
| break; |
| } |
| }; |
|
|
| const onErrorReceived = (e) => { |
| console.error("Worker error:", e); |
| }; |
|
|
| |
| worker.current.addEventListener("message", onMessageReceived); |
| worker.current.addEventListener("error", onErrorReceived); |
|
|
| |
| return () => { |
| worker.current.removeEventListener("message", onMessageReceived); |
| worker.current.removeEventListener("error", onErrorReceived); |
| }; |
| }, []); |
|
|
| |
| useEffect(() => { |
| if (messages.filter((x) => x.role === "user").length === 0) { |
| |
| return; |
| } |
| if (messages.at(-1).role === "assistant") { |
| |
| return; |
| } |
| setTps(null); |
| worker.current.postMessage({ type: "generate", data: messages }); |
| }, [messages, isRunning]); |
|
|
| useEffect(() => { |
| if (!chatContainerRef.current || !isRunning) return; |
| const element = chatContainerRef.current; |
| if ( |
| element.scrollHeight - element.scrollTop - element.clientHeight < |
| STICKY_SCROLL_THRESHOLD |
| ) { |
| element.scrollTop = element.scrollHeight; |
| } |
| }, [messages, isRunning]); |
|
|
| return IS_WEBGPU_AVAILABLE ? ( |
| <div className="flex flex-col h-screen mx-auto items justify-end text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900"> |
| {status === null && messages.length === 0 && ( |
| <div className="h-full overflow-auto scrollbar-thin flex justify-center items-center flex-col relative"> |
| <div className="flex flex-col items-center mb-1 max-w-[340px] text-center"> |
| <img |
| src="logo.png" |
| width="75%" |
| height="auto" |
| className="block" |
| ></img> |
| <h1 className="text-4xl font-bold mb-1">Llama-3.2 WebGPU</h1> |
| <h2 className="font-semibold"> |
| A private and powerful AI chatbot <br /> |
| that runs locally in your browser. |
| </h2> |
| </div> |
| |
| <div className="flex flex-col items-center px-4"> |
| <p className="max-w-[514px] mb-4"> |
| <br /> |
| You are about to load{" "} |
| <a |
| href="https://huggingface.co/onnx-community/Llama-3.2-1B-Instruct-ONNX" |
| target="_blank" |
| rel="noreferrer" |
| className="font-medium underline" |
| > |
| Llama-3.2-1B-Instruct |
| </a> |
| , a 1.24 billion parameter LLM that is optimized for inference on |
| the web. Once downloaded, the model (1.01 GB) will be cached |
| and reused when you revisit the page. |
| <br /> |
| <br /> |
| Everything runs directly in your browser using{" "} |
| <a |
| href="https://huggingface.co/docs/transformers.js" |
| target="_blank" |
| rel="noreferrer" |
| className="underline" |
| > |
| 🤗 Transformers.js |
| </a>{" "} |
| and ONNX Runtime Web, meaning your conversations aren't sent |
| to a server. You can even disconnect from the internet after the |
| model has loaded! |
| <br /> |
| Want to learn more? Check out the demo's source code on{" "} |
| <a |
| href="https://github.com/huggingface/transformers.js-examples/tree/main/llama-3.2-webgpu" |
| target="_blank" |
| rel="noreferrer" |
| className="underline" |
| > |
| GitHub |
| </a> |
| ! |
| </p> |
| |
| {error && ( |
| <div className="text-red-500 text-center mb-2"> |
| <p className="mb-1"> |
| Unable to load model due to the following error: |
| </p> |
| <p className="text-sm">{error}</p> |
| </div> |
| )} |
| |
| <button |
| className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none" |
| onClick={() => { |
| worker.current.postMessage({ type: "load" }); |
| setStatus("loading"); |
| }} |
| disabled={status !== null || error !== null} |
| > |
| Load model |
| </button> |
| </div> |
| </div> |
| )} |
| {status === "loading" && ( |
| <> |
| <div className="w-full max-w-[500px] text-left mx-auto p-4 bottom-0 mt-auto"> |
| <p className="text-center mb-1">{loadingMessage}</p> |
| {progressItems.map(({ file, progress, total }, i) => ( |
| <Progress |
| key={i} |
| text={file} |
| percentage={progress} |
| total={total} |
| /> |
| ))} |
| </div> |
| </> |
| )} |
|
|
| {status === "ready" && ( |
| <div |
| ref={chatContainerRef} |
| className="overflow-y-auto scrollbar-thin w-full flex flex-col items-center h-full" |
| > |
| <Chat messages={messages} /> |
| {messages.length === 0 && ( |
| <div> |
| {EXAMPLES.map((msg, i) => ( |
| <div |
| key={i} |
| className="m-1 border dark:border-gray-600 rounded-md p-2 bg-gray-100 dark:bg-gray-700 cursor-pointer" |
| onClick={() => onEnter(msg)} |
| > |
| {msg} |
| </div> |
| ))} |
| </div> |
| )} |
| <p className="text-center text-sm min-h-6 text-gray-500 dark:text-gray-300"> |
| {tps && messages.length > 0 && ( |
| <> |
| {!isRunning && ( |
| <span> |
| Generated {numTokens} tokens in{" "} |
| {(numTokens / tps).toFixed(2)} seconds ( |
| </span> |
| )} |
| { |
| <> |
| <span className="font-medium text-center mr-1 text-black dark:text-white"> |
| {tps.toFixed(2)} |
| </span> |
| <span className="text-gray-500 dark:text-gray-300"> |
| tokens/second |
| </span> |
| </> |
| } |
| {!isRunning && ( |
| <> |
| <span className="mr-1">).</span> |
| <span |
| className="underline cursor-pointer" |
| onClick={() => { |
| worker.current.postMessage({ type: "reset" }); |
| setMessages([]); |
| }} |
| > |
| Reset |
| </span> |
| </> |
| )} |
| </> |
| )} |
| </p> |
| </div> |
| )} |
|
|
| <div className="mt-2 border dark:bg-gray-700 rounded-lg w-[600px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex"> |
| <textarea |
| ref={textareaRef} |
| className="scrollbar-thin w-[550px] dark:bg-gray-700 px-3 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-400 disabled:placeholder-gray-200 resize-none disabled:cursor-not-allowed" |
| placeholder="Type your message..." |
| type="text" |
| rows={1} |
| value={input} |
| disabled={status !== "ready"} |
| title={status === "ready" ? "Model is ready" : "Model not loaded yet"} |
| onKeyDown={(e) => { |
| if ( |
| input.length > 0 && |
| !isRunning && |
| e.key === "Enter" && |
| !e.shiftKey |
| ) { |
| e.preventDefault(); // Prevent default behavior of Enter key |
| onEnter(input); |
| } |
| }} |
| onInput={(e) => setInput(e.target.value)} |
| /> |
| {isRunning ? ( |
| <div className="cursor-pointer" onClick={onInterrupt}> |
| <StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" /> |
| </div> |
| ) : input.length > 0 ? ( |
| <div className="cursor-pointer" onClick={() => onEnter(input)}> |
| <ArrowRightIcon |
| className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`} |
| /> |
| </div> |
| ) : ( |
| <div> |
| <ArrowRightIcon |
| className={`h-8 w-8 p-1 bg-gray-200 dark:bg-gray-600 text-gray-50 dark:text-gray-800 rounded-md absolute right-3 bottom-3`} |
| /> |
| </div> |
| )} |
| </div> |
|
|
| <p className="text-xs text-gray-400 text-center mb-3"> |
| Disclaimer: Generated content may be inaccurate or false. |
| </p> |
| </div> |
| ) : ( |
| <div className="fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] text-white text-2xl font-semibold flex justify-center items-center text-center"> |
| WebGPU is not supported |
| <br /> |
| by this browser :( |
| </div> |
| ); |
| } |
|
|
| export default App; |
|
|