Text Generation
PEFT
Safetensors
Chinese
English
qwen
qlora
radar
aircraft-cabin
structured-prediction
qa
conversational
Instructions to use sutama/CabinLavatoryPrediction with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use sutama/CabinLavatoryPrediction with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B") model = PeftModel.from_pretrained(base_model, "sutama/CabinLavatoryPrediction") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import math | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from peft import PeftModel | |
| from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, precision_recall_fscore_support | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| STRUCT_FIELDS = [ | |
| "current_behavior", | |
| "is_transition", | |
| "elapsed_seconds_in_current_behavior", | |
| "estimated_remaining_seconds", | |
| "full_remaining_seconds", | |
| "expected_end_time", | |
| "next_possible_behavior", | |
| "stage_index", | |
| "total_stages", | |
| "sequence_so_far", | |
| ] | |
| TIME_FIELDS = [ | |
| "elapsed_seconds_in_current_behavior", | |
| "estimated_remaining_seconds", | |
| "full_remaining_seconds", | |
| "expected_end_time", | |
| ] | |
| QA_FIELDS = ["occupied", "time_to_free_minutes", "used_areas", "is_abnormal"] | |
| def read_jsonl(path, limit=None): | |
| rows = [] | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| if not line.strip(): | |
| continue | |
| rows.append(json.loads(line)) | |
| if limit and len(rows) >= limit: | |
| break | |
| return rows | |
| def load_model(model_name, adapter_dir=None): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| if adapter_dir: | |
| model = PeftModel.from_pretrained(model, adapter_dir) | |
| model.eval() | |
| return tokenizer, model | |
| def render_prompt(tokenizer, messages): | |
| try: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) | |
| except TypeError: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| def json_candidates(text): | |
| decoder = json.JSONDecoder() | |
| for idx, char in enumerate(text): | |
| if char != "{": | |
| continue | |
| try: | |
| obj, _ = decoder.raw_decode(text[idx:]) | |
| except Exception: | |
| continue | |
| if isinstance(obj, dict): | |
| yield obj | |
| def parse_json_text(text, preferred_fields=None): | |
| text = text.strip() | |
| try: | |
| return json.loads(text), None | |
| except Exception: | |
| pass | |
| candidates = list(json_candidates(text)) | |
| if not candidates: | |
| return None, "no_json_object" | |
| if preferred_fields: | |
| preferred = set(preferred_fields) | |
| candidates.sort(key=lambda obj: len(preferred & set(obj.keys())), reverse=True) | |
| return candidates[0], None | |
| def generate_predictions(rows, tokenizer, model, max_new_tokens, batch_size, preferred_fields, max_input_tokens, pred_path=None): | |
| records = [] | |
| pred_file = pred_path.open("w", encoding="utf-8") if pred_path else None | |
| for start in range(0, len(rows), batch_size): | |
| batch = rows[start : start + batch_size] | |
| prompts = [render_prompt(tokenizer, row["messages"][:-1]) for row in batch] | |
| inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=None, | |
| top_p=None, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| decoded = tokenizer.batch_decode(outputs[:, prompt_len:], skip_special_tokens=True) | |
| for row, pred_text in zip(batch, decoded): | |
| target_content = row["messages"][-1]["content"] | |
| target = json.loads(target_content) if isinstance(target_content, str) else target_content | |
| pred, error = parse_json_text(pred_text, preferred_fields) | |
| record = {"target": target, "prediction": pred, "raw_prediction": pred_text, "parse_error": error} | |
| records.append(record) | |
| if pred_file: | |
| pred_file.write(json.dumps(record, ensure_ascii=False, separators=(",", ":")) + "\n") | |
| pred_file.flush() | |
| print(f"generated {min(start + batch_size, len(rows))}/{len(rows)}", flush=True) | |
| if pred_file: | |
| pred_file.close() | |
| return records | |
| def safe_eq(a, b): | |
| return a == b | |
| def numeric_pairs(records, field): | |
| y_true, y_pred = [], [] | |
| for rec in records: | |
| pred = rec["prediction"] | |
| if not isinstance(pred, dict): | |
| continue | |
| t, p = rec["target"].get(field), pred.get(field) | |
| if isinstance(t, (int, float)) and isinstance(p, (int, float)) and math.isfinite(float(p)): | |
| y_true.append(float(t)) | |
| y_pred.append(float(p)) | |
| return y_true, y_pred | |
| def classification_metrics(records, field): | |
| pairs = [] | |
| for rec in records: | |
| pred = rec["prediction"] | |
| if isinstance(pred, dict) and field in pred: | |
| pairs.append((rec["target"].get(field), pred.get(field))) | |
| if not pairs: | |
| return {"accuracy": 0.0, "macro_f1": 0.0, "coverage": 0.0} | |
| y_true, y_pred = zip(*pairs) | |
| # sklearn cannot sort mixed labels such as None and str; normalize only for metric computation. | |
| y_true = ["<NULL>" if value is None else str(value) for value in y_true] | |
| y_pred = ["<NULL>" if value is None else str(value) for value in y_pred] | |
| return { | |
| "accuracy": float(accuracy_score(y_true, y_pred)), | |
| "macro_f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), | |
| "coverage": len(pairs) / len(records), | |
| } | |
| def sequence_metrics(records): | |
| exact = [] | |
| last = [] | |
| prefix = [] | |
| for rec in records: | |
| pred = rec["prediction"] | |
| if not isinstance(pred, dict): | |
| continue | |
| true_seq = [x.get("label") for x in rec["target"].get("sequence_so_far") or []] | |
| pred_seq = [x.get("label") for x in pred.get("sequence_so_far") or [] if isinstance(x, dict)] | |
| exact.append(true_seq == pred_seq) | |
| last.append(bool(true_seq and pred_seq and true_seq[-1] == pred_seq[-1])) | |
| prefix_len = min(len(true_seq), len(pred_seq)) | |
| prefix.append(sum(1 for i in range(prefix_len) if true_seq[i] == pred_seq[i]) / max(1, len(true_seq))) | |
| return { | |
| "sequence_exact_match": float(np.mean(exact)) if exact else 0.0, | |
| "sequence_last_label_accuracy": float(np.mean(last)) if last else 0.0, | |
| "sequence_prefix_label_match": float(np.mean(prefix)) if prefix else 0.0, | |
| } | |
| def evaluate_struct(records): | |
| parsed = [r for r in records if isinstance(r["prediction"], dict)] | |
| metrics = { | |
| "num_examples": len(records), | |
| "json_parse_rate": len(parsed) / max(1, len(records)), | |
| "required_field_complete_rate": sum(all(f in r["prediction"] for f in STRUCT_FIELDS) for r in parsed) / max(1, len(records)), | |
| } | |
| for field in ["current_behavior", "next_possible_behavior", "is_transition", "stage_index", "total_stages"]: | |
| cm = classification_metrics(records, field) | |
| metrics[f"{field}_accuracy"] = cm["accuracy"] | |
| if "behavior" in field or field == "is_transition": | |
| metrics[f"{field}_macro_f1"] = cm["macro_f1"] | |
| for field in TIME_FIELDS: | |
| y_true, y_pred = numeric_pairs(records, field) | |
| metrics[f"{field}_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None | |
| metrics[f"{field}_coverage"] = len(y_true) / max(1, len(records)) | |
| metrics.update(sequence_metrics(records)) | |
| return metrics | |
| def normalize_areas(value): | |
| if not isinstance(value, list): | |
| return set() | |
| return {str(x) for x in value} | |
| def evaluate_qa(records): | |
| parsed = [r for r in records if isinstance(r["prediction"], dict)] | |
| metrics = { | |
| "num_examples": len(records), | |
| "json_parse_rate": len(parsed) / max(1, len(records)), | |
| "required_field_complete_rate": sum(all(f in r["prediction"] for f in QA_FIELDS) for r in parsed) / max(1, len(records)), | |
| } | |
| for field in ["occupied", "is_abnormal"]: | |
| cm = classification_metrics(records, field) | |
| metrics[f"{field}_accuracy"] = cm["accuracy"] | |
| metrics[f"{field}_f1"] = cm["macro_f1"] | |
| y_true, y_pred = numeric_pairs(records, "time_to_free_minutes") | |
| metrics["time_to_free_minutes_mae"] = float(mean_absolute_error(y_true, y_pred)) if y_true else None | |
| true_flat, pred_flat = [], [] | |
| labels = ["门", "马桶", "洗手池", "垃圾桶"] | |
| for rec in records: | |
| pred = rec["prediction"] | |
| if not isinstance(pred, dict): | |
| continue | |
| t = normalize_areas(rec["target"].get("used_areas")) | |
| p = normalize_areas(pred.get("used_areas")) | |
| true_flat.extend([label in t for label in labels]) | |
| pred_flat.extend([label in p for label in labels]) | |
| if true_flat: | |
| pr, rc, f1, _ = precision_recall_fscore_support(true_flat, pred_flat, average="binary", zero_division=0) | |
| metrics["used_areas_micro_precision"] = float(pr) | |
| metrics["used_areas_micro_recall"] = float(rc) | |
| metrics["used_areas_micro_f1"] = float(f1) | |
| return metrics | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model-name", default="Qwen/Qwen3.5-9B") | |
| parser.add_argument("--adapter-dir", default=None) | |
| parser.add_argument("--input-file", default=None) | |
| parser.add_argument("--predictions-file", default=None) | |
| parser.add_argument("--task-type", choices=["struct", "qa"], required=True) | |
| parser.add_argument("--output-dir", default="outputs") | |
| parser.add_argument("--run-name", required=True) | |
| parser.add_argument("--max-samples", type=int, default=None) | |
| parser.add_argument("--batch-size", type=int, default=1) | |
| parser.add_argument("--max-new-tokens", type=int, default=1536) | |
| parser.add_argument("--max-input-tokens", type=int, default=6144) | |
| args = parser.parse_args() | |
| out_root = Path(args.output_dir) | |
| pred_dir = out_root / "predictions" | |
| metric_dir = out_root / "metrics" | |
| pred_dir.mkdir(parents=True, exist_ok=True) | |
| metric_dir.mkdir(parents=True, exist_ok=True) | |
| pred_path = pred_dir / f"{args.run_name}_{args.task_type}_predictions.jsonl" | |
| if args.predictions_file: | |
| records = read_jsonl(args.predictions_file, args.max_samples) | |
| else: | |
| if not args.input_file: | |
| raise ValueError("--input-file is required unless --predictions-file is provided") | |
| rows = read_jsonl(args.input_file, args.max_samples) | |
| tokenizer, model = load_model(args.model_name, args.adapter_dir) | |
| preferred_fields = STRUCT_FIELDS if args.task_type == "struct" else QA_FIELDS | |
| records = generate_predictions( | |
| rows, tokenizer, model, args.max_new_tokens, args.batch_size, preferred_fields, args.max_input_tokens, pred_path | |
| ) | |
| metrics = evaluate_struct(records) if args.task_type == "struct" else evaluate_qa(records) | |
| metric_payload = { | |
| "run_name": args.run_name, | |
| "task_type": args.task_type, | |
| "input_file": args.input_file, | |
| "predictions_file": args.predictions_file, | |
| "metrics": metrics, | |
| } | |
| metric_path = metric_dir / f"{args.run_name}_{args.task_type}_metrics.json" | |
| metric_path.write_text(json.dumps(metric_payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| print(json.dumps(metric_payload, ensure_ascii=False, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |