添加对话内容 ai打分功能

This commit is contained in:
997146918 2025-08-23 17:10:23 +08:00
parent 350b207527
commit 6e2b99438b
4 changed files with 1083 additions and 18 deletions

View File

@ -286,6 +286,9 @@ class ConversationManager:
timestamp TEXT, timestamp TEXT,
context_used TEXT, context_used TEXT,
relevance_score REAL, relevance_score REAL,
dialogue_score REAL DEFAULT 0.0,
score_details TEXT,
score_feedback TEXT,
FOREIGN KEY (session_id) REFERENCES conversations (session_id) FOREIGN KEY (session_id) REFERENCES conversations (session_id)
) )
''') ''')
@ -305,7 +308,9 @@ class ConversationManager:
print(f"✓ 创建对话会话: {session_id}") print(f"✓ 创建对话会话: {session_id}")
return session_id return session_id
def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None, relevance_score: float = 0.0): def add_dialogue_turn(self, session_id: str, speaker: str, content: str, context_used: List[str] = None,
relevance_score: float = 0.0, dialogue_score: float = 0.0,
score_details: str = None, score_feedback: str = None):
"""添加对话轮次""" """添加对话轮次"""
if context_used is None: if context_used is None:
context_used = [] context_used = []
@ -318,10 +323,11 @@ class ConversationManager:
# 插入对话轮次 # 插入对话轮次
conn.execute( conn.execute(
"""INSERT INTO dialogue_turns """INSERT INTO dialogue_turns
(session_id, turn_number, speaker, content, timestamp, context_used, relevance_score) (session_id, turn_number, speaker, content, timestamp, context_used, relevance_score,
VALUES (?, ?, ?, ?, ?, ?, ?)""", dialogue_score, score_details, score_feedback)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(session_id, turn_number, speaker, content, datetime.now().isoformat(), (session_id, turn_number, speaker, content, datetime.now().isoformat(),
json.dumps(context_used), relevance_score) json.dumps(context_used), relevance_score, dialogue_score, score_details, score_feedback)
) )
# 更新会话最后更新时间 # 更新会话最后更新时间
@ -335,7 +341,7 @@ class ConversationManager:
"""获取对话历史""" """获取对话历史"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute( cursor = conn.execute(
"""SELECT speaker, content, timestamp, context_used, relevance_score """SELECT speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback
FROM dialogue_turns FROM dialogue_turns
WHERE session_id = ? WHERE session_id = ?
ORDER BY turn_number DESC LIMIT ?""", ORDER BY turn_number DESC LIMIT ?""",
@ -344,14 +350,20 @@ class ConversationManager:
turns = [] turns = []
for row in cursor.fetchall(): for row in cursor.fetchall():
speaker, content, timestamp, context_used, relevance_score = row speaker, content, timestamp, context_used, relevance_score, dialogue_score, score_feedback = row
turns.append(DialogueTurn( turn = DialogueTurn(
speaker=speaker, speaker=speaker,
content=content, content=content,
timestamp=timestamp, timestamp=timestamp,
context_used=json.loads(context_used or "[]"), context_used=json.loads(context_used or "[]"),
relevance_score=relevance_score relevance_score=relevance_score
)) )
# 添加评分信息到turn对象
if dialogue_score:
turn.dialogue_score = dialogue_score
if score_feedback:
turn.score_feedback = score_feedback
turns.append(turn)
return list(reversed(turns)) # 按时间正序返回 return list(reversed(turns)) # 按时间正序返回
@ -378,10 +390,70 @@ class ConversationManager:
class DualAIDialogueEngine: class DualAIDialogueEngine:
"""双AI对话引擎""" """双AI对话引擎"""
def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator): def __init__(self, knowledge_base: RAGKnowledgeBase, conversation_manager: ConversationManager, llm_generator,
enable_scoring: bool = True, base_model_path: str = None):
self.kb = knowledge_base self.kb = knowledge_base
self.conv_mgr = conversation_manager self.conv_mgr = conversation_manager
self.llm_generator = llm_generator self.llm_generator = llm_generator
self.enable_scoring = enable_scoring
self.scorer = None
# 初始化评分器
if enable_scoring and base_model_path:
try:
from dialogue_scorer import DialogueAIScorer
print("正在初始化对话评分系统...")
self.scorer = DialogueAIScorer(
base_model_path=base_model_path,
tokenizer=getattr(llm_generator, 'tokenizer', None),
model=getattr(llm_generator, 'model', None)
)
print("✓ 对话评分系统初始化成功")
except Exception as e:
print(f"⚠ 对话评分系统初始化失败: {e}")
self.enable_scoring = False
def score_dialogue_turn(self, dialogue_content: str, speaker: str, dialogue_history: List[DialogueTurn]) -> Tuple[float, str, str]:
"""对单条对话进行评分
Args:
dialogue_content: 对话内容
speaker: 说话者
dialogue_history: 对话历史
Returns:
tuple: (总分, 详细分数JSON, 反馈意见)
"""
if not self.enable_scoring or not self.scorer:
return 0.0, "{}", "评分系统未启用"
try:
# 获取角色数据
character_data = self.kb.character_data.get(speaker, {})
# 转换对话历史格式
history_for_scoring = []
for turn in dialogue_history[-5:]: # 最近5轮对话
history_for_scoring.append({
'speaker': turn.speaker,
'content': turn.content
})
# 进行AI评分
score_result = self.scorer.score_dialogue(
dialogue_content=dialogue_content,
speaker=speaker,
character_data=character_data,
dialogue_history=history_for_scoring,
context_info=[]
)
# 返回评分结果
return score_result.overall_score, json.dumps(score_result.scores), score_result.feedback
except Exception as e:
print(f"⚠ 对话评分失败: {e}")
return 0.0, "{}", f"评分失败: {str(e)}"
def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn], def generate_character_prompt(self, character_name: str, context_info: List[Dict], dialogue_history: List[DialogueTurn],
history_context_count: int = 3, context_info_count: int = 2) -> str: history_context_count: int = 3, context_info_count: int = 2) -> str:
@ -494,9 +566,17 @@ class DualAIDialogueEngine:
context_used = [f"{info['section']}.{info['subsection']}" for info in context_info[:context_info_count]] context_used = [f"{info['section']}.{info['subsection']}" for info in context_info[:context_info_count]]
avg_relevance = sum(info['relevance_score'] for info in context_info[:context_info_count]) / len(context_info[:context_info_count]) if context_info else 0.0 avg_relevance = sum(info['relevance_score'] for info in context_info[:context_info_count]) / len(context_info[:context_info_count]) if context_info else 0.0
# 保存对话轮次 # 对对话进行评分
if self.enable_scoring:
dialogue_score, score_details, score_feedback = self.score_dialogue_turn(response, current_speaker, dialogue_history)
print(f" [评分: {dialogue_score:.2f}] {score_feedback}")
else:
dialogue_score, score_details, score_feedback = 0.0, "{}", ""
# 保存对话轮次(包含评分信息)
self.conv_mgr.add_dialogue_turn( self.conv_mgr.add_dialogue_turn(
session_id, current_speaker, response, context_used, avg_relevance session_id, current_speaker, response, context_used, avg_relevance,
dialogue_score, score_details, score_feedback
) )
return response, context_used return response, context_used
@ -605,14 +685,29 @@ class DualAIDialogueEngine:
max_new_tokens=150 max_new_tokens=150
) )
# 保存对话到数据库 # 保存对话到数据库并进行评分
for result in conversation_results: for result in conversation_results:
# 获取当前对话历史进行评分
current_dialogue_history = self.conv_mgr.get_conversation_history(session_id)
# 对对话进行评分
if self.enable_scoring:
dialogue_score, score_details, score_feedback = self.score_dialogue_turn(
result['dialogue'], result['speaker'], current_dialogue_history
)
print(f" [评分: {dialogue_score:.2f}] {score_feedback[:100]}...")
else:
dialogue_score, score_details, score_feedback = 0.0, "{}", ""
self.conv_mgr.add_dialogue_turn( self.conv_mgr.add_dialogue_turn(
session_id, session_id,
result['speaker'], result['speaker'],
result['dialogue'], result['dialogue'],
[result.get('context_used', '')], [result.get('context_used', '')],
0.8 # 默认相关性分数 0.8, # 默认相关性分数
dialogue_score,
score_details,
score_feedback
) )

File diff suppressed because it is too large Load Diff