修改打分bug
This commit is contained in:
parent
5bc280c539
commit
4c29745fa3
497
AITrain/dialogue_scorer.py
Normal file
497
AITrain/dialogue_scorer.py
Normal file
@ -0,0 +1,497 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
对话质量评分系统
|
||||
使用AI模型对生成的对话进行多维度质量评分
|
||||
'''
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
# 尝试导入numpy,如果失败则跳过
|
||||
try:
|
||||
import numpy as np
|
||||
NUMPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
NUMPY_AVAILABLE = False
|
||||
|
||||
@dataclass
|
||||
class ScoreResult:
|
||||
"""单次打分结果"""
|
||||
dialogue_id: str
|
||||
session_id: str
|
||||
speaker: str
|
||||
content: str
|
||||
timestamp: str
|
||||
scores: Dict[str, float] # 各维度分数
|
||||
overall_score: float # 总分
|
||||
feedback: str # 反馈意见
|
||||
scorer_type: str # 打分器类型 (ai/human)
|
||||
|
||||
class DialogueAIScorer:
|
||||
"""AI对话质量评分器"""
|
||||
|
||||
def __init__(self, base_model_path: str, tokenizer=None, model=None):
|
||||
"""
|
||||
初始化AI评分器
|
||||
|
||||
Args:
|
||||
base_model_path: 基础模型路径
|
||||
tokenizer: 分词器(可选,复用现有的)
|
||||
model: 模型(可选,复用现有的)
|
||||
"""
|
||||
self.base_model_path = base_model_path
|
||||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
|
||||
# 如果没有传入模型,则加载
|
||||
if self.tokenizer is None or self.model is None:
|
||||
self._load_model()
|
||||
|
||||
# 评分维度定义
|
||||
self.score_dimensions = {
|
||||
"coherence": {
|
||||
"name": "连贯性",
|
||||
"description": "对话是否与上下文逻辑连贯",
|
||||
"weight": 0.25
|
||||
},
|
||||
"character_consistency": {
|
||||
"name": "角色一致性",
|
||||
"description": "是否符合角色设定和人格特征",
|
||||
"weight": 0.25
|
||||
},
|
||||
"naturalness": {
|
||||
"name": "自然度",
|
||||
"description": "语言表达是否自然流畅",
|
||||
"weight": 0.20
|
||||
},
|
||||
"information_density": {
|
||||
"name": "信息密度",
|
||||
"description": "是否包含有意义的信息,避免废话",
|
||||
"weight": 0.15
|
||||
},
|
||||
"creativity": {
|
||||
"name": "创意性",
|
||||
"description": "内容是否有趣、有创意",
|
||||
"weight": 0.15
|
||||
}
|
||||
}
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型和分词器"""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
print(f"Loading scorer tokenizer from: {self.base_model_path}")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.base_model_path,
|
||||
use_fast=False,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
print(f"Loading scorer model from: {self.base_model_path}")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_path,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ AI评分器模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
def score_dialogue(self,
|
||||
dialogue_content: str,
|
||||
speaker: str,
|
||||
character_data: Dict,
|
||||
dialogue_history: List[Dict] = None,
|
||||
context_info: List[Dict] = None) -> ScoreResult:
|
||||
"""
|
||||
对单条对话进行AI评分
|
||||
|
||||
Args:
|
||||
dialogue_content: 对话内容
|
||||
speaker: 说话者
|
||||
character_data: 角色数据
|
||||
dialogue_history: 对话历史
|
||||
context_info: 上下文信息
|
||||
|
||||
Returns:
|
||||
ScoreResult: 评分结果
|
||||
"""
|
||||
# 构建评分提示
|
||||
scoring_prompt = self._build_scoring_prompt(
|
||||
dialogue_content, speaker, character_data, dialogue_history, context_info
|
||||
)
|
||||
|
||||
# 使用AI模型生成评分
|
||||
try:
|
||||
scores, feedback = self._generate_ai_scores(scoring_prompt)
|
||||
|
||||
# 计算总分
|
||||
overall_score = self._calculate_overall_score(scores)
|
||||
|
||||
# 创建评分结果
|
||||
result = ScoreResult(
|
||||
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
session_id="", # 由调用方设置
|
||||
speaker=speaker,
|
||||
content=dialogue_content,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
scores=scores,
|
||||
overall_score=overall_score,
|
||||
feedback=feedback,
|
||||
scorer_type="ai"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ AI评分失败: {e}")
|
||||
# 返回默认评分
|
||||
return self._create_default_score(dialogue_content, speaker)
|
||||
|
||||
def _build_scoring_prompt(self,
|
||||
dialogue_content: str,
|
||||
speaker: str,
|
||||
character_data: Dict,
|
||||
dialogue_history: List[Dict] = None,
|
||||
context_info: List[Dict] = None) -> str:
|
||||
"""构建评分提示"""
|
||||
|
||||
# 基础角色信息
|
||||
character_info = ""
|
||||
if character_data:
|
||||
personality = character_data.get('personality', {})
|
||||
traits = personality.get('core_traits', [])
|
||||
occupation = character_data.get('basic_info', {}).get('occupation', '未知')
|
||||
character_info = f"角色职业: {occupation}, 性格特点: {', '.join(traits[:3])}"
|
||||
|
||||
# 对话历史
|
||||
history_text = ""
|
||||
if dialogue_history:
|
||||
history_text = "对话历史:\n"
|
||||
for turn in dialogue_history[-3:]: # 只取最近3轮
|
||||
history_text += f"{turn.get('speaker', '未知')}: {turn.get('content', '')}\n"
|
||||
|
||||
# 构建完整提示
|
||||
prompt = f"""请对以下对话内容进行质量评分。
|
||||
|
||||
角色设定:
|
||||
{character_info}
|
||||
|
||||
{history_text}
|
||||
|
||||
当前对话:
|
||||
{speaker}: {dialogue_content}
|
||||
|
||||
请从以下5个维度评分(1-10分):
|
||||
1. 连贯性 - 对话是否与上下文逻辑连贯
|
||||
2. 角色一致性 - 是否符合角色设定和人格特征
|
||||
3. 自然度 - 语言表达是否自然流畅
|
||||
4. 信息密度 - 是否包含有意义的信息,避免废话
|
||||
5. 创意性 - 内容是否有趣、有创意
|
||||
|
||||
请按以下格式输出:
|
||||
连贯性: X分
|
||||
角色一致性: X分
|
||||
自然度: X分
|
||||
信息密度: X分
|
||||
创意性: X分
|
||||
|
||||
总体评价: [具体的改进建议和优点分析]"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _generate_ai_scores(self, prompt: str) -> Tuple[Dict[str, float], str]:
|
||||
"""使用AI模型生成评分"""
|
||||
import torch
|
||||
|
||||
# 准备消息
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的对话质量评估专家,请客观公正地评分。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# 应用对话模板
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
|
||||
# 移动到设备
|
||||
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||
|
||||
# 生成评分
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=300,
|
||||
do_sample=True,
|
||||
temperature=0.3, # 较低温度确保评分稳定
|
||||
top_p=0.8,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
repetition_penalty=1.1
|
||||
)
|
||||
|
||||
# 解码输出
|
||||
response = outputs[0][inputs['input_ids'].shape[1]:]
|
||||
result_text = self.tokenizer.decode(response, skip_special_tokens=True).strip()
|
||||
|
||||
# 解析评分结果
|
||||
scores, feedback = self._parse_score_response(result_text)
|
||||
|
||||
return scores, feedback
|
||||
|
||||
def _parse_score_response(self, response: str) -> Tuple[Dict[str, float], str]:
|
||||
"""解析AI评分响应"""
|
||||
scores = {}
|
||||
feedback = ""
|
||||
|
||||
# 定义维度映射
|
||||
dimension_map = {
|
||||
"连贯性": "coherence",
|
||||
"角色一致性": "character_consistency",
|
||||
"自然度": "naturalness",
|
||||
"信息密度": "information_density",
|
||||
"创意性": "creativity"
|
||||
}
|
||||
|
||||
try:
|
||||
lines = response.split('\n')
|
||||
feedback_start = False
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 查找评分
|
||||
for chinese_name, english_key in dimension_map.items():
|
||||
if chinese_name in line and ':' in line:
|
||||
# 提取分数
|
||||
score_match = re.search(r'(\d+(?:\.\d+)?)', line)
|
||||
if score_match:
|
||||
score = float(score_match.group(1))
|
||||
# 确保分数在1-10范围内
|
||||
score = max(1.0, min(10.0, score))
|
||||
scores[english_key] = score
|
||||
|
||||
# 查找总体评价
|
||||
if '总体评价' in line or '评价' in line:
|
||||
feedback_start = True
|
||||
feedback_content = line.split(':', 1)
|
||||
if len(feedback_content) > 1:
|
||||
feedback += feedback_content[1].strip()
|
||||
elif feedback_start and line:
|
||||
feedback += " " + line
|
||||
|
||||
# 确保所有维度都有分数
|
||||
for english_key in dimension_map.values():
|
||||
if english_key not in scores:
|
||||
scores[english_key] = 5.0 # 默认中等分数
|
||||
|
||||
if not feedback:
|
||||
feedback = "AI评分完成,建议根据各维度分数进行改进。"
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析评分响应失败: {e}")
|
||||
# 使用默认分数
|
||||
for english_key in dimension_map.values():
|
||||
scores[english_key] = 5.0
|
||||
feedback = "评分解析失败,使用默认分数。"
|
||||
|
||||
return scores, feedback
|
||||
|
||||
def _calculate_overall_score(self, scores: Dict[str, float]) -> float:
|
||||
"""计算总分"""
|
||||
total_score = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for dimension, score in scores.items():
|
||||
if dimension in self.score_dimensions:
|
||||
weight = self.score_dimensions[dimension]["weight"]
|
||||
total_score += score * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight > 0:
|
||||
return round(total_score / total_weight, 2)
|
||||
else:
|
||||
return 5.0
|
||||
|
||||
def _create_default_score(self, dialogue_content: str, speaker: str) -> ScoreResult:
|
||||
"""创建默认评分结果"""
|
||||
default_scores = {
|
||||
"coherence": 5.0,
|
||||
"character_consistency": 5.0,
|
||||
"naturalness": 5.0,
|
||||
"information_density": 5.0,
|
||||
"creativity": 5.0
|
||||
}
|
||||
|
||||
return ScoreResult(
|
||||
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
session_id="",
|
||||
speaker=speaker,
|
||||
content=dialogue_content,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
scores=default_scores,
|
||||
overall_score=5.0,
|
||||
feedback="使用默认评分",
|
||||
scorer_type="ai"
|
||||
)
|
||||
|
||||
def batch_score_dialogue(self, dialogue_list: List[Dict]) -> List[ScoreResult]:
|
||||
"""批量评分对话"""
|
||||
results = []
|
||||
|
||||
for i, dialogue_item in enumerate(dialogue_list):
|
||||
print(f"正在评分 {i+1}/{len(dialogue_list)}: {dialogue_item.get('speaker', '未知')}")
|
||||
|
||||
try:
|
||||
result = self.score_dialogue(
|
||||
dialogue_content=dialogue_item.get('content', ''),
|
||||
speaker=dialogue_item.get('speaker', '未知'),
|
||||
character_data=dialogue_item.get('character_data', {}),
|
||||
dialogue_history=dialogue_item.get('dialogue_history', []),
|
||||
context_info=dialogue_item.get('context_info', [])
|
||||
)
|
||||
|
||||
# 设置session_id
|
||||
result.session_id = dialogue_item.get('session_id', '')
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"评分失败: {e}")
|
||||
# 添加默认评分
|
||||
default_result = self._create_default_score(
|
||||
dialogue_item.get('content', ''),
|
||||
dialogue_item.get('speaker', '未知')
|
||||
)
|
||||
default_result.session_id = dialogue_item.get('session_id', '')
|
||||
results.append(default_result)
|
||||
|
||||
return results
|
||||
|
||||
class HumanScorer:
|
||||
"""人工评分器"""
|
||||
|
||||
def __init__(self):
|
||||
self.score_dimensions = {
|
||||
"coherence": "连贯性",
|
||||
"character_consistency": "角色一致性",
|
||||
"naturalness": "自然度",
|
||||
"information_density": "信息密度",
|
||||
"creativity": "创意性"
|
||||
}
|
||||
|
||||
def score_dialogue_interactive(self,
|
||||
dialogue_content: str,
|
||||
speaker: str,
|
||||
session_id: str = "") -> ScoreResult:
|
||||
"""交互式人工评分"""
|
||||
|
||||
print(f"\n=== 人工评分 ===")
|
||||
print(f"角色: {speaker}")
|
||||
print(f"对话: {dialogue_content}")
|
||||
print(f"请对以下维度评分 (1-10分):")
|
||||
|
||||
scores = {}
|
||||
|
||||
for dimension_key, dimension_name in self.score_dimensions.items():
|
||||
while True:
|
||||
try:
|
||||
score_input = input(f"{dimension_name} (1-10): ").strip()
|
||||
score = float(score_input)
|
||||
if 1 <= score <= 10:
|
||||
scores[dimension_key] = score
|
||||
break
|
||||
else:
|
||||
print("请输入1-10之间的分数")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
# 获取反馈
|
||||
feedback = input("请输入评价和建议 (可选): ").strip()
|
||||
if not feedback:
|
||||
feedback = "人工评分完成"
|
||||
|
||||
# 计算总分
|
||||
overall_score = sum(scores.values()) / len(scores)
|
||||
|
||||
return ScoreResult(
|
||||
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
session_id=session_id,
|
||||
speaker=speaker,
|
||||
content=dialogue_content,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
scores=scores,
|
||||
overall_score=round(overall_score, 2),
|
||||
feedback=feedback,
|
||||
scorer_type="human"
|
||||
)
|
||||
|
||||
class QuickScorer:
|
||||
"""快速规则评分器(用于实时反馈)"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def quick_score(self,
|
||||
dialogue_content: str,
|
||||
speaker: str,
|
||||
dialogue_history: List[Dict] = None) -> float:
|
||||
"""快速评分(基于规则)"""
|
||||
|
||||
score = 5.0 # 基础分
|
||||
|
||||
# 长度检查
|
||||
content_length = len(dialogue_content.strip())
|
||||
if content_length < 10:
|
||||
score -= 2.0 # 太短
|
||||
elif content_length > 200:
|
||||
score -= 1.0 # 太长
|
||||
elif 30 <= content_length <= 100:
|
||||
score += 0.5 # 长度适中
|
||||
|
||||
# 重复检查
|
||||
if dialogue_history:
|
||||
recent_content = [turn.get('content', '') for turn in dialogue_history[-3:]]
|
||||
for prev_content in recent_content:
|
||||
if dialogue_content == prev_content:
|
||||
score -= 3.0 # 重复内容
|
||||
elif self._calculate_similarity(dialogue_content, prev_content) > 0.8:
|
||||
score -= 1.5 # 高度相似
|
||||
|
||||
# 内容质量检查
|
||||
if any(word in dialogue_content for word in ['...', '呃', '额', '嗯嗯']):
|
||||
score -= 0.5 # 含有填充词
|
||||
|
||||
if re.search(r'[。!?]', dialogue_content):
|
||||
score += 0.3 # 有标点符号
|
||||
|
||||
# 确保分数在合理范围内
|
||||
return max(1.0, min(10.0, score))
|
||||
|
||||
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
||||
"""计算文本相似度(简单方法)"""
|
||||
words1 = set(text1)
|
||||
words2 = set(text2)
|
||||
|
||||
if not words1 and not words2:
|
||||
return 1.0
|
||||
|
||||
intersection = len(words1.intersection(words2))
|
||||
union = len(words1.union(words2))
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
@ -529,84 +529,15 @@ def run_model_optimization():
|
||||
|
||||
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
||||
|
||||
print("模型优化选项:")
|
||||
print("1. 分析优化需求")
|
||||
print("2. 生成LoRA训练脚本")
|
||||
print("3. 创建提示优化配置")
|
||||
print("4. 执行增量训练")
|
||||
print("5. 性能对比验证")
|
||||
# print("模型优化选项:")
|
||||
|
||||
choice = input("请输入选择 (1-5): ").strip()
|
||||
# print("1. 生成LoRA训练脚本")
|
||||
# print("2. 执行增量训练")
|
||||
|
||||
if choice == '1':
|
||||
# 分析优化需求
|
||||
print("\n=== 优化需求分析 ===")
|
||||
|
||||
with sqlite3.connect(conv_mgr.db_path) as conn:
|
||||
# 获取性能数据
|
||||
cursor = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
AVG(dialogue_score) as avg_score,
|
||||
AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate
|
||||
FROM dialogue_turns WHERE dialogue_score > 0
|
||||
""")
|
||||
# choice = input("请输入选择 (1-5): ").strip()
|
||||
|
||||
total, avg_score, hq_rate = cursor.fetchone()
|
||||
|
||||
print(f"当前性能指标:")
|
||||
print(f" - 总评分样本: {total}")
|
||||
print(f" - 平均分数: {avg_score:.2f}")
|
||||
print(f" - 高质量率: {hq_rate*100:.1f}%")
|
||||
|
||||
# 分析优化潜力
|
||||
optimization_needs = []
|
||||
|
||||
if avg_score < 7.0:
|
||||
optimization_needs.append("整体质量需要提升")
|
||||
|
||||
if hq_rate < 0.6:
|
||||
optimization_needs.append("高质量对话比例偏低")
|
||||
|
||||
# 分析各维度表现
|
||||
cursor = conn.execute("""
|
||||
SELECT score_details FROM dialogue_turns
|
||||
WHERE dialogue_score > 0 AND score_details != '{}'
|
||||
ORDER BY timestamp DESC LIMIT 100
|
||||
""")
|
||||
|
||||
dim_scores = {'coherence': [], 'character_consistency': [],
|
||||
'naturalness': [], 'information_density': [], 'creativity': []}
|
||||
|
||||
for (score_details,) in cursor.fetchall():
|
||||
try:
|
||||
scores = json.loads(score_details)
|
||||
for dim, score in scores.items():
|
||||
if dim in dim_scores:
|
||||
dim_scores[dim].append(float(score))
|
||||
except:
|
||||
continue
|
||||
|
||||
weak_dimensions = []
|
||||
print(f"\n维度分析:")
|
||||
for dim, scores in dim_scores.items():
|
||||
if scores:
|
||||
avg = sum(scores) / len(scores)
|
||||
print(f" - {dim}: {avg:.2f}分")
|
||||
if avg < 7.0:
|
||||
weak_dimensions.append(dim)
|
||||
|
||||
if weak_dimensions:
|
||||
optimization_needs.append(f"薄弱维度: {weak_dimensions}")
|
||||
|
||||
print(f"\n优化建议:")
|
||||
if optimization_needs:
|
||||
for i, need in enumerate(optimization_needs, 1):
|
||||
print(f" {i}. {need}")
|
||||
else:
|
||||
print(" 当前模型表现良好,可考虑细微调优")
|
||||
|
||||
elif choice == '2':
|
||||
# if choice == '1':
|
||||
# 生成LoRA训练脚本
|
||||
print("\n=== 生成LoRA训练脚本 ===")
|
||||
|
||||
@ -623,121 +554,50 @@ def run_model_optimization():
|
||||
print(" 2. 修改脚本中的路径配置")
|
||||
print(f" 3. 运行: python {script_path}")
|
||||
|
||||
elif choice == '3':
|
||||
# 创建提示优化配置
|
||||
print("\n=== 创建提示优化配置 ===")
|
||||
# elif choice == '2':
|
||||
# # 执行增量训练
|
||||
# print("\n=== 执行增量训练 ===")
|
||||
|
||||
config = generate_prompt_optimization_config()
|
||||
config_path = "./config/prompt_optimization.json"
|
||||
# # 检查训练数据
|
||||
# training_dir = "./training_data"
|
||||
# if not os.path.exists(training_dir):
|
||||
# print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
|
||||
# return
|
||||
|
||||
os.makedirs("./config", exist_ok=True)
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
# training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
|
||||
# if not training_files:
|
||||
# print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
|
||||
# return
|
||||
|
||||
print(f"✓ 提示优化配置已生成: {config_path}")
|
||||
print("配置包含:")
|
||||
print(" - 动态提示调整规则")
|
||||
print(" - 质量阈值设置")
|
||||
print(" - 生成参数优化")
|
||||
# print(f"找到训练数据文件:")
|
||||
# for i, file in enumerate(training_files, 1):
|
||||
# print(f" {i}. {file}")
|
||||
|
||||
elif choice == '4':
|
||||
# 执行增量训练
|
||||
print("\n=== 执行增量训练 ===")
|
||||
# file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
|
||||
# try:
|
||||
# selected_file = training_files[int(file_idx) - 1]
|
||||
# training_file_path = os.path.join(training_dir, selected_file)
|
||||
|
||||
# 检查训练数据
|
||||
training_dir = "./training_data"
|
||||
if not os.path.exists(training_dir):
|
||||
print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
|
||||
return
|
||||
# print(f"将使用训练文件: {selected_file}")
|
||||
# print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
|
||||
|
||||
training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
|
||||
if not training_files:
|
||||
print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
|
||||
return
|
||||
# # 生成训练命令
|
||||
# training_command = generate_training_command(training_file_path)
|
||||
# print(f"建议训练命令:")
|
||||
# print(f" {training_command}")
|
||||
|
||||
print(f"找到训练数据文件:")
|
||||
for i, file in enumerate(training_files, 1):
|
||||
print(f" {i}. {file}")
|
||||
# # 可选:执行训练(需要用户确认)
|
||||
# confirm = input("是否现在执行训练?(y/N): ").strip().lower()
|
||||
# if confirm == 'y':
|
||||
# print("开始增量训练...")
|
||||
# # 这里可以添加实际的训练执行逻辑
|
||||
# print("⚠ 训练功能需要根据实际环境配置")
|
||||
|
||||
file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
|
||||
try:
|
||||
selected_file = training_files[int(file_idx) - 1]
|
||||
training_file_path = os.path.join(training_dir, selected_file)
|
||||
# except (ValueError, IndexError):
|
||||
# print("❌ 无效的文件选择")
|
||||
|
||||
print(f"将使用训练文件: {selected_file}")
|
||||
print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
|
||||
|
||||
# 生成训练命令
|
||||
training_command = generate_training_command(training_file_path)
|
||||
print(f"建议训练命令:")
|
||||
print(f" {training_command}")
|
||||
|
||||
# 可选:执行训练(需要用户确认)
|
||||
confirm = input("是否现在执行训练?(y/N): ").strip().lower()
|
||||
if confirm == 'y':
|
||||
print("开始增量训练...")
|
||||
# 这里可以添加实际的训练执行逻辑
|
||||
print("⚠ 训练功能需要根据实际环境配置")
|
||||
|
||||
except (ValueError, IndexError):
|
||||
print("❌ 无效的文件选择")
|
||||
|
||||
elif choice == '5':
|
||||
# 性能对比验证
|
||||
print("\n=== 性能对比验证 ===")
|
||||
|
||||
print("验证选项:")
|
||||
print("1. 历史性能趋势对比")
|
||||
print("2. A/B测试配置生成")
|
||||
print("3. 模型版本性能对比")
|
||||
|
||||
verify_choice = input("请输入选择 (1-3): ").strip()
|
||||
|
||||
if verify_choice == '1':
|
||||
# 历史性能趋势
|
||||
with sqlite3.connect(conv_mgr.db_path) as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT
|
||||
DATE(timestamp) as date,
|
||||
AVG(dialogue_score) as avg_score,
|
||||
COUNT(*) as count
|
||||
FROM dialogue_turns
|
||||
WHERE dialogue_score > 0
|
||||
AND timestamp >= datetime('now', '-30 days')
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY date ASC
|
||||
""")
|
||||
|
||||
trend_data = cursor.fetchall()
|
||||
if trend_data:
|
||||
print(f"30天性能趋势:")
|
||||
for date, avg_score, count in trend_data:
|
||||
trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊"
|
||||
print(f" {date}: {avg_score:.2f}分 {trend} ({count}条对话)")
|
||||
else:
|
||||
print("暂无足够的历史数据")
|
||||
|
||||
elif verify_choice == '2':
|
||||
# A/B测试配置
|
||||
ab_config = generate_ab_test_config()
|
||||
ab_config_path = "./config/ab_test_config.json"
|
||||
|
||||
os.makedirs("./config", exist_ok=True)
|
||||
with open(ab_config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(ab_config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✓ A/B测试配置已生成: {ab_config_path}")
|
||||
print("配置包含:")
|
||||
print(" - 对照组和实验组设置")
|
||||
print(" - 评估指标定义")
|
||||
print(" - 测试持续时间配置")
|
||||
|
||||
elif verify_choice == '3':
|
||||
print("模型版本对比功能开发中...")
|
||||
print("建议手动记录不同版本的性能指标进行对比")
|
||||
|
||||
else:
|
||||
print("❌ 无效选择")
|
||||
# else:
|
||||
# print("❌ 无效选择")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 模型优化失败: {e}")
|
||||
@ -837,102 +697,8 @@ if __name__ == '__main__':
|
||||
trainer.train()
|
||||
'''
|
||||
|
||||
def generate_prompt_optimization_config():
|
||||
"""生成提示优化配置"""
|
||||
return {
|
||||
"optimization_rules": {
|
||||
"quality_thresholds": {
|
||||
"excellent": 9.0,
|
||||
"good": 8.0,
|
||||
"acceptable": 7.0,
|
||||
"needs_improvement": 6.0
|
||||
},
|
||||
"adaptive_adjustments": {
|
||||
"low_coherence": {
|
||||
"add_context_emphasis": True,
|
||||
"increase_history_weight": 1.2,
|
||||
"add_logical_constraints": True
|
||||
},
|
||||
"low_character_consistency": {
|
||||
"enhance_character_description": True,
|
||||
"add_personality_reminders": True,
|
||||
"increase_character_weight": 1.5
|
||||
},
|
||||
"low_creativity": {
|
||||
"add_creativity_prompts": True,
|
||||
"increase_temperature": 0.1,
|
||||
"diversify_examples": True
|
||||
}
|
||||
}
|
||||
},
|
||||
"dynamic_prompts": {
|
||||
"quality_boost_templates": [
|
||||
"请生成一个富有创意且符合角色特征的高质量对话",
|
||||
"确保对话逻辑连贯,信息丰富,表达自然",
|
||||
"体现角色的独特性格和说话风格"
|
||||
],
|
||||
"problem_specific_guidance": {
|
||||
"repetitive": "避免重复之前的内容,提供新的信息和观点",
|
||||
"inconsistent": "严格遵循角色设定,保持人格一致性",
|
||||
"dull": "增加对话的趣味性和深度,使用生动的表达"
|
||||
}
|
||||
},
|
||||
"generation_parameters": {
|
||||
"adaptive_temperature": {
|
||||
"high_creativity_needed": 0.9,
|
||||
"normal": 0.8,
|
||||
"high_consistency_needed": 0.7
|
||||
},
|
||||
"adaptive_top_p": {
|
||||
"creative_mode": 0.9,
|
||||
"balanced_mode": 0.8,
|
||||
"conservative_mode": 0.7
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def generate_ab_test_config():
|
||||
"""生成A/B测试配置"""
|
||||
return {
|
||||
"test_name": "model_optimization_validation",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"groups": {
|
||||
"control": {
|
||||
"name": "原始模型",
|
||||
"description": "未优化的基础模型",
|
||||
"model_path": "/path/to/base/model",
|
||||
"sample_ratio": 0.5
|
||||
},
|
||||
"experimental": {
|
||||
"name": "优化模型",
|
||||
"description": "经过迭代优化的模型",
|
||||
"model_path": "/path/to/optimized/model",
|
||||
"sample_ratio": 0.5
|
||||
}
|
||||
},
|
||||
"evaluation_metrics": {
|
||||
"primary": [
|
||||
"overall_dialogue_score",
|
||||
"character_consistency_score",
|
||||
"creativity_score"
|
||||
],
|
||||
"secondary": [
|
||||
"user_satisfaction",
|
||||
"response_time",
|
||||
"coherence_score"
|
||||
]
|
||||
},
|
||||
"test_duration": {
|
||||
"target_samples": 200,
|
||||
"max_duration_days": 7,
|
||||
"min_samples_per_group": 50
|
||||
},
|
||||
"statistical_settings": {
|
||||
"confidence_level": 0.95,
|
||||
"minimum_effect_size": 0.3,
|
||||
"power": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def generate_training_command(training_file_path):
|
||||
"""生成训练命令"""
|
||||
|
||||
617
AITrain/score_data_manager.py
Normal file
617
AITrain/score_data_manager.py
Normal file
@ -0,0 +1,617 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
对话评分数据存储管理系统
|
||||
管理评分结果、统计分析、数据导出等功能
|
||||
'''
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import asdict
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from dialogue_scorer import ScoreResult
|
||||
|
||||
class ScoreDataManager:
|
||||
"""评分数据管理器"""
|
||||
|
||||
def __init__(self, db_path: str = "score_data.db"):
|
||||
"""
|
||||
初始化评分数据管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""初始化数据库"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# 评分结果表
|
||||
conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS score_results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
dialogue_id TEXT UNIQUE,
|
||||
session_id TEXT,
|
||||
speaker TEXT,
|
||||
content TEXT,
|
||||
timestamp TEXT,
|
||||
scorer_type TEXT,
|
||||
coherence_score REAL,
|
||||
character_consistency_score REAL,
|
||||
naturalness_score REAL,
|
||||
information_density_score REAL,
|
||||
creativity_score REAL,
|
||||
overall_score REAL,
|
||||
feedback TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# 评分统计表
|
||||
conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS score_statistics (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT,
|
||||
speaker TEXT,
|
||||
date_range TEXT,
|
||||
total_dialogues INTEGER,
|
||||
avg_overall_score REAL,
|
||||
avg_coherence REAL,
|
||||
avg_character_consistency REAL,
|
||||
avg_naturalness REAL,
|
||||
avg_information_density REAL,
|
||||
avg_creativity REAL,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# 模型训练数据表
|
||||
conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS training_data (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
dialogue_content TEXT,
|
||||
speaker TEXT,
|
||||
character_data TEXT,
|
||||
context_info TEXT,
|
||||
dialogue_history TEXT,
|
||||
score_overall REAL,
|
||||
score_dimensions TEXT,
|
||||
feedback TEXT,
|
||||
is_high_quality INTEGER,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# 优化记录表
|
||||
conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS optimization_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
optimization_type TEXT,
|
||||
before_score REAL,
|
||||
after_score REAL,
|
||||
improvement REAL,
|
||||
parameters_changed TEXT,
|
||||
description TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
|
||||
def save_score_result(self, score_result: ScoreResult) -> bool:
|
||||
"""
|
||||
保存评分结果
|
||||
|
||||
Args:
|
||||
score_result: 评分结果对象
|
||||
|
||||
Returns:
|
||||
bool: 保存是否成功
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute('''
|
||||
INSERT OR REPLACE INTO score_results
|
||||
(dialogue_id, session_id, speaker, content, timestamp, scorer_type,
|
||||
coherence_score, character_consistency_score, naturalness_score,
|
||||
information_density_score, creativity_score, overall_score, feedback)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
score_result.dialogue_id,
|
||||
score_result.session_id,
|
||||
score_result.speaker,
|
||||
score_result.content,
|
||||
score_result.timestamp,
|
||||
score_result.scorer_type,
|
||||
score_result.scores.get('coherence', 0),
|
||||
score_result.scores.get('character_consistency', 0),
|
||||
score_result.scores.get('naturalness', 0),
|
||||
score_result.scores.get('information_density', 0),
|
||||
score_result.scores.get('creativity', 0),
|
||||
score_result.overall_score,
|
||||
score_result.feedback
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"保存评分结果失败: {e}")
|
||||
return False
|
||||
|
||||
def get_score_results(self,
|
||||
session_id: Optional[str] = None,
|
||||
speaker: Optional[str] = None,
|
||||
min_score: Optional[float] = None,
|
||||
max_score: Optional[float] = None,
|
||||
limit: int = 100) -> List[Dict]:
|
||||
"""
|
||||
获取评分结果
|
||||
|
||||
Args:
|
||||
session_id: 会话ID过滤
|
||||
speaker: 角色过滤
|
||||
min_score: 最低分数过滤
|
||||
max_score: 最高分数过滤
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 评分结果列表
|
||||
"""
|
||||
query = "SELECT * FROM score_results WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if session_id:
|
||||
query += " AND session_id = ?"
|
||||
params.append(session_id)
|
||||
|
||||
if speaker:
|
||||
query += " AND speaker = ?"
|
||||
params.append(speaker)
|
||||
|
||||
if min_score is not None:
|
||||
query += " AND overall_score >= ?"
|
||||
params.append(min_score)
|
||||
|
||||
if max_score is not None:
|
||||
query += " AND overall_score <= ?"
|
||||
params.append(max_score)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def calculate_statistics(self,
|
||||
session_id: Optional[str] = None,
|
||||
speaker: Optional[str] = None,
|
||||
days: int = 7) -> Dict:
|
||||
"""
|
||||
计算评分统计信息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID过滤
|
||||
speaker: 角色过滤
|
||||
days: 统计天数
|
||||
|
||||
Returns:
|
||||
Dict: 统计信息
|
||||
"""
|
||||
# 计算时间范围
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
query = '''
|
||||
SELECT
|
||||
COUNT(*) as total_dialogues,
|
||||
AVG(overall_score) as avg_overall,
|
||||
AVG(coherence_score) as avg_coherence,
|
||||
AVG(character_consistency_score) as avg_character_consistency,
|
||||
AVG(naturalness_score) as avg_naturalness,
|
||||
AVG(information_density_score) as avg_information_density,
|
||||
AVG(creativity_score) as avg_creativity,
|
||||
MIN(overall_score) as min_score,
|
||||
MAX(overall_score) as max_score,
|
||||
speaker
|
||||
FROM score_results
|
||||
WHERE created_at >= ? AND created_at <= ?
|
||||
'''
|
||||
|
||||
params = [start_date.isoformat(), end_date.isoformat()]
|
||||
|
||||
if session_id:
|
||||
query += " AND session_id = ?"
|
||||
params.append(session_id)
|
||||
|
||||
if speaker:
|
||||
query += " AND speaker = ?"
|
||||
params.append(speaker)
|
||||
else:
|
||||
query += " GROUP BY speaker"
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(query, params)
|
||||
results = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
return {
|
||||
"period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}",
|
||||
"statistics": results
|
||||
}
|
||||
|
||||
def get_high_quality_samples(self, threshold: float = 8.0, limit: int = 50) -> List[Dict]:
|
||||
"""
|
||||
获取高质量对话样本
|
||||
|
||||
Args:
|
||||
threshold: 分数阈值
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 高质量样本列表
|
||||
"""
|
||||
query = '''
|
||||
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
|
||||
FROM score_results
|
||||
WHERE overall_score >= ?
|
||||
ORDER BY overall_score DESC
|
||||
LIMIT ?
|
||||
'''
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(query, [threshold, limit])
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_low_quality_samples(self, threshold: float = 4.0, limit: int = 50) -> List[Dict]:
|
||||
"""
|
||||
获取低质量对话样本
|
||||
|
||||
Args:
|
||||
threshold: 分数阈值
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 低质量样本列表
|
||||
"""
|
||||
query = '''
|
||||
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
|
||||
FROM score_results
|
||||
WHERE overall_score <= ?
|
||||
ORDER BY overall_score ASC
|
||||
LIMIT ?
|
||||
'''
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(query, [threshold, limit])
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def save_training_data(self,
|
||||
dialogue_content: str,
|
||||
speaker: str,
|
||||
character_data: Dict,
|
||||
context_info: List[Dict],
|
||||
dialogue_history: List[Dict],
|
||||
score_result: ScoreResult) -> bool:
|
||||
"""
|
||||
保存用于模型训练的数据
|
||||
|
||||
Args:
|
||||
dialogue_content: 对话内容
|
||||
speaker: 说话者
|
||||
character_data: 角色数据
|
||||
context_info: 上下文信息
|
||||
dialogue_history: 对话历史
|
||||
score_result: 评分结果
|
||||
|
||||
Returns:
|
||||
bool: 保存是否成功
|
||||
"""
|
||||
try:
|
||||
# 判断是否为高质量样本
|
||||
is_high_quality = 1 if score_result.overall_score >= 7.0 else 0
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute('''
|
||||
INSERT INTO training_data
|
||||
(dialogue_content, speaker, character_data, context_info,
|
||||
dialogue_history, score_overall, score_dimensions, feedback, is_high_quality)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
dialogue_content,
|
||||
speaker,
|
||||
json.dumps(character_data, ensure_ascii=False),
|
||||
json.dumps(context_info, ensure_ascii=False),
|
||||
json.dumps([asdict(turn) if hasattr(turn, '__dict__') else turn for turn in dialogue_history], ensure_ascii=False),
|
||||
score_result.overall_score,
|
||||
json.dumps(score_result.scores, ensure_ascii=False),
|
||||
score_result.feedback,
|
||||
is_high_quality
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"保存训练数据失败: {e}")
|
||||
return False
|
||||
|
||||
def export_to_csv(self, output_path: str = "score_data_export.csv") -> bool:
|
||||
"""
|
||||
导出评分数据到CSV文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
bool: 导出是否成功
|
||||
"""
|
||||
try:
|
||||
query = '''
|
||||
SELECT dialogue_id, session_id, speaker, content, timestamp, scorer_type,
|
||||
coherence_score, character_consistency_score, naturalness_score,
|
||||
information_density_score, creativity_score, overall_score, feedback
|
||||
FROM score_results
|
||||
ORDER BY created_at DESC
|
||||
'''
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
df = pd.read_sql_query(query, conn)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
print(f"✓ 评分数据已导出到: {output_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"导出数据失败: {e}")
|
||||
return False
|
||||
|
||||
def get_score_trends(self, speaker: str, days: int = 30) -> Dict:
|
||||
"""
|
||||
获取评分趋势数据
|
||||
|
||||
Args:
|
||||
speaker: 角色名称
|
||||
days: 统计天数
|
||||
|
||||
Returns:
|
||||
Dict: 趋势数据
|
||||
"""
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
query = '''
|
||||
SELECT DATE(created_at) as date,
|
||||
AVG(overall_score) as avg_score,
|
||||
COUNT(*) as count
|
||||
FROM score_results
|
||||
WHERE speaker = ? AND created_at >= ? AND created_at <= ?
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date
|
||||
'''
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(query, [speaker, start_date.isoformat(), end_date.isoformat()])
|
||||
results = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
return {
|
||||
"speaker": speaker,
|
||||
"period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}",
|
||||
"trend_data": results
|
||||
}
|
||||
|
||||
def save_optimization_record(self,
|
||||
optimization_type: str,
|
||||
before_score: float,
|
||||
after_score: float,
|
||||
parameters_changed: Dict,
|
||||
description: str = "") -> bool:
|
||||
"""
|
||||
保存优化记录
|
||||
|
||||
Args:
|
||||
optimization_type: 优化类型
|
||||
before_score: 优化前分数
|
||||
after_score: 优化后分数
|
||||
parameters_changed: 改变的参数
|
||||
description: 描述
|
||||
|
||||
Returns:
|
||||
bool: 保存是否成功
|
||||
"""
|
||||
try:
|
||||
improvement = after_score - before_score
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute('''
|
||||
INSERT INTO optimization_records
|
||||
(optimization_type, before_score, after_score, improvement,
|
||||
parameters_changed, description)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
optimization_type,
|
||||
before_score,
|
||||
after_score,
|
||||
improvement,
|
||||
json.dumps(parameters_changed, ensure_ascii=False),
|
||||
description
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"保存优化记录失败: {e}")
|
||||
return False
|
||||
|
||||
def get_optimization_history(self, limit: int = 20) -> List[Dict]:
|
||||
"""
|
||||
获取优化历史记录
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 优化历史记录
|
||||
"""
|
||||
query = '''
|
||||
SELECT * FROM optimization_records
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
'''
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(query, [limit])
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
class ScoreAnalyzer:
|
||||
"""评分分析器"""
|
||||
|
||||
def __init__(self, score_manager: ScoreDataManager):
|
||||
"""
|
||||
初始化评分分析器
|
||||
|
||||
Args:
|
||||
score_manager: 评分数据管理器
|
||||
"""
|
||||
self.score_manager = score_manager
|
||||
|
||||
def analyze_character_performance(self, character_name: str) -> Dict:
|
||||
"""
|
||||
分析角色表现
|
||||
|
||||
Args:
|
||||
character_name: 角色名称
|
||||
|
||||
Returns:
|
||||
Dict: 分析结果
|
||||
"""
|
||||
# 获取角色的所有评分数据
|
||||
scores = self.score_manager.get_score_results(speaker=character_name, limit=1000)
|
||||
|
||||
if not scores:
|
||||
return {"error": f"没有找到角色 {character_name} 的评分数据"}
|
||||
|
||||
# 计算各种统计信息
|
||||
overall_scores = [score['overall_score'] for score in scores]
|
||||
|
||||
analysis = {
|
||||
"character_name": character_name,
|
||||
"total_dialogues": len(scores),
|
||||
"average_score": round(np.mean(overall_scores), 2),
|
||||
"median_score": round(np.median(overall_scores), 2),
|
||||
"score_std": round(np.std(overall_scores), 2),
|
||||
"min_score": min(overall_scores),
|
||||
"max_score": max(overall_scores),
|
||||
"score_distribution": self._calculate_score_distribution(overall_scores),
|
||||
"dimension_analysis": self._analyze_dimensions(scores),
|
||||
"improvement_suggestions": self._generate_improvement_suggestions(scores)
|
||||
}
|
||||
|
||||
return analysis
|
||||
|
||||
def _calculate_score_distribution(self, scores: List[float]) -> Dict:
|
||||
"""计算分数分布"""
|
||||
distribution = {
|
||||
"excellent": len([s for s in scores if s >= 8.0]),
|
||||
"good": len([s for s in scores if 6.0 <= s < 8.0]),
|
||||
"average": len([s for s in scores if 4.0 <= s < 6.0]),
|
||||
"poor": len([s for s in scores if s < 4.0])
|
||||
}
|
||||
|
||||
total = len(scores)
|
||||
if total > 0:
|
||||
distribution = {k: {"count": v, "percentage": round(v/total*100, 1)}
|
||||
for k, v in distribution.items()}
|
||||
|
||||
return distribution
|
||||
|
||||
def _analyze_dimensions(self, scores: List[Dict]) -> Dict:
|
||||
"""分析各维度表现"""
|
||||
dimensions = ['coherence_score', 'character_consistency_score', 'naturalness_score',
|
||||
'information_density_score', 'creativity_score']
|
||||
|
||||
dimension_analysis = {}
|
||||
|
||||
for dim in dimensions:
|
||||
dim_scores = [score[dim] for score in scores if score[dim] is not None]
|
||||
if dim_scores:
|
||||
dimension_analysis[dim] = {
|
||||
"average": round(np.mean(dim_scores), 2),
|
||||
"min": min(dim_scores),
|
||||
"max": max(dim_scores),
|
||||
"std": round(np.std(dim_scores), 2)
|
||||
}
|
||||
|
||||
return dimension_analysis
|
||||
|
||||
def _generate_improvement_suggestions(self, scores: List[Dict]) -> List[str]:
|
||||
"""生成改进建议"""
|
||||
suggestions = []
|
||||
|
||||
# 分析最弱的维度
|
||||
dimensions = {
|
||||
'coherence_score': '连贯性',
|
||||
'character_consistency_score': '角色一致性',
|
||||
'naturalness_score': '自然度',
|
||||
'information_density_score': '信息密度',
|
||||
'creativity_score': '创意性'
|
||||
}
|
||||
|
||||
dim_averages = {}
|
||||
for dim_key, dim_name in dimensions.items():
|
||||
dim_scores = [score[dim_key] for score in scores if score[dim_key] is not None]
|
||||
if dim_scores:
|
||||
dim_averages[dim_name] = np.mean(dim_scores)
|
||||
|
||||
# 找出最弱的维度
|
||||
if dim_averages:
|
||||
weakest_dim = min(dim_averages.items(), key=lambda x: x[1])
|
||||
if weakest_dim[1] < 6.0:
|
||||
suggestions.append(f"需要重点改进{weakest_dim[0]},当前平均分为{weakest_dim[1]:.1f}")
|
||||
|
||||
# 根据分数分布给建议
|
||||
overall_avg = np.mean([score['overall_score'] for score in scores])
|
||||
if overall_avg < 5.0:
|
||||
suggestions.append("整体表现需要改进,建议调整角色设定或提示词")
|
||||
elif overall_avg < 7.0:
|
||||
suggestions.append("表现良好,可以尝试增加对话的创意性和深度")
|
||||
else:
|
||||
suggestions.append("表现优秀,继续保持当前设定")
|
||||
|
||||
return suggestions
|
||||
|
||||
def compare_characters(self, character_names: List[str]) -> Dict:
|
||||
"""
|
||||
比较多个角色的表现
|
||||
|
||||
Args:
|
||||
character_names: 角色名称列表
|
||||
|
||||
Returns:
|
||||
Dict: 比较结果
|
||||
"""
|
||||
comparison = {"characters": {}}
|
||||
|
||||
for char_name in character_names:
|
||||
scores = self.score_manager.get_score_results(speaker=char_name, limit=500)
|
||||
if scores:
|
||||
overall_scores = [score['overall_score'] for score in scores]
|
||||
comparison["characters"][char_name] = {
|
||||
"total_dialogues": len(scores),
|
||||
"average_score": round(np.mean(overall_scores), 2),
|
||||
"score_range": f"{min(overall_scores):.1f} - {max(overall_scores):.1f}"
|
||||
}
|
||||
|
||||
# 排序
|
||||
if comparison["characters"]:
|
||||
sorted_chars = sorted(comparison["characters"].items(),
|
||||
key=lambda x: x[1]["average_score"], reverse=True)
|
||||
comparison["ranking"] = [{"character": char, **data} for char, data in sorted_chars]
|
||||
|
||||
return comparison
|
||||
Loading…
x
Reference in New Issue
Block a user