Instructions to use pra1223/psharma-models with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- ESPnet
How to use pra1223/psharma-models with ESPnet:
unknown model type (must be text-to-speech or automatic-speech-recognition)
- Notebooks
- Google Colab
- Kaggle
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import re | |
| # --- Backend Functions --- | |
| def initialize_gemini_api(api_key): | |
| """Initializes the Gemini model and tokenizer.""" | |
| try: | |
| # Using Auto Classes is generally recommended for loading from Hugging Face | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemini-1.5-pro-001", token=api_key) #check if model has a tokenizer and version number. | |
| model = AutoModelForCausalLM.from_pretrained("google/gemini-1.5-pro-001", token=api_key, device_map="auto", torch_dtype=torch.bfloat16) #Added model device and dtype. | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error initializing model: {e}") | |
| return None, None | |
| def preprocess_input(user_input, input_type): | |
| """Preprocesses the input based on the input type.""" | |
| prompt_templates = { | |
| "recipe_suggestion": "I have the following ingredients: {}. Suggest a recipe, and the recipe must include the ingredients I provided. Provide steps", | |
| "promotion_idea": "Suggest a promotion to increase customer engagement based on these goals/themes: {}.", | |
| "waste_reduction_tip": "Suggest strategies, including numbered steps, to minimize food waste based on this context/these ingredients: {}.", | |
| "event_planning": "I want to plan an event. Here's the description/goals/requirements: {}. Give detailed, step-by-step instructions and important considerations.", | |
| } | |
| prompt = prompt_templates.get(input_type) | |
| if prompt: | |
| return prompt.format(user_input) | |
| else: | |
| return "Invalid input type." # Should ideally never happen due to Streamlit UI controls. | |
| def generate_suggestion(model, tokenizer, processed_input): | |
| """Generates text using the Gemini model.""" | |
| try: | |
| input_ids = tokenizer(processed_input, return_tensors="pt").to(model.device) # Make sure tensors are on same device | |
| outputs = model.generate(**input_ids, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True) # Added important params for generation quality | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |
| except Exception as e: | |
| st.error(f"Error during generation: {e}") | |
| return "An error occurred during suggestion generation." | |
| def postprocess_output(raw_response, input_type): | |
| """Postprocesses the generated text.""" | |
| # Remove any leading/trailing whitespace | |
| cleaned_response = raw_response.strip() | |
| # Further, specific postprocessing according to context | |
| if input_type == 'recipe_suggestion': | |
| try: | |
| pass # Can add custom filtering | |
| except: | |
| pass | |
| elif input_type == 'promotion_idea': | |
| try: | |
| pass #Can add custom regex and filters | |
| except: | |
| pass | |
| elif input_type == "waste_reduction_tip" or input_type == 'event_planning': | |
| try: | |
| # Check to ensure instructions and steps in final output. | |
| pass | |
| except: | |
| pass | |
| # Basic example: Split into sentences for better readability (can be improved) | |
| sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', cleaned_response) | |
| formatted_response = "\n\n".join(sentences) | |
| return formatted_response | |
| def get_ai_suggestion(user_input, input_type, api_key): | |
| model, tokenizer = initialize_gemini_api(api_key) | |
| if model is None or tokenizer is None: | |
| return "Failed to initialize the model. Check your API key." | |
| processed_input = preprocess_input(user_input, input_type) | |
| raw_response = generate_suggestion(model, tokenizer, processed_input) | |
| formatted_response = postprocess_output(raw_response, input_type) | |
| return formatted_response | |
| # --- Streamlit Frontend --- | |
| st.set_page_config(page_title="AI Restaurant Assistant", layout="wide") #Set page config | |
| st.sidebar.title("AI Restaurant Assistant") | |
| # --- API KEY HANDLING --- | |
| # Use st.session_state to persist the API key *only for the session* | |
| if 'api_key' not in st.session_state: | |
| st.session_state.api_key = '' | |
| # IMPORTANT SECURITY NOTE: This method is suitable for demonstration/local development. | |
| # For a production deployment, you MUST use a more secure method of storing the API key, | |
| # such as environment variables and NEVER hardcode it or commit it to version control. | |
| api_key_input = st.sidebar.text_input("Enter your Hugging Face API key:", type="password", value=st.session_state.api_key) | |
| if api_key_input: | |
| st.session_state.api_key = api_key_input #Value is automatically cached and input bar has api_key once entered. | |
| if not st.session_state.api_key: | |
| st.sidebar.warning("AIzaSyBsHbB6QcwYSjZ7GeGmNuUkmnTm6a2BjmM") | |
| st.stop() # Stop execution if no API key | |
| # --- Input Selection --- | |
| input_type = st.sidebar.selectbox("What kind of suggestion do you need?", | |
| ["recipe_suggestion", "promotion_idea", "waste_reduction_tip", "event_planning"]) | |
| # --- Main Area --- | |
| st.title("Get AI-Powered Suggestions") | |
| st.write("This tool leverages the power of the Gemini 1.5 Pro model to assist with various restaurant management tasks.") # Introduction and description | |
| user_input = st.text_area("Enter your input here:", height=150, key="user_input") #Key is added | |
| if st.button("Generate Suggestion"): | |
| if user_input: | |
| with st.spinner("Generating suggestion..."): | |
| suggestion = get_ai_suggestion(user_input, input_type, st.session_state.api_key) | |
| st.markdown("### AI Suggestion:", unsafe_allow_html=True) #Style output and enhance it visually. | |
| st.write(suggestion) #Can upgrade output design by having boxes etc. | |
| else: | |
| st.warning("Please enter some input.") |