Project02/AITrain/dialogue_scorer.py

497 lines
17 KiB
Python
Raw Permalink Normal View History

2025-08-23 18:37:00 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
对话质量评分系统
使用AI模型对生成的对话进行多维度质量评分
'''
import json
import re
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime
# 尝试导入numpy如果失败则跳过
try:
import numpy as np
NUMPY_AVAILABLE = True
except ImportError:
NUMPY_AVAILABLE = False
@dataclass
class ScoreResult:
"""单次打分结果"""
dialogue_id: str
session_id: str
speaker: str
content: str
timestamp: str
scores: Dict[str, float] # 各维度分数
overall_score: float # 总分
feedback: str # 反馈意见
scorer_type: str # 打分器类型 (ai/human)
class DialogueAIScorer:
"""AI对话质量评分器"""
def __init__(self, base_model_path: str, tokenizer=None, model=None):
"""
初始化AI评分器
Args:
base_model_path: 基础模型路径
tokenizer: 分词器可选复用现有的
model: 模型可选复用现有的
"""
self.base_model_path = base_model_path
self.tokenizer = tokenizer
self.model = model
# 如果没有传入模型,则加载
if self.tokenizer is None or self.model is None:
self._load_model()
# 评分维度定义
self.score_dimensions = {
"coherence": {
"name": "连贯性",
"description": "对话是否与上下文逻辑连贯",
"weight": 0.25
},
"character_consistency": {
"name": "角色一致性",
"description": "是否符合角色设定和人格特征",
"weight": 0.25
},
"naturalness": {
"name": "自然度",
"description": "语言表达是否自然流畅",
"weight": 0.20
},
"information_density": {
"name": "信息密度",
"description": "是否包含有意义的信息,避免废话",
"weight": 0.15
},
"creativity": {
"name": "创意性",
"description": "内容是否有趣、有创意",
"weight": 0.15
}
}
def _load_model(self):
"""加载模型和分词器"""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
print(f"Loading scorer tokenizer from: {self.base_model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_path,
use_fast=False,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"Loading scorer model from: {self.base_model_path}")
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
except Exception as e:
print(f"✗ AI评分器模型加载失败: {e}")
raise
def score_dialogue(self,
dialogue_content: str,
speaker: str,
character_data: Dict,
dialogue_history: List[Dict] = None,
context_info: List[Dict] = None) -> ScoreResult:
"""
对单条对话进行AI评分
Args:
dialogue_content: 对话内容
speaker: 说话者
character_data: 角色数据
dialogue_history: 对话历史
context_info: 上下文信息
Returns:
ScoreResult: 评分结果
"""
# 构建评分提示
scoring_prompt = self._build_scoring_prompt(
dialogue_content, speaker, character_data, dialogue_history, context_info
)
# 使用AI模型生成评分
try:
scores, feedback = self._generate_ai_scores(scoring_prompt)
# 计算总分
overall_score = self._calculate_overall_score(scores)
# 创建评分结果
result = ScoreResult(
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
session_id="", # 由调用方设置
speaker=speaker,
content=dialogue_content,
timestamp=datetime.now().isoformat(),
scores=scores,
overall_score=overall_score,
feedback=feedback,
scorer_type="ai"
)
return result
except Exception as e:
print(f"✗ AI评分失败: {e}")
# 返回默认评分
return self._create_default_score(dialogue_content, speaker)
def _build_scoring_prompt(self,
dialogue_content: str,
speaker: str,
character_data: Dict,
dialogue_history: List[Dict] = None,
context_info: List[Dict] = None) -> str:
"""构建评分提示"""
# 基础角色信息
character_info = ""
if character_data:
personality = character_data.get('personality', {})
traits = personality.get('core_traits', [])
occupation = character_data.get('basic_info', {}).get('occupation', '未知')
character_info = f"角色职业: {occupation}, 性格特点: {', '.join(traits[:3])}"
# 对话历史
history_text = ""
if dialogue_history:
history_text = "对话历史:\n"
for turn in dialogue_history[-3:]: # 只取最近3轮
history_text += f"{turn.get('speaker', '未知')}: {turn.get('content', '')}\n"
# 构建完整提示
prompt = f"""请对以下对话内容进行质量评分。
角色设定:
{character_info}
{history_text}
当前对话:
{speaker}: {dialogue_content}
请从以下5个维度评分1-10
1. 连贯性 - 对话是否与上下文逻辑连贯
2. 角色一致性 - 是否符合角色设定和人格特征
3. 自然度 - 语言表达是否自然流畅
4. 信息密度 - 是否包含有意义的信息避免废话
5. 创意性 - 内容是否有趣有创意
请按以下格式输出
连贯性: X分
角色一致性: X分
自然度: X分
信息密度: X分
创意性: X分
总体评价: [具体的改进建议和优点分析]"""
return prompt
def _generate_ai_scores(self, prompt: str) -> Tuple[Dict[str, float], str]:
"""使用AI模型生成评分"""
import torch
# 准备消息
messages = [
{"role": "system", "content": "你是一个专业的对话质量评估专家,请客观公正地评分。"},
{"role": "user", "content": prompt}
]
# 应用对话模板
inputs = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
enable_thinking=False
)
# 移动到设备
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# 生成评分
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.3, # 较低温度确保评分稳定
top_p=0.8,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 解码输出
response = outputs[0][inputs['input_ids'].shape[1]:]
result_text = self.tokenizer.decode(response, skip_special_tokens=True).strip()
# 解析评分结果
scores, feedback = self._parse_score_response(result_text)
return scores, feedback
def _parse_score_response(self, response: str) -> Tuple[Dict[str, float], str]:
"""解析AI评分响应"""
scores = {}
feedback = ""
# 定义维度映射
dimension_map = {
"连贯性": "coherence",
"角色一致性": "character_consistency",
"自然度": "naturalness",
"信息密度": "information_density",
"创意性": "creativity"
}
try:
lines = response.split('\n')
feedback_start = False
for line in lines:
line = line.strip()
# 查找评分
for chinese_name, english_key in dimension_map.items():
if chinese_name in line and ':' in line:
# 提取分数
score_match = re.search(r'(\d+(?:\.\d+)?)', line)
if score_match:
score = float(score_match.group(1))
# 确保分数在1-10范围内
score = max(1.0, min(10.0, score))
scores[english_key] = score
# 查找总体评价
if '总体评价' in line or '评价' in line:
feedback_start = True
feedback_content = line.split(':', 1)
if len(feedback_content) > 1:
feedback += feedback_content[1].strip()
elif feedback_start and line:
feedback += " " + line
# 确保所有维度都有分数
for english_key in dimension_map.values():
if english_key not in scores:
scores[english_key] = 5.0 # 默认中等分数
if not feedback:
feedback = "AI评分完成建议根据各维度分数进行改进。"
except Exception as e:
print(f"解析评分响应失败: {e}")
# 使用默认分数
for english_key in dimension_map.values():
scores[english_key] = 5.0
feedback = "评分解析失败,使用默认分数。"
return scores, feedback
def _calculate_overall_score(self, scores: Dict[str, float]) -> float:
"""计算总分"""
total_score = 0.0
total_weight = 0.0
for dimension, score in scores.items():
if dimension in self.score_dimensions:
weight = self.score_dimensions[dimension]["weight"]
total_score += score * weight
total_weight += weight
if total_weight > 0:
return round(total_score / total_weight, 2)
else:
return 5.0
def _create_default_score(self, dialogue_content: str, speaker: str) -> ScoreResult:
"""创建默认评分结果"""
default_scores = {
"coherence": 5.0,
"character_consistency": 5.0,
"naturalness": 5.0,
"information_density": 5.0,
"creativity": 5.0
}
return ScoreResult(
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
session_id="",
speaker=speaker,
content=dialogue_content,
timestamp=datetime.now().isoformat(),
scores=default_scores,
overall_score=5.0,
feedback="使用默认评分",
scorer_type="ai"
)
def batch_score_dialogue(self, dialogue_list: List[Dict]) -> List[ScoreResult]:
"""批量评分对话"""
results = []
for i, dialogue_item in enumerate(dialogue_list):
print(f"正在评分 {i+1}/{len(dialogue_list)}: {dialogue_item.get('speaker', '未知')}")
try:
result = self.score_dialogue(
dialogue_content=dialogue_item.get('content', ''),
speaker=dialogue_item.get('speaker', '未知'),
character_data=dialogue_item.get('character_data', {}),
dialogue_history=dialogue_item.get('dialogue_history', []),
context_info=dialogue_item.get('context_info', [])
)
# 设置session_id
result.session_id = dialogue_item.get('session_id', '')
results.append(result)
except Exception as e:
print(f"评分失败: {e}")
# 添加默认评分
default_result = self._create_default_score(
dialogue_item.get('content', ''),
dialogue_item.get('speaker', '未知')
)
default_result.session_id = dialogue_item.get('session_id', '')
results.append(default_result)
return results
class HumanScorer:
"""人工评分器"""
def __init__(self):
self.score_dimensions = {
"coherence": "连贯性",
"character_consistency": "角色一致性",
"naturalness": "自然度",
"information_density": "信息密度",
"creativity": "创意性"
}
def score_dialogue_interactive(self,
dialogue_content: str,
speaker: str,
session_id: str = "") -> ScoreResult:
"""交互式人工评分"""
print(f"\n=== 人工评分 ===")
print(f"角色: {speaker}")
print(f"对话: {dialogue_content}")
print(f"请对以下维度评分 (1-10分):")
scores = {}
for dimension_key, dimension_name in self.score_dimensions.items():
while True:
try:
score_input = input(f"{dimension_name} (1-10): ").strip()
score = float(score_input)
if 1 <= score <= 10:
scores[dimension_key] = score
break
else:
print("请输入1-10之间的分数")
except ValueError:
print("请输入有效的数字")
# 获取反馈
feedback = input("请输入评价和建议 (可选): ").strip()
if not feedback:
feedback = "人工评分完成"
# 计算总分
overall_score = sum(scores.values()) / len(scores)
return ScoreResult(
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
session_id=session_id,
speaker=speaker,
content=dialogue_content,
timestamp=datetime.now().isoformat(),
scores=scores,
overall_score=round(overall_score, 2),
feedback=feedback,
scorer_type="human"
)
class QuickScorer:
"""快速规则评分器(用于实时反馈)"""
def __init__(self):
pass
def quick_score(self,
dialogue_content: str,
speaker: str,
dialogue_history: List[Dict] = None) -> float:
"""快速评分(基于规则)"""
score = 5.0 # 基础分
# 长度检查
content_length = len(dialogue_content.strip())
if content_length < 10:
score -= 2.0 # 太短
elif content_length > 200:
score -= 1.0 # 太长
elif 30 <= content_length <= 100:
score += 0.5 # 长度适中
# 重复检查
if dialogue_history:
recent_content = [turn.get('content', '') for turn in dialogue_history[-3:]]
for prev_content in recent_content:
if dialogue_content == prev_content:
score -= 3.0 # 重复内容
elif self._calculate_similarity(dialogue_content, prev_content) > 0.8:
score -= 1.5 # 高度相似
# 内容质量检查
if any(word in dialogue_content for word in ['...', '', '', '嗯嗯']):
score -= 0.5 # 含有填充词
if re.search(r'[。!?]', dialogue_content):
score += 0.3 # 有标点符号
# 确保分数在合理范围内
return max(1.0, min(10.0, score))
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度(简单方法)"""
words1 = set(text1)
words2 = set(text2)
if not words1 and not words2:
return 1.0
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
return intersection / union if union > 0 else 0.0