File size: 1,077 Bytes
0f3ad1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea4c290
0f3ad1b
ea4c290
0f3ad1b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

HUGGINGFACE = {
    'type': 'huggingface',
    'param':{
        'repo_id': "Qwen/Qwen2.5-7B-Instruct",
        'task': 'text-generation',
        'max_new_tokens': 512,
        'do_sample': False,
        'repetition_penalty': 1.03,
        'provider': 'auto',
    },
    'model': 'huggingface:Qwen2.5-7B-Instruct',
}

HUGGINGFACE_LITE = {
    'type': 'huggingface-lite',
    'param':{
        'repo_id': "Qwen/Qwen2.5-1.5B-Instruct",
        'task': 'text-generation',
        'max_new_tokens': 512,
        'do_sample': False,
        'repetition_penalty': 1.03,
        'provider': 'auto',
    },
    'model': 'huggingface:Qwen2.5-1.5B-Instruct',
}

valid_LLM = [HUGGINGFACE, HUGGINGFACE_LITE]

def setup_model(config: dict, API_KEY: str, callbacks=None):
    if config['type'].startswith('huggingface'):
        llm = HuggingFaceEndpoint(huggingfacehub_api_token=API_KEY, **config['param'])
        model = ChatHuggingFace(llm=llm, callbacks=callbacks)
    else:
        model = None
    return model