| import argparse |
| import os |
| import logging |
| import re |
| import pandas as pd |
| from typing import Tuple |
| import numpy as np |
| import soundfile as sf |
| import zhconv |
| import librosa |
|
|
|
|
| def setup_logging(filename): |
| """配置日志系统,同时输出到控制台和文件""" |
| |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| log_file = os.path.join(script_dir, f"{filename}.log") |
|
|
| |
| log_format = "%(asctime)s - %(levelname)s - %(message)s" |
| date_format = "%Y-%m-%d %H:%M:%S" |
|
|
| |
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
|
|
| |
| for handler in logger.handlers[:]: |
| logger.removeHandler(handler) |
|
|
| |
| file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8") |
| file_handler.setLevel(logging.INFO) |
| file_formatter = logging.Formatter(log_format, date_format) |
| file_handler.setFormatter(file_formatter) |
|
|
| |
| console_handler = logging.StreamHandler() |
| console_handler.setLevel(logging.INFO) |
| console_formatter = logging.Formatter(log_format, date_format) |
| console_handler.setFormatter(console_formatter) |
|
|
| |
| logger.addHandler(file_handler) |
| logger.addHandler(console_handler) |
|
|
| return logger |
|
|
|
|
| def load_audio(filename: str) -> Tuple[np.ndarray, int]: |
| data, sample_rate = sf.read( |
| filename, |
| always_2d=True, |
| dtype="float32", |
| ) |
| data = data[:, 0] |
| if sample_rate != 16000: |
| wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000) |
| sample_rate = 16000 |
| samples = np.ascontiguousarray(data) |
| return samples, sample_rate |
|
|
|
|
| def compute_feat(filename: str, n_mels: int = 80): |
| audio, sample_rate = load_audio(filename) |
| if sample_rate != 16000: |
| audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) |
| sample_rate = 16000 |
|
|
| mel = librosa.feature.melspectrogram( |
| y=audio, |
| sr=sample_rate, |
| n_fft=480, |
| hop_length=160, |
| window="hann", |
| center=True, |
| pad_mode="reflect", |
| power=2.0, |
| n_mels=n_mels, |
| ) |
|
|
| log_spec = np.log10(np.maximum(mel, 1e-10)) |
| log_spec = np.maximum(log_spec, log_spec.max() - 8.0) |
| mel = (log_spec + 4.0) / 4.0 |
|
|
| target = 3000 |
| if mel.shape[1] > target: |
| |
| mel = mel[:, :target] |
| mel[:, -50:] = 0 |
|
|
| |
| if mel.shape[1] < target: |
| mel = np.concatenate( |
| ( |
| mel, |
| np.zeros((n_mels, target - mel.shape[1]), dtype=np.float32), |
| ), |
| axis=-1, |
| ) |
|
|
| return mel[np.newaxis, ...] |
|
|
|
|
| class AIShellDataset: |
| def __init__(self, gt_path: str): |
| """ |
| 初始化数据集 |
| |
| Args: |
| json_path: voice.json文件的路径 |
| """ |
| self.gt_path = gt_path |
| self.dataset_dir = os.path.dirname(gt_path) |
| self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764") |
|
|
| |
| assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}" |
| assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}" |
|
|
| |
| self.data = [] |
| with open(gt_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| audio_path, gt = line.split(" ") |
| audio_path = os.path.join(self.voice_dir, audio_path + ".wav") |
| self.data.append({"audio_path": audio_path, "gt": gt}) |
|
|
| |
| logger = logging.getLogger() |
| logger.info(f"加载了 {len(self.data)} 条数据") |
|
|
| def __iter__(self): |
| """返回迭代器""" |
| self.index = 0 |
| return self |
|
|
| def __next__(self): |
| """返回下一个数据项""" |
| if self.index >= len(self.data): |
| raise StopIteration |
|
|
| item = self.data[self.index] |
| audio_path = item["audio_path"] |
| ground_truth = item["gt"] |
|
|
| self.index += 1 |
| return audio_path, ground_truth |
|
|
| def __len__(self): |
| """返回数据集大小""" |
| return len(self.data) |
|
|
|
|
| class CommonVoiceDataset: |
| """Common Voice数据集解析器""" |
|
|
| def __init__(self, tsv_path: str): |
| """ |
| 初始化数据集 |
| |
| Args: |
| json_path: voice.json文件的路径 |
| """ |
| self.tsv_path = tsv_path |
| self.dataset_dir = os.path.dirname(tsv_path) |
| self.voice_dir = os.path.join(self.dataset_dir, "clips") |
|
|
| |
| assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}" |
| assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}" |
|
|
| |
| self.data = [] |
| with open(tsv_path, "r", encoding="utf-8") as f: |
| f.readline() |
| for line in f: |
| line = line.strip() |
| splits = line.split("\t") |
| audio_path = splits[1] |
| gt = splits[2] |
| audio_path = os.path.join(self.voice_dir, audio_path) |
| self.data.append({"audio_path": audio_path, "gt": gt}) |
|
|
| |
| logger = logging.getLogger() |
| logger.info(f"加载了 {len(self.data)} 条数据") |
|
|
| def __iter__(self): |
| """返回迭代器""" |
| self.index = 0 |
| return self |
|
|
| def __next__(self): |
| """返回下一个数据项""" |
| if self.index >= len(self.data): |
| raise StopIteration |
|
|
| item = self.data[self.index] |
| audio_path = item["audio_path"] |
| ground_truth = item["gt"] |
|
|
| self.index += 1 |
| return audio_path, ground_truth |
|
|
| def __len__(self): |
| """返回数据集大小""" |
| return len(self.data) |
|
|
|
|
| class CustomDataset: |
| """自定义数据集解析器""" |
|
|
| def __init__(self, label_path: str): |
| """ |
| 初始化数据集 |
| """ |
|
|
| self.label_path = label_path |
| self.dataset_dir = os.path.dirname(label_path) |
|
|
| |
| assert os.path.exists(label_path), f"{label_path}文件不存在: {label_path}" |
|
|
| |
| self.data = [] |
| df = pd.read_csv(label_path, sep="\t") |
| for i, row in df.iterrows(): |
| audio_path = os.path.join( |
| self.dataset_dir, row["SPEAKER_ID"], row["UTTRANS_ID"] |
| ) |
| gt = row["TRANSCRIPTION"] |
| self.data.append({"audio_path": audio_path, "gt": gt}) |
|
|
| |
| logger = logging.getLogger() |
| logger.info(f"加载了 {len(self.data)} 条数据") |
|
|
| def __iter__(self): |
| """返回迭代器""" |
| self.index = 0 |
| return self |
|
|
| def __next__(self): |
| """返回下一个数据项""" |
| if self.index >= len(self.data): |
| raise StopIteration |
|
|
| item = self.data[self.index] |
| audio_path = item["audio_path"] |
| ground_truth = item["gt"] |
|
|
| self.index += 1 |
| return audio_path, ground_truth |
|
|
| def __len__(self): |
| """返回数据集大小""" |
| return len(self.data) |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset") |
| parser.add_argument( |
| "--dataset", |
| "-d", |
| type=str, |
| required=True, |
| choices=["aishell", "common_voice", "custom"], |
| help="Test dataset", |
| ) |
| parser.add_argument( |
| "--gt_path", |
| "-g", |
| type=str, |
| required=True, |
| help="Test dataset ground truth file", |
| ) |
| parser.add_argument( |
| "--max_num", type=int, default=-1, required=False, help="Maximum test data num" |
| ) |
| parser.add_argument( |
| "--model_type", |
| "-t", |
| type=str, |
| choices=["tiny", "base", "small", "medium", "large", "large-v3", "turbo"], |
| required=True, |
| help="model type, only support tiny, base and small currently", |
| ) |
| parser.add_argument( |
| "--model_path", |
| "-p", |
| type=str, |
| required=False, |
| default="../models-ax650", |
| help="model path for *.axmodel, tokens.txt", |
| ) |
| parser.add_argument( |
| "--repo_id", type=str, default=None, help="repo id from huggingface" |
| ) |
| parser.add_argument( |
| "--language", |
| "-l", |
| type=str, |
| required=False, |
| default="zh", |
| help="Target language, support en, zh, ja, and others. See languages.py for more options.", |
| ) |
| parser.add_argument( |
| "--backend", type=str, default="ax", choices=["ax", "torch", "onnx"] |
| ) |
| parser.add_argument("--log_name", type=str, default="test_wer") |
| return parser.parse_args() |
|
|
|
|
| def print_args(args): |
| logger = logging.getLogger() |
| logger.info(vars(args)) |
|
|
|
|
| def min_distance(word1: str, word2: str) -> int: |
|
|
| row = len(word1) + 1 |
| column = len(word2) + 1 |
|
|
| cache = [[0] * column for i in range(row)] |
|
|
| for i in range(row): |
| for j in range(column): |
|
|
| if i == 0 and j == 0: |
| cache[i][j] = 0 |
| elif i == 0 and j != 0: |
| cache[i][j] = j |
| elif j == 0 and i != 0: |
| cache[i][j] = i |
| else: |
| if word1[i - 1] == word2[j - 1]: |
| cache[i][j] = cache[i - 1][j - 1] |
| else: |
| replace = cache[i - 1][j - 1] + 1 |
| insert = cache[i][j - 1] + 1 |
| remove = cache[i - 1][j] + 1 |
|
|
| cache[i][j] = min(replace, insert, remove) |
|
|
| return cache[row - 1][column - 1] |
|
|
|
|
| def remove_punctuation(text): |
| |
| |
| pattern = r"[^\w\s]|_" |
|
|
| |
| cleaned_text = re.sub(pattern, "", text) |
|
|
| return cleaned_text |
|
|
|
|
| def main(): |
| args = get_args() |
|
|
| |
| logger = setup_logging(args.log_name) |
| print_args(args) |
|
|
| dataset_type = args.dataset.lower() |
| if dataset_type == "aishell": |
| dataset = AIShellDataset(args.gt_path) |
| elif dataset_type == "common_voice": |
| dataset = CommonVoiceDataset(args.gt_path) |
| elif dataset_type == "custom": |
| dataset = CustomDataset(args.gt_path) |
| else: |
| raise ValueError(f"Unknown dataset type {dataset_type}") |
|
|
| max_num = args.max_num |
|
|
| |
| use_hf_model = False |
| tokenizer = None |
| task = "transcribe" |
|
|
| if args.backend == "ax": |
| from whisper_ax import Whisper |
|
|
| model = Whisper(args.model_type, args.model_path, args.language, task) |
| elif args.backend == "torch": |
| if args.repo_id is not None: |
| use_hf_model = True |
|
|
| from transformers import WhisperForConditionalGeneration |
| import torch |
|
|
| model = WhisperForConditionalGeneration.from_pretrained( |
| args.repo_id, |
| dtype=torch.float32, |
| ).cpu() |
| else: |
| import whisper |
|
|
| model = whisper.load_model(args.model_type).cpu() |
|
|
| tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True) |
| elif args.backend == "onnx": |
| import onnxruntime as ort |
| from ..model_convert.generate_data import OnnxModel |
|
|
| encoder_path = os.path.join( |
| args.model_path, f"{args.model_type}/{args.model_type}-encoder.onnx" |
| ) |
| decoder_path = os.path.join( |
| args.model_path, f"{args.model_type}/{args.model_type}-decoder.onnx" |
| ) |
| model = OnnxModel(encoder_path, decoder_path) |
|
|
| |
| references = [] |
| hyp = [] |
| all_character_error_num = 0 |
| all_character_num = 0 |
| max_data_num = max_num if max_num > 0 else len(dataset) |
| for n, (audio_path, reference) in enumerate(dataset): |
| if args.backend == "ax": |
| hypothesis = model.run(audio_path) |
| elif args.backend == "torch": |
| if use_hf_model: |
| with torch.no_grad(): |
| feature = compute_feat(audio_path, model.config.num_mel_bins) |
| r = model.generate( |
| torch.from_numpy(feature), |
| output_scores=True, |
| return_dict_in_generate=True, |
| return_timestamps=False, |
| language=args.language, |
| task="transcribe", |
| ) |
|
|
| tokens = r["sequences"][0][4:-1] |
| hypothesis = "".join(tokenizer.decode(tokens)).strip() |
| else: |
| result = model.transcribe( |
| audio_path, fp16=False, language=args.language |
| ) |
| hypothesis = result["text"] |
| if args.language == "zh": |
| hypothesis = zhconv.convert(hypothesis, "zh-hans") |
|
|
| elif args.backend == "onnx": |
| hypothesis = model.run(audio_path, args.language, task) |
|
|
| hypothesis = remove_punctuation(hypothesis).lower() |
| reference = remove_punctuation(reference).lower() |
|
|
| character_error_num = min_distance(reference, hypothesis) |
| character_num = len(reference) |
| character_error_rate = character_error_num / character_num * 100 |
|
|
| all_character_error_num += character_error_num |
| all_character_num += character_num |
|
|
| hyp.append(hypothesis) |
| references.append(reference) |
|
|
| line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%" |
| logger.info(line_content) |
|
|
| if n + 1 >= max_data_num: |
| break |
|
|
| total_character_error_rate = all_character_error_num / all_character_num * 100 |
|
|
| logger.info(f"Total WER: {total_character_error_rate}%") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|