File size: 1,782 Bytes
02bc7b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import json
import heapq
def get_top_k_indices(json_file_path, k):
"""
读取JSON文件中的列表,返回最大的k个元素的索引
Args:
json_file_path: JSON文件路径
k: 需要获取的最大元素的个数
Returns:
list: 按元素大小降序排列的索引列表
Raises:
FileNotFoundError: 文件不存在时抛出
ValueError: k值无效或数据格式错误时抛出
"""
# 读取JSON文件
try:
with open(json_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"文件 {json_file_path} 不存在")
except json.JSONDecodeError:
raise ValueError("JSON文件格式错误")
# 验证数据类型
if not isinstance(data, list):
raise ValueError("JSON文件内容不是列表")
# 验证k值的有效性
if k <= 0 or k > len(data):
raise ValueError(f"k值无效,应在1到{len(data)}之间")
# 获取元素值和索引的元组列表 [(value, index), ...]
value_index_pairs = [(value, idx) for idx, value in enumerate(data)]
# 方法1:使用heapq获取最大的k个元素(效率更高,O(n log k))
top_k_pairs = heapq.nlargest(k, value_index_pairs, key=lambda x: x[0])
# 方法2:使用排序(简单直观,O(n log n))
# sorted_pairs = sorted(value_index_pairs, key=lambda x: x[0], reverse=True)
# top_k_pairs = sorted_pairs[:k]
# 提取索引
top_k_indices = [pair[1] for pair in top_k_pairs]
return top_k_indices
a = get_top_k_indices('/mnt/bn/life-mllm/users/cxr/quantization/quantization_metric/metrics/alpha/alpha_mlp_Llama-2-7b-hf.json', 10)
print(a) |