|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import sys
|
| import argparse
|
| import math
|
| import csv
|
| from typing import List, Tuple
|
|
|
| try:
|
| from tqdm import tqdm
|
| except ImportError:
|
| print("tqdm not installed. Install with: pip install tqdm", file=sys.stderr)
|
| sys.exit(1)
|
|
|
| try:
|
| import pandas as pd
|
| except ImportError:
|
| print("pandas not installed. Install with: pip install pandas", file=sys.stderr)
|
| sys.exit(1)
|
|
|
| try:
|
| from transformers import AutoTokenizer
|
| except ImportError:
|
| print("transformers not installed. Install with: pip install transformers", file=sys.stderr)
|
| sys.exit(1)
|
|
|
| def load_texts(csv_path: str) -> List[str]:
|
|
|
| df = pd.read_csv(csv_path, header=None, quoting=csv.QUOTE_ALL, dtype=str)
|
| df.columns = ["text"]
|
|
|
| texts = df["text"].astype(str).tolist()
|
| return texts
|
|
|
| def compute_lengths(
|
| texts: List[str],
|
| tokenizer_name: str,
|
| max_length: int,
|
| batch_size: int = 64
|
| ) -> Tuple[List[int], List[int], int]:
|
| """
|
| Returns:
|
| - lengths_with_special: token lengths when add_special_tokens=True
|
| - lengths_no_special: token lengths when add_special_tokens=False
|
| - special_tokens_to_add: tokenizer.num_special_tokens_to_add(pair=False)
|
| """
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
|
|
|
| lengths_with_special = []
|
| lengths_no_special = []
|
|
|
|
|
| special_tokens_to_add = tokenizer.num_special_tokens_to_add(pair=False)
|
|
|
|
|
| for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", unit="batch"):
|
| batch = texts[i:i+batch_size]
|
|
|
| enc_with = tokenizer(
|
| batch,
|
| add_special_tokens=True,
|
| truncation=False,
|
| return_attention_mask=False,
|
| return_token_type_ids=False
|
| )
|
| enc_without = tokenizer(
|
| batch,
|
| add_special_tokens=False,
|
| truncation=False,
|
| return_attention_mask=False,
|
| return_token_type_ids=False
|
| )
|
|
|
|
|
| lengths_with_special.extend([len(ids) for ids in enc_with["input_ids"]])
|
| lengths_no_special.extend([len(ids) for ids in enc_without["input_ids"]])
|
|
|
| return lengths_with_special, lengths_no_special, special_tokens_to_add
|
|
|
| def summarize(lengths: List[int]):
|
| import numpy as np
|
| arr = np.array(lengths, dtype=int)
|
| stats = {
|
| "count": int(arr.size),
|
| "mean": float(arr.mean()) if arr.size else 0.0,
|
| "median": float(np.median(arr)) if arr.size else 0.0,
|
| "min": int(arr.min()) if arr.size else 0,
|
| "p90": float(np.percentile(arr, 90)) if arr.size else 0.0,
|
| "p95": float(np.percentile(arr, 95)) if arr.size else 0.0,
|
| "p99": float(np.percentile(arr, 99)) if arr.size else 0.0,
|
| "max": int(arr.max()) if arr.size else 0
|
| }
|
| return stats
|
|
|
| def estimate_windows(
|
| lengths_no_special: List[int],
|
| max_length: int,
|
| special_tokens_to_add: int,
|
| doc_stride: int
|
| ) -> Tuple[int, int]:
|
| """
|
| Estimate total number of sliding-window chunks required if we split long docs.
|
| We compute per-sample windows with content capacity = max_length - special_tokens_to_add.
|
| Overlap is applied on content tokens via doc_stride.
|
| Returns:
|
| (total_windows, num_samples_needing_chunking)
|
| """
|
| content_capacity = max_length - special_tokens_to_add
|
| if content_capacity <= 0:
|
| raise ValueError(f"Invalid content capacity: {content_capacity}. Check tokenizer specials and max_length.")
|
|
|
| total_windows = 0
|
| need_chunking = 0
|
|
|
| for n in lengths_no_special:
|
| if n <= content_capacity:
|
| total_windows += 1
|
| else:
|
| need_chunking += 1
|
| step = max(content_capacity - doc_stride, 1)
|
|
|
| remaining = max(n - content_capacity, 0)
|
| extra = math.ceil(remaining / step)
|
| total_windows += 1 + extra
|
| return total_windows, need_chunking
|
|
|
| def histogram_counts(lengths_with_special: List[int], max_length: int) -> List[Tuple[str, int]]:
|
| bins = [128, 256, 384, 512]
|
| labels = []
|
| counts = []
|
| prev = 0
|
| for b in bins:
|
| labels.append(f"{prev+1:>4}-{b:>4}")
|
| counts.append(sum(1 for L in lengths_with_special if prev < L <= b))
|
| prev = b
|
| labels.append(f">{bins[-1]}")
|
| counts.append(sum(1 for L in lengths_with_special if L > bins[-1]))
|
| return list(zip(labels, counts))
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="DeBERTa v3 base token-length analysis for train_clean.csv")
|
| parser.add_argument("--input_csv", default="train_clean.csv", help="Input CSV (one quoted text per line).")
|
| parser.add_argument("--tokenizer", default="microsoft/mdeberta-v3-base", help="HF tokenizer name.")
|
| parser.add_argument("--max_length", type=int, default=512, help="Max sequence length (incl. specials).")
|
| parser.add_argument("--doc_stride", type=int, default=128, help="Sliding window overlap on content tokens.")
|
| parser.add_argument("--batch_size", type=int, default=64, help="Batch size for tokenization.")
|
| args = parser.parse_args()
|
|
|
| script_dir = os.path.dirname(os.path.abspath(__file__))
|
| input_path = os.path.join(script_dir, args.input_csv)
|
|
|
| if not os.path.isfile(input_path):
|
| print(f"Input file not found: {input_path}", file=sys.stderr)
|
| sys.exit(1)
|
|
|
| texts = load_texts(input_path)
|
|
|
| lengths_with_special, lengths_no_special, specials = compute_lengths(
|
| texts, args.tokenizer, args.max_length, batch_size=args.batch_size
|
| )
|
|
|
|
|
| stats_with = summarize(lengths_with_special)
|
| stats_no = summarize(lengths_no_special)
|
|
|
|
|
| exceed = sum(1 for L in lengths_with_special if L > args.max_length)
|
| total = len(lengths_with_special)
|
| frac = (exceed / total) * 100.0 if total else 0.0
|
|
|
|
|
| hist = histogram_counts(lengths_with_special, args.max_length)
|
| total_windows, need_chunking = estimate_windows(
|
| lengths_no_special, args.max_length, specials, args.doc_stride
|
| )
|
|
|
|
|
| print("\n=== Token Length Analysis (DeBERTa v3 base) ===")
|
| print(f"Tokenizer: {args.tokenizer}")
|
| print(f"max_length: {args.max_length} (includes {specials} special tokens per sequence)")
|
| print(f"doc_stride: {args.doc_stride} (applies to content tokens)")
|
| print(f"Samples: {total}")
|
| print(f"Exceeding max_length: {exceed} ({frac:.2f}%)")
|
| print(f"Samples needing chunking (by content capacity): {need_chunking}")
|
|
|
| print("\n-- Length stats WITH specials --")
|
| for k, v in stats_with.items():
|
| print(f"{k:>6}: {v}")
|
|
|
| print("\n-- Length stats WITHOUT specials --")
|
| for k, v in stats_no.items():
|
| print(f"{k:>6}: {v}")
|
|
|
| print("\n-- Histogram (WITH specials) --")
|
| for label, cnt in hist:
|
| print(f"{label}: {cnt}")
|
|
|
| content_capacity = args.max_length - specials
|
| print(f"\nContent capacity per window (tokens excluding specials): {content_capacity}")
|
| print(f"Estimated total windows if chunked with doc_stride={args.doc_stride}: {total_windows}")
|
|
|
| print("\nDone.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|