54 lines
2.1 KiB
Python
54 lines
2.1 KiB
Python
|
|
import os
|
|||
|
|
from vllm import LLM
|
|||
|
|
from vllm import SamplingParams
|
|||
|
|
from transformers import AutoTokenizer
|
|||
|
|
|
|||
|
|
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_completion(prompts, model, tokenizer=None, temperature = 1.0, top_p = 0.95, top_k=20, min_p=0,
|
|||
|
|
max_tokens = 2048, max_model_len = 4096):
|
|||
|
|
stop_token_ids = [151645, 151643]
|
|||
|
|
# 创建采样参数。temperature 控制生成文本的多样性,
|
|||
|
|
# top_p 控制核心采样的概率,
|
|||
|
|
# top_k 通过限制候选词的数量来控制生成文本的质量和多样性,
|
|||
|
|
# min_p 通过设置概率阈值来筛选候选词,从而在保证文本质量的同时增加多样性
|
|||
|
|
sampling_params = SamplingParams(temperature=temperature, top_p=top_p,
|
|||
|
|
top_k=top_k, min_p=min_p, max_tokens=max_tokens, stop_token_ids=stop_token_ids)
|
|||
|
|
#初始化vllm推理引擎
|
|||
|
|
llm = LLM(
|
|||
|
|
model=model,
|
|||
|
|
tokenizer=tokenizer,
|
|||
|
|
max_model_len=max_model_len,
|
|||
|
|
gpu_memory_utilization=0.85,
|
|||
|
|
trust_remote_code=True,
|
|||
|
|
enforce_eager=True,
|
|||
|
|
swap_space=2 # 使用2GB交换空间
|
|||
|
|
)
|
|||
|
|
outputs = llm.generate(prompts, sampling_params)
|
|||
|
|
return outputs
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
model = '/home/tong/AIProject/Qwen/Qwen/Qwen3-0.6B'
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) #加载分词器
|
|||
|
|
prompt = "给我一个关于大模型的简短介绍"
|
|||
|
|
messages = [
|
|||
|
|
{"role": "user", "content": prompt}
|
|||
|
|
]
|
|||
|
|
text = tokenizer.apply_chat_template(
|
|||
|
|
messages,
|
|||
|
|
tokenize=False,
|
|||
|
|
add_generation_prompt=True,
|
|||
|
|
enable_thinking=False)
|
|||
|
|
|
|||
|
|
outputs = get_completion(text, model, tokenizer=None, temperature=0.6, top_p = 0.95, top_k=20, min_p=0) # 对于思考模式,官方建议使用以下参数:temperature = 0.6,TopP = 0.95,TopK = 20,MinP = 0。
|
|||
|
|
|
|||
|
|
# 输出是一个包含 prompt、生成文本和其他信息的 RequestOutput 对象列表。
|
|||
|
|
# 打印输出。
|
|||
|
|
for output in outputs:
|
|||
|
|
prompt = output.prompt
|
|||
|
|
generated_text = output.outputs[0].text
|
|||
|
|
print(f"Prompt: {prompt!r}, \nResponse: {generated_text!r}")
|