File size: 1,782 Bytes
02bc7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
import heapq
def get_top_k_indices(json_file_path, k):
    """
    读取JSON文件中的列表,返回最大的k个元素的索引
    
    Args:
        json_file_path: JSON文件路径
        k: 需要获取的最大元素的个数
        
    Returns:
        list: 按元素大小降序排列的索引列表
        
    Raises:
        FileNotFoundError: 文件不存在时抛出
        ValueError: k值无效或数据格式错误时抛出
    """
    # 读取JSON文件
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except FileNotFoundError:
        raise FileNotFoundError(f"文件 {json_file_path} 不存在")
    except json.JSONDecodeError:
        raise ValueError("JSON文件格式错误")
    
    # 验证数据类型
    if not isinstance(data, list):
        raise ValueError("JSON文件内容不是列表")
    
    # 验证k值的有效性
    if k <= 0 or k > len(data):
        raise ValueError(f"k值无效,应在1到{len(data)}之间")
    
    # 获取元素值和索引的元组列表 [(value, index), ...]
    value_index_pairs = [(value, idx) for idx, value in enumerate(data)]
    
    # 方法1:使用heapq获取最大的k个元素(效率更高,O(n log k))
    top_k_pairs = heapq.nlargest(k, value_index_pairs, key=lambda x: x[0])
    
    # 方法2:使用排序(简单直观,O(n log n))
    # sorted_pairs = sorted(value_index_pairs, key=lambda x: x[0], reverse=True)
    # top_k_pairs = sorted_pairs[:k]
    
    # 提取索引
    top_k_indices = [pair[1] for pair in top_k_pairs]
    
    return top_k_indices

a = get_top_k_indices('/mnt/bn/life-mllm/users/cxr/quantization/quantization_metric/metrics/alpha/alpha_mlp_Llama-2-7b-hf.json', 10)
print(a)