269 lines
9.9 KiB
Python
269 lines
9.9 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
'''
|
||
@File : LLM.py
|
||
@Time : 2024/06/16 07:59:28
|
||
@Author : 不要葱姜蒜
|
||
@Version : 1.0
|
||
@Desc : None
|
||
'''
|
||
|
||
import os
|
||
import sys
|
||
import platform
|
||
from typing import Dict, List, Optional, Tuple, Union
|
||
|
||
from openai import OpenAI
|
||
from vllm import LLM
|
||
from vllm import SamplingParams
|
||
from transformers import AutoTokenizer
|
||
|
||
from dotenv import load_dotenv, find_dotenv
|
||
_ = load_dotenv(find_dotenv())
|
||
|
||
# Windows multiprocessing兼容性修复
|
||
if platform.system() == "Windows":
|
||
import multiprocessing
|
||
multiprocessing.set_start_method('spawn', force=True)
|
||
|
||
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
|
||
os.environ['DEEPSEEK_API'] = 'sk-8238e4a0efa748208adb1bf6e9d441f2'
|
||
os.environ['DEEPSEEK_BASE_URL'] = 'https://api.deepseek.com'
|
||
|
||
class BaseModel:
|
||
def __init__(self, path: str = '') -> None:
|
||
self.path = path
|
||
|
||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||
pass
|
||
|
||
def load_model(self):
|
||
pass
|
||
|
||
|
||
class DeepseekChat(BaseModel):
|
||
def __init__(self, path: str = '', model: str = "deepseek-chat", use_api: bool = True) -> None:
|
||
super().__init__(path)
|
||
self.model = model
|
||
self.use_api = use_api
|
||
self.tokenizer = None
|
||
self.local_llm = None # 缓存LLM实例
|
||
self._llm_initialized = False
|
||
|
||
if not use_api:
|
||
self.load_local_model()
|
||
|
||
def load_local_model(self):
|
||
if not self.path:
|
||
raise ValueError("Local model path is required when use_api=False")
|
||
|
||
# 验证模型路径是否存在
|
||
if not os.path.exists(self.path):
|
||
raise FileNotFoundError(f"Model path does not exist: {self.path}")
|
||
|
||
# 检查是否包含必要的模型文件
|
||
required_files = ['config.json', 'tokenizer_config.json']
|
||
missing_files = []
|
||
for file in required_files:
|
||
if not os.path.exists(os.path.join(self.path, file)):
|
||
missing_files.append(file)
|
||
|
||
if missing_files:
|
||
print(f"Warning: Missing model files: {missing_files}")
|
||
|
||
# 检查模型是否已经量化
|
||
config_path = os.path.join(self.path, 'config.json')
|
||
if os.path.exists(config_path):
|
||
import json
|
||
with open(config_path, 'r') as f:
|
||
config = json.load(f)
|
||
|
||
# 检查量化配置
|
||
quantization_config = config.get('quantization_config', {})
|
||
if quantization_config:
|
||
quant_method = quantization_config.get('quant_method', 'unknown')
|
||
print(f"Model is pre-quantized with method: {quant_method}")
|
||
|
||
try:
|
||
self.tokenizer = AutoTokenizer.from_pretrained(self.path, use_fast=False)
|
||
print(f"Successfully loaded tokenizer from {self.path}")
|
||
except Exception as e:
|
||
raise RuntimeError(f"Failed to load tokenizer from {self.path}. Error: {e}")
|
||
|
||
def check_gpu_memory(self):
|
||
"""检查GPU显存"""
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available():
|
||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||
print(f"Available GPU memory: {gpu_memory:.1f} GB")
|
||
if gpu_memory < 12:
|
||
print("Warning: Less than 12GB GPU memory detected. 7B model may not fit.")
|
||
return gpu_memory
|
||
except ImportError:
|
||
print("PyTorch not available for GPU memory check")
|
||
return 0
|
||
|
||
def chat(self, system_prompt: str, user_prompt: str, temperature: float = 0.7,
|
||
top_p: float = 0.95, top_k: int = 20, min_p: float = 0,
|
||
max_tokens: int = 2048, max_model_len: int = 4096) -> str:
|
||
if self.use_api:
|
||
return self._chat_with_api(system_prompt, user_prompt, temperature)
|
||
else:
|
||
return self._chat_with_local(system_prompt, user_prompt, temperature,
|
||
top_p, top_k, min_p, max_tokens, max_model_len)
|
||
|
||
def _chat_with_api(self, system_prompt: str, user_prompt: str, temperature: float) -> str:
|
||
client = OpenAI(api_key=os.getenv('DEEPSEEK_API'), base_url=os.getenv('DEEPSEEK_BASE_URL'))
|
||
response = client.chat.completions.create(
|
||
model="deepseek-chat",
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
temperature=temperature,
|
||
stream=False
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
def _initialize_llm_if_needed(self, max_model_len: int):
|
||
"""只在第一次调用时初始化LLM,后续复用"""
|
||
if self._llm_initialized and self.local_llm is not None:
|
||
return
|
||
|
||
print("Initializing LLM (this will only happen once)...")
|
||
|
||
self.local_llm = LLM(
|
||
model=self.path,
|
||
tokenizer=None,
|
||
max_model_len=max_model_len,
|
||
gpu_memory_utilization=0.85, # 降低内存使用避免累积问题
|
||
trust_remote_code=True,
|
||
enforce_eager=True,
|
||
swap_space=2 # 增加交换空间
|
||
)
|
||
self._llm_initialized = True
|
||
print("LLM initialization completed.")
|
||
|
||
def _chat_with_local(self, system_prompt: str, user_prompt: str, temperature: float,
|
||
top_p: float, top_k: int, min_p: float, max_tokens: int, max_model_len: int) -> str:
|
||
# 检查GPU显存
|
||
#self.check_gpu_memory()
|
||
|
||
# 只初始化一次LLM
|
||
self._initialize_llm_if_needed(max_model_len)
|
||
|
||
# 确保系统提示中包含不输出思考过程的指令
|
||
enhanced_system_prompt = f"{system_prompt}\n\n请直接回答,不要输出思考过程或中间推理步骤。"
|
||
|
||
messages = [
|
||
{"role": "system", "content": enhanced_system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
|
||
text = self.tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
enable_thinking=False
|
||
)
|
||
|
||
# # 调试输入文本
|
||
# print(f"Input text length: {len(text)}")
|
||
# print(f"Input text preview: {repr(text[:300])}")
|
||
|
||
stop_token_ids = [151645, 151643]
|
||
stop_strings = ["<|thinking|>", "</thinking>", "<think>", "</think>", "思考:", "Let me think", "I think", "我想", "step by step", "First,", "Looking at"]
|
||
# 临时移除所有停止条件进行测试
|
||
sampling_params = SamplingParams(
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
top_k=top_k,
|
||
min_p=min_p,
|
||
max_tokens=max_tokens,
|
||
# 暂时注释掉所有停止条件
|
||
# stop_token_ids=stop_token_ids,
|
||
# stop=stop_strings
|
||
)
|
||
|
||
# # 调试采样参数
|
||
# print(f"Sampling params - temperature: {temperature}, max_tokens: {max_tokens}")
|
||
# print(f"Stop token IDs: {stop_token_ids}")
|
||
|
||
# 使用缓存的LLM实例进行生成
|
||
outputs = self.local_llm.generate([text], sampling_params)
|
||
raw_output = outputs[0].outputs[0].text
|
||
|
||
# 后处理:移除思考过程
|
||
cleaned_output = self._clean_thinking_process(raw_output)
|
||
# print(f"Cleaned output length: {len(cleaned_output)}")
|
||
return cleaned_output
|
||
|
||
def cleanup(self):
|
||
"""显式清理LLM实例,释放GPU内存"""
|
||
if hasattr(self, 'local_llm') and self.local_llm is not None:
|
||
print("Cleaning up LLM instance...")
|
||
try:
|
||
# 尝试释放VLLM资源
|
||
if hasattr(self.local_llm, 'llm_engine'):
|
||
del self.local_llm.llm_engine
|
||
del self.local_llm
|
||
self.local_llm = None
|
||
self._llm_initialized = False
|
||
|
||
# 强制垃圾回收
|
||
import gc
|
||
gc.collect()
|
||
|
||
# 清理CUDA缓存
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
print("GPU cache cleared.")
|
||
except ImportError:
|
||
pass
|
||
|
||
print("LLM cleanup completed.")
|
||
except Exception as e:
|
||
print(f"Error during cleanup: {e}")
|
||
|
||
def __del__(self):
|
||
"""析构函数,确保资源被释放"""
|
||
self.cleanup()
|
||
|
||
def _clean_thinking_process(self, text: str) -> str:
|
||
"""移除输出中的思考过程"""
|
||
# 移除思考标签之间的内容
|
||
import re
|
||
|
||
# 移除 <think>...</think> 或 <|thinking|>...</|thinking|> 之间的内容
|
||
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
||
text = re.sub(r'<\|thinking\|>.*?</thinking>', '', text, flags=re.DOTALL)
|
||
|
||
# 移除常见思考过程开头的段落
|
||
lines = text.split('\n')
|
||
cleaned_lines = []
|
||
skip_line = False
|
||
|
||
for line in lines:
|
||
line_lower = line.lower().strip()
|
||
# 跳过包含思考过程标识的行
|
||
if any(marker in line_lower for marker in [
|
||
'let\'s tackle this', 'step by step', 'first, i need',
|
||
'looking at the text', 'i can identify', 'now, constructing',
|
||
'double-checking', 'finally, i\'ll'
|
||
]):
|
||
skip_line = True
|
||
continue
|
||
|
||
# 如果遇到JSON开始标记,停止跳过
|
||
if line.strip().startswith('[') or line.strip().startswith('{'):
|
||
skip_line = False
|
||
|
||
if not skip_line:
|
||
cleaned_lines.append(line)
|
||
|
||
result = '\n'.join(cleaned_lines).strip()
|
||
return result
|