| |
|
|
| import streamlit as st |
| import torch |
| import joblib |
| import dill |
| import numpy as np |
| import gdown |
| import os |
|
|
| |
| @st.cache_resource |
| def load_assets(): |
| with open("preprocess_function.pkl", "rb") as f: |
| preprocess_text = dill.load(f) |
| tfidf = joblib.load("tfidf_vectorizer.pkl") |
| model = joblib.load("sage_model.pkl") |
| return preprocess_text, tfidf, model |
|
|
| |
| def ensure_knn_model(): |
| knn_path = "knn_model.pkl" |
| if not os.path.exists(knn_path): |
| gdown.download( |
| "https://drive.google.com/uc?id=166HWcckEVofU1TzVpZPNzbHdjxV_SqpT", |
| knn_path, |
| quiet=False |
| ) |
| return joblib.load(knn_path) |
|
|
| |
| preprocess_text, tfidf_vectorizer, sage_model = load_assets() |
| knn_model = ensure_knn_model() |
|
|
| |
| st.title("π§ Disinformation Detection") |
| st.write("This app predicts whether a given news article is **real** or **disinformation** using a trained GraphSAGE model.") |
|
|
| |
| user_input = st.text_area("π Enter a news article or headline:") |
|
|
| if st.button("Detect"): |
| if user_input.strip() == "": |
| st.warning("Please enter some text to analyze.") |
| else: |
| |
| cleaned_text = preprocess_text(user_input) |
| tfidf_vector = tfidf_vectorizer.transform([cleaned_text]) |
| input_feature = torch.tensor(tfidf_vector.toarray(), dtype=torch.float) |
|
|
| |
| original_features = torch.tensor(knn_model._fit_X, dtype=torch.float) |
|
|
| |
| combined_features = torch.cat([original_features, input_feature], dim=0) |
|
|
| |
| neighbors = knn_model.kneighbors(combined_features, return_distance=False) |
| edge_list = [] |
| for idx, nbrs in enumerate(neighbors): |
| for nbr in nbrs: |
| if idx != nbr: |
| edge_list.append([idx, nbr]) |
| edge_index = torch.tensor(np.array(edge_list).T, dtype=torch.long) |
|
|
| |
| sage_model.eval() |
| with torch.no_grad(): |
| logits = sage_model(combined_features, edge_index) |
| pred_node_logits = logits[-1] |
| prediction = torch.argmax(pred_node_logits).item() |
| confidence = torch.exp(pred_node_logits)[prediction].item() |
|
|
| |
| label = "π’ Real News" if prediction == 1 else "π΄ Disinformation" |
| st.markdown(f"### Prediction: {label}") |
| st.markdown(f"**Confidence:** {confidence:.2%}") |
|
|