Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import shutil | |
| import warnings | |
| from html import escape | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image, ImageFile | |
| from handler import EndpointHandler | |
| from translator import translate_texts | |
| # ------------------------------------------------------------------ | |
| # 安全配置 | |
| # ------------------------------------------------------------------ | |
| # 1) 限制上传文件原始体积,拦截伪装图片/图片中塞入额外数据/高熵噪声导致的超大文件 | |
| MAX_UPLOAD_BYTES = 8 * 1024 * 1024 # 8 MB | |
| # 2) 限制单边尺寸,避免异常超大分辨率 | |
| MAX_IMAGE_SIDE = 4096 | |
| # 3) 限制总像素数,防止“像素炸弹”或解码后内存占用过高 | |
| MAX_IMAGE_PIXELS = 20_000_000 # 2000 万像素 | |
| # 4) 限制解码后的估算内存占用 | |
| MAX_DECOMPRESSED_BYTES = 160 * 1024 * 1024 # 160 MB | |
| # 5) 仅允许常见安全图片格式 | |
| ALLOWED_IMAGE_FORMATS = {"PNG", "JPEG", "WEBP", "BMP", "GIF"} | |
| # Pillow 安全设置 | |
| Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS | |
| ImageFile.LOAD_TRUNCATED_IMAGES = False | |
| warnings.simplefilter("error", Image.DecompressionBombWarning) | |
| class ImageValidationError(ValueError): | |
| """上传图片校验失败。""" | |
| def _format_size(num_bytes: int) -> str: | |
| if num_bytes < 1024: | |
| return f"{num_bytes} B" | |
| if num_bytes < 1024 * 1024: | |
| return f"{num_bytes / 1024:.2f} KB" | |
| return f"{num_bytes / (1024 * 1024):.2f} MB" | |
| def validate_and_open_image(image_path: str) -> Image.Image: | |
| """ | |
| 安全打开用户上传图片: | |
| - 校验原始文件体积 | |
| - 校验图片格式 | |
| - 校验宽高/总像素 | |
| - 校验解码后预估内存占用 | |
| - 拦截 Pillow 解压炸弹警告 | |
| """ | |
| if not image_path: | |
| raise ImageValidationError("未检测到上传文件。") | |
| if not os.path.isfile(image_path): | |
| raise ImageValidationError("上传文件不存在或无法访问。") | |
| file_size = os.path.getsize(image_path) | |
| if file_size <= 0: | |
| raise ImageValidationError("上传文件为空。") | |
| if file_size > MAX_UPLOAD_BYTES: | |
| raise ImageValidationError( | |
| f"图片文件过大:{_format_size(file_size)},超过限制 {_format_size(MAX_UPLOAD_BYTES)}。" | |
| ) | |
| try: | |
| with Image.open(image_path) as probe: | |
| img_format = (probe.format or "").upper() | |
| width, height = probe.size | |
| probe.verify() | |
| except Image.DecompressionBombWarning: | |
| raise ImageValidationError("图片疑似像素炸弹,已被拒绝处理。") | |
| except Exception as e: | |
| raise ImageValidationError(f"无法解析为有效图片文件:{e}") | |
| if img_format not in ALLOWED_IMAGE_FORMATS: | |
| raise ImageValidationError( | |
| f"不支持的图片格式:{img_format or '未知'}。仅允许:{', '.join(sorted(ALLOWED_IMAGE_FORMATS))}。" | |
| ) | |
| if width <= 0 or height <= 0: | |
| raise ImageValidationError("图片尺寸非法。") | |
| if width > MAX_IMAGE_SIDE or height > MAX_IMAGE_SIDE: | |
| raise ImageValidationError( | |
| f"图片尺寸过大:{width}×{height},单边不得超过 {MAX_IMAGE_SIDE} 像素。" | |
| ) | |
| total_pixels = width * height | |
| if total_pixels > MAX_IMAGE_PIXELS: | |
| raise ImageValidationError( | |
| f"图片总像素过大:{total_pixels:,},超过限制 {MAX_IMAGE_PIXELS:,}。" | |
| ) | |
| estimated_decompressed_bytes = total_pixels * 3 | |
| if estimated_decompressed_bytes > MAX_DECOMPRESSED_BYTES: | |
| raise ImageValidationError( | |
| "图片解码后的内存占用过高,已拒绝处理。" | |
| f" 预计占用约 {_format_size(estimated_decompressed_bytes)}," | |
| f"超过限制 {_format_size(MAX_DECOMPRESSED_BYTES)}。" | |
| ) | |
| try: | |
| with Image.open(image_path) as img: | |
| img.load() | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| else: | |
| img = img.copy() | |
| except Image.DecompressionBombWarning: | |
| raise ImageValidationError("图片在解码阶段触发像素炸弹保护,已拒绝处理。") | |
| except Exception as e: | |
| raise ImageValidationError(f"图片加载失败:{e}") | |
| return img | |
| # ------------------------------------------------------------------ | |
| # 新版 PixAI Tagger v0.9 模型配置 | |
| # ------------------------------------------------------------------ | |
| ASSETS_REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") | |
| ASSETS_REVISION = os.environ.get("ASSETS_REVISION") | |
| MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") | |
| HF_TOKEN = ( | |
| os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| or os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGINGFACE_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| REQUIRED_FILES = [ | |
| "model_v0.9.pth", | |
| "tags_v0.9_13k.json", | |
| "char_ip_map.json", | |
| ] | |
| def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str) -> None: | |
| """ | |
| 下载 pixai-labs/pixai-tagger-v0.9 所需资源,并复制到 handler 期望的本地目录。 | |
| 如果文件已经存在,则不会重复下载。 | |
| """ | |
| target = Path(target_dir) | |
| target.mkdir(parents=True, exist_ok=True) | |
| missing = [fname for fname in REQUIRED_FILES if not (target / fname).exists()] | |
| if not missing: | |
| return | |
| snapshot_path = snapshot_download( | |
| repo_id=repo_id, | |
| revision=revision, | |
| allow_patterns=REQUIRED_FILES, | |
| token=HF_TOKEN, | |
| ) | |
| for fname in REQUIRED_FILES: | |
| src = Path(snapshot_path) / fname | |
| dst = target / fname | |
| if not src.exists(): | |
| raise FileNotFoundError( | |
| f"模型资源缺失:'{fname}' 未在 {repo_id} @ {revision or 'default'} 中找到。" | |
| ) | |
| if src.resolve() != dst.resolve(): | |
| shutil.copyfile(src, dst) | |
| # ------------------------------------------------------------------ | |
| # Tagger 类:使用新版 EndpointHandler | |
| # ------------------------------------------------------------------ | |
| class Tagger: | |
| def __init__(self): | |
| self.handler = None | |
| self.device = "unknown" | |
| self._load_model_and_labels() | |
| def _load_model_and_labels(self) -> None: | |
| try: | |
| ensure_assets(ASSETS_REPO_ID, ASSETS_REVISION, MODEL_DIR) | |
| self.handler = EndpointHandler(MODEL_DIR) | |
| self.device = getattr(self.handler, "device", "unknown") | |
| print(f"✅ PixAI Tagger v0.9 加载成功,设备:{str(self.device).upper()}") | |
| except Exception as e: | |
| print(f"❌ PixAI Tagger v0.9 加载失败: {e}") | |
| raise RuntimeError(f"模型初始化失败: {e}") from e | |
| def _display_tag(tag: str) -> str: | |
| return str(tag).replace("_", " ") | |
| def _get_score(scores: dict, tag: str) -> float: | |
| """ | |
| handler 通常以原始 tag 作为分数字典 key。 | |
| 这里额外兼容空格/下划线两种写法,避免 key 不一致时取不到分数。 | |
| """ | |
| if not isinstance(scores, dict): | |
| return 0.0 | |
| candidates = [ | |
| tag, | |
| str(tag).replace("_", " "), | |
| str(tag).replace(" ", "_"), | |
| ] | |
| for key in candidates: | |
| if key in scores: | |
| try: | |
| return float(scores[key]) | |
| except Exception: | |
| return 0.0 | |
| return 0.0 | |
| def predict(self, img: Image.Image, gen_th: float = 0.30, char_th: float = 0.85): | |
| """ | |
| 返回结构保持原 app.py 的 UI 处理习惯: | |
| - general:通用/特征标签,带置信度 | |
| - characters:角色标签,带置信度 | |
| - ips:IP 标签,新模型不返回评分标签,因此原 ratings 改为 ips,且 IP 不展示伪造置信度 | |
| """ | |
| if self.handler is None: | |
| raise RuntimeError("模型未成功加载,无法进行预测。") | |
| if img is None: | |
| raise ValueError("输入图像不能为空。") | |
| params = { | |
| "general_threshold": float(gen_th), | |
| "character_threshold": float(char_th), | |
| "mode": "threshold", | |
| "topk_general": 25, | |
| "topk_character": 10, | |
| "include_scores": True, | |
| } | |
| data = { | |
| "inputs": img, | |
| "parameters": params, | |
| } | |
| started = time.time() | |
| out = self.handler(data) | |
| latency = round(time.time() - started, 4) | |
| feature_tags = out.get("feature", []) or [] | |
| character_tags = out.get("character", []) or [] | |
| ip_tags = out.get("ip", []) or [] | |
| feature_scores = out.get("feature_scores", {}) or {} | |
| character_scores = out.get("character_scores", {}) or {} | |
| general = { | |
| self._display_tag(tag): self._get_score(feature_scores, tag) | |
| for tag in feature_tags | |
| } | |
| characters = { | |
| self._display_tag(tag): self._get_score(character_scores, tag) | |
| for tag in character_tags | |
| } | |
| # IP 标签没有评分,使用 None 表示“不显示置信度” | |
| ips = { | |
| self._display_tag(tag): None | |
| for tag in ip_tags | |
| } | |
| general = dict(sorted(general.items(), key=lambda kv: kv[1], reverse=True)) | |
| characters = dict(sorted(characters.items(), key=lambda kv: kv[1], reverse=True)) | |
| res = { | |
| "general": general, | |
| "characters": characters, | |
| "ips": ips, | |
| } | |
| tag_categories_for_translation = { | |
| "general": list(general.keys()), | |
| "characters": list(characters.keys()), | |
| "ips": list(ips.keys()), | |
| } | |
| raw_meta = { | |
| "device": str(self.device), | |
| "latency_s_total": latency, | |
| "_params": out.get("_params", params), | |
| "_timings": out.get("_timings", {}), | |
| } | |
| return res, tag_categories_for_translation, raw_meta | |
| # 全局 Tagger 实例 | |
| try: | |
| tagger_instance = Tagger() | |
| except RuntimeError as e: | |
| print(f"应用启动时 Tagger 初始化失败: {e}") | |
| tagger_instance = None | |
| DEVICE_LABEL = ( | |
| f"设备:{str(tagger_instance.device).upper()}" | |
| if tagger_instance is not None | |
| else "设备:UNKNOWN" | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Gradio UI | |
| # ------------------------------------------------------------------ | |
| custom_css = """ | |
| .label-container { | |
| max-height: 300px; | |
| overflow-y: auto; | |
| border: 1px solid #ddd; | |
| padding: 10px; | |
| border-radius: 5px; | |
| background-color: #f9f9f9; | |
| } | |
| .tag-item { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| margin: 2px 0; | |
| padding: 2px 5px; | |
| border-radius: 3px; | |
| background-color: #fff; | |
| transition: background-color 0.2s; | |
| } | |
| .tag-item:hover { | |
| background-color: #f0f0f0; | |
| } | |
| .tag-en { | |
| font-weight: bold; | |
| color: #333; | |
| cursor: pointer; | |
| } | |
| .tag-zh { | |
| color: #666; | |
| margin-left: 10px; | |
| } | |
| .tag-score { | |
| color: #999; | |
| font-size: 0.9em; | |
| white-space: nowrap; | |
| } | |
| .btn-analyze-container { | |
| margin-top: 15px; | |
| margin-bottom: 15px; | |
| } | |
| """ | |
| _js_functions = """ | |
| function copyToClipboard(text) { | |
| console.log('copyToClipboard function was called.'); | |
| console.log('Received text:', text); | |
| if (typeof text === 'undefined' || text === null) { | |
| console.warn('copyToClipboard was called with undefined or null text. Aborting this specific copy operation.'); | |
| return; | |
| } | |
| navigator.clipboard.writeText(text).then(() => { | |
| const feedback = document.createElement('div'); | |
| let displayText = String(text); | |
| displayText = displayText.substring(0, 30) + (displayText.length > 30 ? '...' : ''); | |
| feedback.textContent = '已复制: ' + displayText; | |
| feedback.style.position = 'fixed'; | |
| feedback.style.bottom = '20px'; | |
| feedback.style.left = '50%'; | |
| feedback.style.transform = 'translateX(-50%)'; | |
| feedback.style.backgroundColor = '#4CAF50'; | |
| feedback.style.color = 'white'; | |
| feedback.style.padding = '10px 20px'; | |
| feedback.style.borderRadius = '5px'; | |
| feedback.style.zIndex = '10000'; | |
| feedback.style.transition = 'opacity 0.5s ease-out'; | |
| document.body.appendChild(feedback); | |
| setTimeout(() => { | |
| feedback.style.opacity = '0'; | |
| setTimeout(() => { | |
| if (document.body.contains(feedback)) { | |
| document.body.removeChild(feedback); | |
| } | |
| }, 500); | |
| }, 1500); | |
| }).catch(err => { | |
| console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text); | |
| const errorFeedback = document.createElement('div'); | |
| errorFeedback.textContent = '复制操作失败!'; | |
| errorFeedback.style.position = 'fixed'; | |
| errorFeedback.style.bottom = '20px'; | |
| errorFeedback.style.left = '50%'; | |
| errorFeedback.style.transform = 'translateX(-50%)'; | |
| errorFeedback.style.backgroundColor = '#D32F2F'; | |
| errorFeedback.style.color = 'white'; | |
| errorFeedback.style.padding = '10px 20px'; | |
| errorFeedback.style.borderRadius = '5px'; | |
| errorFeedback.style.zIndex = '10000'; | |
| errorFeedback.style.transition = 'opacity 0.5s ease-out'; | |
| document.body.appendChild(errorFeedback); | |
| setTimeout(() => { | |
| errorFeedback.style.opacity = '0'; | |
| setTimeout(() => { | |
| if (document.body.contains(errorFeedback)) { | |
| document.body.removeChild(errorFeedback); | |
| } | |
| }, 500); | |
| }, 2500); | |
| }); | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo: | |
| gr.Markdown("# 🖼️ AI 图像标签分析器") | |
| gr.Markdown( | |
| "上传图片自动识别标签,支持中英文显示和一键复制。" | |
| "[NovelAI在线绘画](https://nai.idlecloud.cc/)\n\n" | |
| f"**当前模型:pixai-labs/pixai-tagger-v0.9** | **{DEVICE_LABEL}**\n\n" | |
| "说明:新版模型不再返回评分标签,本页面已将原“评分标签”区域改为“IP 标签”。" | |
| ) | |
| state_res = gr.State({}) | |
| state_translations_dict = gr.State({}) | |
| state_tag_categories_for_translation = gr.State({}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="filepath", label="上传图片", height=300) | |
| btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"]) | |
| with gr.Accordion("⚙️ 高级设置", open=False): | |
| gen_slider = gr.Slider( | |
| 0, | |
| 1, | |
| value=0.30, | |
| step=0.01, | |
| label="通用标签阈值", | |
| info="越高 → 标签更少更准", | |
| ) | |
| char_slider = gr.Slider( | |
| 0, | |
| 1, | |
| value=0.85, | |
| step=0.01, | |
| label="角色标签阈值", | |
| info="推荐保持较高阈值", | |
| ) | |
| show_tag_scores = gr.Checkbox( | |
| True, | |
| label="在列表中显示标签置信度", | |
| info="IP 标签不返回置信度,因此不会显示分数。", | |
| ) | |
| with gr.Accordion("📊 标签汇总设置", open=True): | |
| gr.Markdown("选择要包含在下方汇总文本框中的标签类别:") | |
| with gr.Row(): | |
| sum_general = gr.Checkbox(True, label="通用标签", min_width=50) | |
| sum_char = gr.Checkbox(True, label="角色标签", min_width=50) | |
| sum_ip = gr.Checkbox(False, label="IP 标签", min_width=50) | |
| sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符") | |
| sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译") | |
| processing_info = gr.Markdown("", visible=False) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("🏷️ 通用标签"): | |
| out_general = gr.HTML(label="General Tags") | |
| with gr.TabItem("👤 角色标签"): | |
| gr.Markdown("<p style='color:gray; font-size:small;'>提示:角色标签由模型推断,建议保持较高阈值。</p>") | |
| out_char = gr.HTML(label="Character Tags") | |
| with gr.TabItem("🌐 IP 标签"): | |
| gr.Markdown("<p style='color:gray; font-size:small;'>提示:新版模型输出 IP 标签,但不返回评分标签/评分置信度。</p>") | |
| out_ip = gr.HTML(label="IP Tags") | |
| gr.Markdown("### 标签汇总结果") | |
| out_summary = gr.Textbox( | |
| label="标签汇总", | |
| placeholder="分析完成后,此处将显示汇总的英文标签...", | |
| lines=5, | |
| show_copy_button=True, | |
| ) | |
| with gr.Accordion("🧾 推理元数据", open=False): | |
| out_meta = gr.JSON(label="Metadata") | |
| # ----------------- 辅助函数 ----------------- | |
| def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True): | |
| if not tags_dict: | |
| return "<p>暂无标签</p>" | |
| html = '<div class="label-container">' | |
| if not isinstance(translations_list, list): | |
| translations_list = [] | |
| tag_keys = list(tags_dict.keys()) | |
| for i, tag in enumerate(tag_keys): | |
| score = tags_dict[tag] | |
| safe_tag_text = escape(str(tag)) | |
| js_arg = json.dumps(str(tag), ensure_ascii=False) | |
| html += '<div class="tag-item">' | |
| tag_display_html = ( | |
| f'<span class="tag-en" onclick=\'copyToClipboard({js_arg})\'>{safe_tag_text}</span>' | |
| ) | |
| if show_translation_in_list and i < len(translations_list) and translations_list[i]: | |
| tag_display_html += f'<span class="tag-zh">({escape(str(translations_list[i]))})</span>' | |
| html += f"<div>{tag_display_html}</div>" | |
| if show_scores and isinstance(score, (int, float)): | |
| html += f'<span class="tag-score">{score:.3f}</span>' | |
| html += "</div>" | |
| html += "</div>" | |
| return html | |
| def generate_summary_text_content( | |
| current_res, | |
| current_translations_dict, | |
| s_gen, | |
| s_char, | |
| s_ip, | |
| s_sep_type, | |
| s_show_zh, | |
| ): | |
| if not current_res: | |
| return "请先分析图像或选择要汇总的标签类别。" | |
| summary_parts = [] | |
| separators = {"逗号": ", ", "换行": "\n", "空格": " "} | |
| separator = separators.get(s_sep_type, ", ") | |
| categories_to_summarize = [] | |
| if s_gen: | |
| categories_to_summarize.append("general") | |
| if s_char: | |
| categories_to_summarize.append("characters") | |
| if s_ip: | |
| categories_to_summarize.append("ips") | |
| if not categories_to_summarize: | |
| return "请至少选择一个标签类别进行汇总。" | |
| for cat_key in categories_to_summarize: | |
| if current_res.get(cat_key): | |
| tags_to_join = [] | |
| cat_tags_en = list(current_res[cat_key].keys()) | |
| cat_translations = current_translations_dict.get(cat_key, []) | |
| for i, en_tag in enumerate(cat_tags_en): | |
| if s_show_zh and i < len(cat_translations) and cat_translations[i]: | |
| tags_to_join.append(f"{en_tag}/*{cat_translations[i]}*/") | |
| else: | |
| tags_to_join.append(en_tag) | |
| if tags_to_join: | |
| summary_parts.append(separator.join(tags_to_join)) | |
| joiner = "\n\n" if separator != "\n" and len(summary_parts) > 1 else separator if separator == "\n" else " " | |
| final_summary = joiner.join(summary_parts) | |
| return final_summary if final_summary else "选定的类别中没有找到标签。" | |
| def process_image_and_generate_outputs( | |
| image_path, | |
| g_th, | |
| c_th, | |
| s_scores, | |
| s_gen, | |
| s_char, | |
| s_ip, | |
| s_sep, | |
| s_zh_in_sum, | |
| ): | |
| if image_path is None: | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value="❌ 请先上传图片。"), | |
| "", | |
| "", | |
| "", | |
| "", | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| ) | |
| return | |
| if tagger_instance is None: | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"), | |
| "", | |
| "", | |
| "", | |
| "", | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| ) | |
| return | |
| yield ( | |
| gr.update(interactive=False, value="🔄 处理中..."), | |
| gr.update(visible=True, value="🔄 正在校验并分析图像,请稍候..."), | |
| gr.HTML(value="<p>分析中...</p>"), | |
| gr.HTML(value="<p>分析中...</p>"), | |
| gr.HTML(value="<p>分析中...</p>"), | |
| gr.update(value="分析中,请稍候..."), | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| ) | |
| try: | |
| img = validate_and_open_image(image_path) | |
| res, tag_categories_original_order, meta = tagger_instance.predict(img, g_th, c_th) | |
| all_tags_to_translate = [] | |
| for cat_key in ["general", "characters", "ips"]: | |
| all_tags_to_translate.extend(tag_categories_original_order.get(cat_key, [])) | |
| all_translations_flat = [] | |
| if all_tags_to_translate: | |
| try: | |
| all_translations_flat = translate_texts(all_tags_to_translate, src_lang="auto", tgt_lang="zh") | |
| except Exception as translate_error: | |
| print(f"⚠️ 标签翻译失败,将仅显示英文标签:{translate_error}") | |
| all_translations_flat = [""] * len(all_tags_to_translate) | |
| current_translations_dict = {} | |
| offset = 0 | |
| for cat_key in ["general", "characters", "ips"]: | |
| cat_original_tags = tag_categories_original_order.get(cat_key, []) | |
| num_tags_in_cat = len(cat_original_tags) | |
| if num_tags_in_cat > 0: | |
| current_translations_dict[cat_key] = all_translations_flat[offset: offset + num_tags_in_cat] | |
| offset += num_tags_in_cat | |
| else: | |
| current_translations_dict[cat_key] = [] | |
| general_html = format_tags_html( | |
| res.get("general", {}), | |
| current_translations_dict.get("general", []), | |
| "general", | |
| s_scores, | |
| True, | |
| ) | |
| char_html = format_tags_html( | |
| res.get("characters", {}), | |
| current_translations_dict.get("characters", []), | |
| "characters", | |
| s_scores, | |
| True, | |
| ) | |
| ip_html = format_tags_html( | |
| res.get("ips", {}), | |
| current_translations_dict.get("ips", []), | |
| "ips", | |
| s_scores, | |
| True, | |
| ) | |
| summary_text = generate_summary_text_content( | |
| res, | |
| current_translations_dict, | |
| s_gen, | |
| s_char, | |
| s_ip, | |
| s_sep, | |
| s_zh_in_sum, | |
| ) | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value="✅ 分析完成!"), | |
| general_html, | |
| char_html, | |
| ip_html, | |
| gr.update(value=summary_text), | |
| res, | |
| current_translations_dict, | |
| tag_categories_original_order, | |
| meta, | |
| ) | |
| except ImageValidationError as e: | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value=f"❌ 上传图片未通过安全校验:{str(e)}"), | |
| "<p>图片已被安全策略拒绝</p>", | |
| "<p>图片已被安全策略拒绝</p>", | |
| "<p>图片已被安全策略拒绝</p>", | |
| gr.update(value=f"错误: {str(e)}", placeholder="上传图片未通过安全校验..."), | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| ) | |
| except Exception as e: | |
| import traceback | |
| tb_str = traceback.format_exc() | |
| print(f"处理时发生错误: {e}\n{tb_str}") | |
| yield ( | |
| gr.update(interactive=True, value="🚀 开始分析"), | |
| gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), | |
| "<p>处理出错</p>", | |
| "<p>处理出错</p>", | |
| "<p>处理出错</p>", | |
| gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."), | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| ) | |
| def update_summary_display( | |
| s_gen, | |
| s_char, | |
| s_ip, | |
| s_sep, | |
| s_zh_in_sum, | |
| current_res_from_state, | |
| current_translations_from_state, | |
| ): | |
| if not current_res_from_state: | |
| return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="") | |
| new_summary_text = generate_summary_text_content( | |
| current_res_from_state, | |
| current_translations_from_state, | |
| s_gen, | |
| s_char, | |
| s_ip, | |
| s_sep, | |
| s_zh_in_sum, | |
| ) | |
| return gr.update(value=new_summary_text) | |
| btn.click( | |
| process_image_and_generate_outputs, | |
| inputs=[ | |
| img_in, | |
| gen_slider, | |
| char_slider, | |
| show_tag_scores, | |
| sum_general, | |
| sum_char, | |
| sum_ip, | |
| sum_sep, | |
| sum_show_zh, | |
| ], | |
| outputs=[ | |
| btn, | |
| processing_info, | |
| out_general, | |
| out_char, | |
| out_ip, | |
| out_summary, | |
| state_res, | |
| state_translations_dict, | |
| state_tag_categories_for_translation, | |
| out_meta, | |
| ], | |
| ) | |
| summary_controls = [sum_general, sum_char, sum_ip, sum_sep, sum_show_zh] | |
| for ctrl in summary_controls: | |
| ctrl.change( | |
| fn=update_summary_display, | |
| inputs=summary_controls + [state_res, state_translations_dict], | |
| outputs=[out_summary], | |
| ) | |
| if __name__ == "__main__": | |
| if tagger_instance is None: | |
| print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。") | |
| demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860) |