修改打分bug
This commit is contained in:
parent
5bc280c539
commit
4c29745fa3
497
AITrain/dialogue_scorer.py
Normal file
497
AITrain/dialogue_scorer.py
Normal file
@ -0,0 +1,497 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
'''
|
||||||
|
对话质量评分系统
|
||||||
|
使用AI模型对生成的对话进行多维度质量评分
|
||||||
|
'''
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 尝试导入numpy,如果失败则跳过
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
NUMPY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
NUMPY_AVAILABLE = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreResult:
|
||||||
|
"""单次打分结果"""
|
||||||
|
dialogue_id: str
|
||||||
|
session_id: str
|
||||||
|
speaker: str
|
||||||
|
content: str
|
||||||
|
timestamp: str
|
||||||
|
scores: Dict[str, float] # 各维度分数
|
||||||
|
overall_score: float # 总分
|
||||||
|
feedback: str # 反馈意见
|
||||||
|
scorer_type: str # 打分器类型 (ai/human)
|
||||||
|
|
||||||
|
class DialogueAIScorer:
|
||||||
|
"""AI对话质量评分器"""
|
||||||
|
|
||||||
|
def __init__(self, base_model_path: str, tokenizer=None, model=None):
|
||||||
|
"""
|
||||||
|
初始化AI评分器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model_path: 基础模型路径
|
||||||
|
tokenizer: 分词器(可选,复用现有的)
|
||||||
|
model: 模型(可选,复用现有的)
|
||||||
|
"""
|
||||||
|
self.base_model_path = base_model_path
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# 如果没有传入模型,则加载
|
||||||
|
if self.tokenizer is None or self.model is None:
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
# 评分维度定义
|
||||||
|
self.score_dimensions = {
|
||||||
|
"coherence": {
|
||||||
|
"name": "连贯性",
|
||||||
|
"description": "对话是否与上下文逻辑连贯",
|
||||||
|
"weight": 0.25
|
||||||
|
},
|
||||||
|
"character_consistency": {
|
||||||
|
"name": "角色一致性",
|
||||||
|
"description": "是否符合角色设定和人格特征",
|
||||||
|
"weight": 0.25
|
||||||
|
},
|
||||||
|
"naturalness": {
|
||||||
|
"name": "自然度",
|
||||||
|
"description": "语言表达是否自然流畅",
|
||||||
|
"weight": 0.20
|
||||||
|
},
|
||||||
|
"information_density": {
|
||||||
|
"name": "信息密度",
|
||||||
|
"description": "是否包含有意义的信息,避免废话",
|
||||||
|
"weight": 0.15
|
||||||
|
},
|
||||||
|
"creativity": {
|
||||||
|
"name": "创意性",
|
||||||
|
"description": "内容是否有趣、有创意",
|
||||||
|
"weight": 0.15
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载模型和分词器"""
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
print(f"Loading scorer tokenizer from: {self.base_model_path}")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.base_model_path,
|
||||||
|
use_fast=False,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
print(f"Loading scorer model from: {self.base_model_path}")
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.base_model_path,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ AI评分器模型加载失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def score_dialogue(self,
|
||||||
|
dialogue_content: str,
|
||||||
|
speaker: str,
|
||||||
|
character_data: Dict,
|
||||||
|
dialogue_history: List[Dict] = None,
|
||||||
|
context_info: List[Dict] = None) -> ScoreResult:
|
||||||
|
"""
|
||||||
|
对单条对话进行AI评分
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogue_content: 对话内容
|
||||||
|
speaker: 说话者
|
||||||
|
character_data: 角色数据
|
||||||
|
dialogue_history: 对话历史
|
||||||
|
context_info: 上下文信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScoreResult: 评分结果
|
||||||
|
"""
|
||||||
|
# 构建评分提示
|
||||||
|
scoring_prompt = self._build_scoring_prompt(
|
||||||
|
dialogue_content, speaker, character_data, dialogue_history, context_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用AI模型生成评分
|
||||||
|
try:
|
||||||
|
scores, feedback = self._generate_ai_scores(scoring_prompt)
|
||||||
|
|
||||||
|
# 计算总分
|
||||||
|
overall_score = self._calculate_overall_score(scores)
|
||||||
|
|
||||||
|
# 创建评分结果
|
||||||
|
result = ScoreResult(
|
||||||
|
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
session_id="", # 由调用方设置
|
||||||
|
speaker=speaker,
|
||||||
|
content=dialogue_content,
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
scores=scores,
|
||||||
|
overall_score=overall_score,
|
||||||
|
feedback=feedback,
|
||||||
|
scorer_type="ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ AI评分失败: {e}")
|
||||||
|
# 返回默认评分
|
||||||
|
return self._create_default_score(dialogue_content, speaker)
|
||||||
|
|
||||||
|
def _build_scoring_prompt(self,
|
||||||
|
dialogue_content: str,
|
||||||
|
speaker: str,
|
||||||
|
character_data: Dict,
|
||||||
|
dialogue_history: List[Dict] = None,
|
||||||
|
context_info: List[Dict] = None) -> str:
|
||||||
|
"""构建评分提示"""
|
||||||
|
|
||||||
|
# 基础角色信息
|
||||||
|
character_info = ""
|
||||||
|
if character_data:
|
||||||
|
personality = character_data.get('personality', {})
|
||||||
|
traits = personality.get('core_traits', [])
|
||||||
|
occupation = character_data.get('basic_info', {}).get('occupation', '未知')
|
||||||
|
character_info = f"角色职业: {occupation}, 性格特点: {', '.join(traits[:3])}"
|
||||||
|
|
||||||
|
# 对话历史
|
||||||
|
history_text = ""
|
||||||
|
if dialogue_history:
|
||||||
|
history_text = "对话历史:\n"
|
||||||
|
for turn in dialogue_history[-3:]: # 只取最近3轮
|
||||||
|
history_text += f"{turn.get('speaker', '未知')}: {turn.get('content', '')}\n"
|
||||||
|
|
||||||
|
# 构建完整提示
|
||||||
|
prompt = f"""请对以下对话内容进行质量评分。
|
||||||
|
|
||||||
|
角色设定:
|
||||||
|
{character_info}
|
||||||
|
|
||||||
|
{history_text}
|
||||||
|
|
||||||
|
当前对话:
|
||||||
|
{speaker}: {dialogue_content}
|
||||||
|
|
||||||
|
请从以下5个维度评分(1-10分):
|
||||||
|
1. 连贯性 - 对话是否与上下文逻辑连贯
|
||||||
|
2. 角色一致性 - 是否符合角色设定和人格特征
|
||||||
|
3. 自然度 - 语言表达是否自然流畅
|
||||||
|
4. 信息密度 - 是否包含有意义的信息,避免废话
|
||||||
|
5. 创意性 - 内容是否有趣、有创意
|
||||||
|
|
||||||
|
请按以下格式输出:
|
||||||
|
连贯性: X分
|
||||||
|
角色一致性: X分
|
||||||
|
自然度: X分
|
||||||
|
信息密度: X分
|
||||||
|
创意性: X分
|
||||||
|
|
||||||
|
总体评价: [具体的改进建议和优点分析]"""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _generate_ai_scores(self, prompt: str) -> Tuple[Dict[str, float], str]:
|
||||||
|
"""使用AI模型生成评分"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 准备消息
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个专业的对话质量评估专家,请客观公正地评分。"},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 应用对话模板
|
||||||
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
enable_thinking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 移动到设备
|
||||||
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 生成评分
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=300,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.3, # 较低温度确保评分稳定
|
||||||
|
top_p=0.8,
|
||||||
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
|
repetition_penalty=1.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解码输出
|
||||||
|
response = outputs[0][inputs['input_ids'].shape[1]:]
|
||||||
|
result_text = self.tokenizer.decode(response, skip_special_tokens=True).strip()
|
||||||
|
|
||||||
|
# 解析评分结果
|
||||||
|
scores, feedback = self._parse_score_response(result_text)
|
||||||
|
|
||||||
|
return scores, feedback
|
||||||
|
|
||||||
|
def _parse_score_response(self, response: str) -> Tuple[Dict[str, float], str]:
|
||||||
|
"""解析AI评分响应"""
|
||||||
|
scores = {}
|
||||||
|
feedback = ""
|
||||||
|
|
||||||
|
# 定义维度映射
|
||||||
|
dimension_map = {
|
||||||
|
"连贯性": "coherence",
|
||||||
|
"角色一致性": "character_consistency",
|
||||||
|
"自然度": "naturalness",
|
||||||
|
"信息密度": "information_density",
|
||||||
|
"创意性": "creativity"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
lines = response.split('\n')
|
||||||
|
feedback_start = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# 查找评分
|
||||||
|
for chinese_name, english_key in dimension_map.items():
|
||||||
|
if chinese_name in line and ':' in line:
|
||||||
|
# 提取分数
|
||||||
|
score_match = re.search(r'(\d+(?:\.\d+)?)', line)
|
||||||
|
if score_match:
|
||||||
|
score = float(score_match.group(1))
|
||||||
|
# 确保分数在1-10范围内
|
||||||
|
score = max(1.0, min(10.0, score))
|
||||||
|
scores[english_key] = score
|
||||||
|
|
||||||
|
# 查找总体评价
|
||||||
|
if '总体评价' in line or '评价' in line:
|
||||||
|
feedback_start = True
|
||||||
|
feedback_content = line.split(':', 1)
|
||||||
|
if len(feedback_content) > 1:
|
||||||
|
feedback += feedback_content[1].strip()
|
||||||
|
elif feedback_start and line:
|
||||||
|
feedback += " " + line
|
||||||
|
|
||||||
|
# 确保所有维度都有分数
|
||||||
|
for english_key in dimension_map.values():
|
||||||
|
if english_key not in scores:
|
||||||
|
scores[english_key] = 5.0 # 默认中等分数
|
||||||
|
|
||||||
|
if not feedback:
|
||||||
|
feedback = "AI评分完成,建议根据各维度分数进行改进。"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析评分响应失败: {e}")
|
||||||
|
# 使用默认分数
|
||||||
|
for english_key in dimension_map.values():
|
||||||
|
scores[english_key] = 5.0
|
||||||
|
feedback = "评分解析失败,使用默认分数。"
|
||||||
|
|
||||||
|
return scores, feedback
|
||||||
|
|
||||||
|
def _calculate_overall_score(self, scores: Dict[str, float]) -> float:
|
||||||
|
"""计算总分"""
|
||||||
|
total_score = 0.0
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
for dimension, score in scores.items():
|
||||||
|
if dimension in self.score_dimensions:
|
||||||
|
weight = self.score_dimensions[dimension]["weight"]
|
||||||
|
total_score += score * weight
|
||||||
|
total_weight += weight
|
||||||
|
|
||||||
|
if total_weight > 0:
|
||||||
|
return round(total_score / total_weight, 2)
|
||||||
|
else:
|
||||||
|
return 5.0
|
||||||
|
|
||||||
|
def _create_default_score(self, dialogue_content: str, speaker: str) -> ScoreResult:
|
||||||
|
"""创建默认评分结果"""
|
||||||
|
default_scores = {
|
||||||
|
"coherence": 5.0,
|
||||||
|
"character_consistency": 5.0,
|
||||||
|
"naturalness": 5.0,
|
||||||
|
"information_density": 5.0,
|
||||||
|
"creativity": 5.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return ScoreResult(
|
||||||
|
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
session_id="",
|
||||||
|
speaker=speaker,
|
||||||
|
content=dialogue_content,
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
scores=default_scores,
|
||||||
|
overall_score=5.0,
|
||||||
|
feedback="使用默认评分",
|
||||||
|
scorer_type="ai"
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_score_dialogue(self, dialogue_list: List[Dict]) -> List[ScoreResult]:
|
||||||
|
"""批量评分对话"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, dialogue_item in enumerate(dialogue_list):
|
||||||
|
print(f"正在评分 {i+1}/{len(dialogue_list)}: {dialogue_item.get('speaker', '未知')}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.score_dialogue(
|
||||||
|
dialogue_content=dialogue_item.get('content', ''),
|
||||||
|
speaker=dialogue_item.get('speaker', '未知'),
|
||||||
|
character_data=dialogue_item.get('character_data', {}),
|
||||||
|
dialogue_history=dialogue_item.get('dialogue_history', []),
|
||||||
|
context_info=dialogue_item.get('context_info', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置session_id
|
||||||
|
result.session_id = dialogue_item.get('session_id', '')
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"评分失败: {e}")
|
||||||
|
# 添加默认评分
|
||||||
|
default_result = self._create_default_score(
|
||||||
|
dialogue_item.get('content', ''),
|
||||||
|
dialogue_item.get('speaker', '未知')
|
||||||
|
)
|
||||||
|
default_result.session_id = dialogue_item.get('session_id', '')
|
||||||
|
results.append(default_result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
class HumanScorer:
|
||||||
|
"""人工评分器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.score_dimensions = {
|
||||||
|
"coherence": "连贯性",
|
||||||
|
"character_consistency": "角色一致性",
|
||||||
|
"naturalness": "自然度",
|
||||||
|
"information_density": "信息密度",
|
||||||
|
"creativity": "创意性"
|
||||||
|
}
|
||||||
|
|
||||||
|
def score_dialogue_interactive(self,
|
||||||
|
dialogue_content: str,
|
||||||
|
speaker: str,
|
||||||
|
session_id: str = "") -> ScoreResult:
|
||||||
|
"""交互式人工评分"""
|
||||||
|
|
||||||
|
print(f"\n=== 人工评分 ===")
|
||||||
|
print(f"角色: {speaker}")
|
||||||
|
print(f"对话: {dialogue_content}")
|
||||||
|
print(f"请对以下维度评分 (1-10分):")
|
||||||
|
|
||||||
|
scores = {}
|
||||||
|
|
||||||
|
for dimension_key, dimension_name in self.score_dimensions.items():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
score_input = input(f"{dimension_name} (1-10): ").strip()
|
||||||
|
score = float(score_input)
|
||||||
|
if 1 <= score <= 10:
|
||||||
|
scores[dimension_key] = score
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("请输入1-10之间的分数")
|
||||||
|
except ValueError:
|
||||||
|
print("请输入有效的数字")
|
||||||
|
|
||||||
|
# 获取反馈
|
||||||
|
feedback = input("请输入评价和建议 (可选): ").strip()
|
||||||
|
if not feedback:
|
||||||
|
feedback = "人工评分完成"
|
||||||
|
|
||||||
|
# 计算总分
|
||||||
|
overall_score = sum(scores.values()) / len(scores)
|
||||||
|
|
||||||
|
return ScoreResult(
|
||||||
|
dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
session_id=session_id,
|
||||||
|
speaker=speaker,
|
||||||
|
content=dialogue_content,
|
||||||
|
timestamp=datetime.now().isoformat(),
|
||||||
|
scores=scores,
|
||||||
|
overall_score=round(overall_score, 2),
|
||||||
|
feedback=feedback,
|
||||||
|
scorer_type="human"
|
||||||
|
)
|
||||||
|
|
||||||
|
class QuickScorer:
|
||||||
|
"""快速规则评分器(用于实时反馈)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def quick_score(self,
|
||||||
|
dialogue_content: str,
|
||||||
|
speaker: str,
|
||||||
|
dialogue_history: List[Dict] = None) -> float:
|
||||||
|
"""快速评分(基于规则)"""
|
||||||
|
|
||||||
|
score = 5.0 # 基础分
|
||||||
|
|
||||||
|
# 长度检查
|
||||||
|
content_length = len(dialogue_content.strip())
|
||||||
|
if content_length < 10:
|
||||||
|
score -= 2.0 # 太短
|
||||||
|
elif content_length > 200:
|
||||||
|
score -= 1.0 # 太长
|
||||||
|
elif 30 <= content_length <= 100:
|
||||||
|
score += 0.5 # 长度适中
|
||||||
|
|
||||||
|
# 重复检查
|
||||||
|
if dialogue_history:
|
||||||
|
recent_content = [turn.get('content', '') for turn in dialogue_history[-3:]]
|
||||||
|
for prev_content in recent_content:
|
||||||
|
if dialogue_content == prev_content:
|
||||||
|
score -= 3.0 # 重复内容
|
||||||
|
elif self._calculate_similarity(dialogue_content, prev_content) > 0.8:
|
||||||
|
score -= 1.5 # 高度相似
|
||||||
|
|
||||||
|
# 内容质量检查
|
||||||
|
if any(word in dialogue_content for word in ['...', '呃', '额', '嗯嗯']):
|
||||||
|
score -= 0.5 # 含有填充词
|
||||||
|
|
||||||
|
if re.search(r'[。!?]', dialogue_content):
|
||||||
|
score += 0.3 # 有标点符号
|
||||||
|
|
||||||
|
# 确保分数在合理范围内
|
||||||
|
return max(1.0, min(10.0, score))
|
||||||
|
|
||||||
|
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
||||||
|
"""计算文本相似度(简单方法)"""
|
||||||
|
words1 = set(text1)
|
||||||
|
words2 = set(text2)
|
||||||
|
|
||||||
|
if not words1 and not words2:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
intersection = len(words1.intersection(words2))
|
||||||
|
union = len(words1.union(words2))
|
||||||
|
|
||||||
|
return intersection / union if union > 0 else 0.0
|
||||||
@ -529,215 +529,75 @@ def run_model_optimization():
|
|||||||
|
|
||||||
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
conv_mgr = ConversationManager("./conversation_data/conversations.db")
|
||||||
|
|
||||||
print("模型优化选项:")
|
# print("模型优化选项:")
|
||||||
print("1. 分析优化需求")
|
|
||||||
print("2. 生成LoRA训练脚本")
|
# print("1. 生成LoRA训练脚本")
|
||||||
print("3. 创建提示优化配置")
|
# print("2. 执行增量训练")
|
||||||
print("4. 执行增量训练")
|
|
||||||
print("5. 性能对比验证")
|
|
||||||
|
|
||||||
choice = input("请输入选择 (1-5): ").strip()
|
# choice = input("请输入选择 (1-5): ").strip()
|
||||||
|
|
||||||
|
# if choice == '1':
|
||||||
|
# 生成LoRA训练脚本
|
||||||
|
print("\n=== 生成LoRA训练脚本 ===")
|
||||||
|
|
||||||
if choice == '1':
|
script_content = generate_lora_training_script()
|
||||||
# 分析优化需求
|
script_path = "./scripts/iterative_lora_training.py"
|
||||||
print("\n=== 优化需求分析 ===")
|
|
||||||
|
|
||||||
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
||||||
# 获取性能数据
|
|
||||||
cursor = conn.execute("""
|
|
||||||
SELECT
|
|
||||||
COUNT(*) as total,
|
|
||||||
AVG(dialogue_score) as avg_score,
|
|
||||||
AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate
|
|
||||||
FROM dialogue_turns WHERE dialogue_score > 0
|
|
||||||
""")
|
|
||||||
|
|
||||||
total, avg_score, hq_rate = cursor.fetchone()
|
|
||||||
|
|
||||||
print(f"当前性能指标:")
|
|
||||||
print(f" - 总评分样本: {total}")
|
|
||||||
print(f" - 平均分数: {avg_score:.2f}")
|
|
||||||
print(f" - 高质量率: {hq_rate*100:.1f}%")
|
|
||||||
|
|
||||||
# 分析优化潜力
|
|
||||||
optimization_needs = []
|
|
||||||
|
|
||||||
if avg_score < 7.0:
|
|
||||||
optimization_needs.append("整体质量需要提升")
|
|
||||||
|
|
||||||
if hq_rate < 0.6:
|
|
||||||
optimization_needs.append("高质量对话比例偏低")
|
|
||||||
|
|
||||||
# 分析各维度表现
|
|
||||||
cursor = conn.execute("""
|
|
||||||
SELECT score_details FROM dialogue_turns
|
|
||||||
WHERE dialogue_score > 0 AND score_details != '{}'
|
|
||||||
ORDER BY timestamp DESC LIMIT 100
|
|
||||||
""")
|
|
||||||
|
|
||||||
dim_scores = {'coherence': [], 'character_consistency': [],
|
|
||||||
'naturalness': [], 'information_density': [], 'creativity': []}
|
|
||||||
|
|
||||||
for (score_details,) in cursor.fetchall():
|
|
||||||
try:
|
|
||||||
scores = json.loads(score_details)
|
|
||||||
for dim, score in scores.items():
|
|
||||||
if dim in dim_scores:
|
|
||||||
dim_scores[dim].append(float(score))
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
weak_dimensions = []
|
|
||||||
print(f"\n维度分析:")
|
|
||||||
for dim, scores in dim_scores.items():
|
|
||||||
if scores:
|
|
||||||
avg = sum(scores) / len(scores)
|
|
||||||
print(f" - {dim}: {avg:.2f}分")
|
|
||||||
if avg < 7.0:
|
|
||||||
weak_dimensions.append(dim)
|
|
||||||
|
|
||||||
if weak_dimensions:
|
|
||||||
optimization_needs.append(f"薄弱维度: {weak_dimensions}")
|
|
||||||
|
|
||||||
print(f"\n优化建议:")
|
|
||||||
if optimization_needs:
|
|
||||||
for i, need in enumerate(optimization_needs, 1):
|
|
||||||
print(f" {i}. {need}")
|
|
||||||
else:
|
|
||||||
print(" 当前模型表现良好,可考虑细微调优")
|
|
||||||
|
|
||||||
elif choice == '2':
|
|
||||||
# 生成LoRA训练脚本
|
|
||||||
print("\n=== 生成LoRA训练脚本 ===")
|
|
||||||
|
|
||||||
script_content = generate_lora_training_script()
|
|
||||||
script_path = "./scripts/iterative_lora_training.py"
|
|
||||||
|
|
||||||
os.makedirs("./scripts", exist_ok=True)
|
|
||||||
with open(script_path, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(script_content)
|
|
||||||
|
|
||||||
print(f"✓ LoRA训练脚本已生成: {script_path}")
|
|
||||||
print("使用方法:")
|
|
||||||
print(" 1. 先运行训练数据生成 (选项8)")
|
|
||||||
print(" 2. 修改脚本中的路径配置")
|
|
||||||
print(f" 3. 运行: python {script_path}")
|
|
||||||
|
|
||||||
elif choice == '3':
|
|
||||||
# 创建提示优化配置
|
|
||||||
print("\n=== 创建提示优化配置 ===")
|
|
||||||
|
|
||||||
config = generate_prompt_optimization_config()
|
|
||||||
config_path = "./config/prompt_optimization.json"
|
|
||||||
|
|
||||||
os.makedirs("./config", exist_ok=True)
|
|
||||||
with open(config_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
print(f"✓ 提示优化配置已生成: {config_path}")
|
|
||||||
print("配置包含:")
|
|
||||||
print(" - 动态提示调整规则")
|
|
||||||
print(" - 质量阈值设置")
|
|
||||||
print(" - 生成参数优化")
|
|
||||||
|
|
||||||
elif choice == '4':
|
|
||||||
# 执行增量训练
|
|
||||||
print("\n=== 执行增量训练 ===")
|
|
||||||
|
|
||||||
# 检查训练数据
|
|
||||||
training_dir = "./training_data"
|
|
||||||
if not os.path.exists(training_dir):
|
|
||||||
print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
|
|
||||||
return
|
|
||||||
|
|
||||||
training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
|
|
||||||
if not training_files:
|
|
||||||
print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"找到训练数据文件:")
|
|
||||||
for i, file in enumerate(training_files, 1):
|
|
||||||
print(f" {i}. {file}")
|
|
||||||
|
|
||||||
file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
|
|
||||||
try:
|
|
||||||
selected_file = training_files[int(file_idx) - 1]
|
|
||||||
training_file_path = os.path.join(training_dir, selected_file)
|
|
||||||
|
|
||||||
print(f"将使用训练文件: {selected_file}")
|
|
||||||
print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
|
|
||||||
|
|
||||||
# 生成训练命令
|
|
||||||
training_command = generate_training_command(training_file_path)
|
|
||||||
print(f"建议训练命令:")
|
|
||||||
print(f" {training_command}")
|
|
||||||
|
|
||||||
# 可选:执行训练(需要用户确认)
|
|
||||||
confirm = input("是否现在执行训练?(y/N): ").strip().lower()
|
|
||||||
if confirm == 'y':
|
|
||||||
print("开始增量训练...")
|
|
||||||
# 这里可以添加实际的训练执行逻辑
|
|
||||||
print("⚠ 训练功能需要根据实际环境配置")
|
|
||||||
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
print("❌ 无效的文件选择")
|
|
||||||
|
|
||||||
elif choice == '5':
|
|
||||||
# 性能对比验证
|
|
||||||
print("\n=== 性能对比验证 ===")
|
|
||||||
|
|
||||||
print("验证选项:")
|
|
||||||
print("1. 历史性能趋势对比")
|
|
||||||
print("2. A/B测试配置生成")
|
|
||||||
print("3. 模型版本性能对比")
|
|
||||||
|
|
||||||
verify_choice = input("请输入选择 (1-3): ").strip()
|
|
||||||
|
|
||||||
if verify_choice == '1':
|
|
||||||
# 历史性能趋势
|
|
||||||
with sqlite3.connect(conv_mgr.db_path) as conn:
|
|
||||||
cursor = conn.execute("""
|
|
||||||
SELECT
|
|
||||||
DATE(timestamp) as date,
|
|
||||||
AVG(dialogue_score) as avg_score,
|
|
||||||
COUNT(*) as count
|
|
||||||
FROM dialogue_turns
|
|
||||||
WHERE dialogue_score > 0
|
|
||||||
AND timestamp >= datetime('now', '-30 days')
|
|
||||||
GROUP BY DATE(timestamp)
|
|
||||||
ORDER BY date ASC
|
|
||||||
""")
|
|
||||||
|
|
||||||
trend_data = cursor.fetchall()
|
|
||||||
if trend_data:
|
|
||||||
print(f"30天性能趋势:")
|
|
||||||
for date, avg_score, count in trend_data:
|
|
||||||
trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊"
|
|
||||||
print(f" {date}: {avg_score:.2f}分 {trend} ({count}条对话)")
|
|
||||||
else:
|
|
||||||
print("暂无足够的历史数据")
|
|
||||||
|
|
||||||
elif verify_choice == '2':
|
|
||||||
# A/B测试配置
|
|
||||||
ab_config = generate_ab_test_config()
|
|
||||||
ab_config_path = "./config/ab_test_config.json"
|
|
||||||
|
|
||||||
os.makedirs("./config", exist_ok=True)
|
|
||||||
with open(ab_config_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(ab_config, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
print(f"✓ A/B测试配置已生成: {ab_config_path}")
|
|
||||||
print("配置包含:")
|
|
||||||
print(" - 对照组和实验组设置")
|
|
||||||
print(" - 评估指标定义")
|
|
||||||
print(" - 测试持续时间配置")
|
|
||||||
|
|
||||||
elif verify_choice == '3':
|
|
||||||
print("模型版本对比功能开发中...")
|
|
||||||
print("建议手动记录不同版本的性能指标进行对比")
|
|
||||||
|
|
||||||
else:
|
os.makedirs("./scripts", exist_ok=True)
|
||||||
print("❌ 无效选择")
|
with open(script_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(script_content)
|
||||||
|
|
||||||
|
print(f"✓ LoRA训练脚本已生成: {script_path}")
|
||||||
|
print("使用方法:")
|
||||||
|
print(" 1. 先运行训练数据生成 (选项8)")
|
||||||
|
print(" 2. 修改脚本中的路径配置")
|
||||||
|
print(f" 3. 运行: python {script_path}")
|
||||||
|
|
||||||
|
# elif choice == '2':
|
||||||
|
# # 执行增量训练
|
||||||
|
# print("\n=== 执行增量训练 ===")
|
||||||
|
|
||||||
|
# # 检查训练数据
|
||||||
|
# training_dir = "./training_data"
|
||||||
|
# if not os.path.exists(training_dir):
|
||||||
|
# print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')]
|
||||||
|
# if not training_files:
|
||||||
|
# print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# print(f"找到训练数据文件:")
|
||||||
|
# for i, file in enumerate(training_files, 1):
|
||||||
|
# print(f" {i}. {file}")
|
||||||
|
|
||||||
|
# file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip()
|
||||||
|
# try:
|
||||||
|
# selected_file = training_files[int(file_idx) - 1]
|
||||||
|
# training_file_path = os.path.join(training_dir, selected_file)
|
||||||
|
|
||||||
|
# print(f"将使用训练文件: {selected_file}")
|
||||||
|
# print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源")
|
||||||
|
|
||||||
|
# # 生成训练命令
|
||||||
|
# training_command = generate_training_command(training_file_path)
|
||||||
|
# print(f"建议训练命令:")
|
||||||
|
# print(f" {training_command}")
|
||||||
|
|
||||||
|
# # 可选:执行训练(需要用户确认)
|
||||||
|
# confirm = input("是否现在执行训练?(y/N): ").strip().lower()
|
||||||
|
# if confirm == 'y':
|
||||||
|
# print("开始增量训练...")
|
||||||
|
# # 这里可以添加实际的训练执行逻辑
|
||||||
|
# print("⚠ 训练功能需要根据实际环境配置")
|
||||||
|
|
||||||
|
# except (ValueError, IndexError):
|
||||||
|
# print("❌ 无效的文件选择")
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# print("❌ 无效选择")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ 模型优化失败: {e}")
|
print(f"✗ 模型优化失败: {e}")
|
||||||
@ -837,102 +697,8 @@ if __name__ == '__main__':
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def generate_prompt_optimization_config():
|
|
||||||
"""生成提示优化配置"""
|
|
||||||
return {
|
|
||||||
"optimization_rules": {
|
|
||||||
"quality_thresholds": {
|
|
||||||
"excellent": 9.0,
|
|
||||||
"good": 8.0,
|
|
||||||
"acceptable": 7.0,
|
|
||||||
"needs_improvement": 6.0
|
|
||||||
},
|
|
||||||
"adaptive_adjustments": {
|
|
||||||
"low_coherence": {
|
|
||||||
"add_context_emphasis": True,
|
|
||||||
"increase_history_weight": 1.2,
|
|
||||||
"add_logical_constraints": True
|
|
||||||
},
|
|
||||||
"low_character_consistency": {
|
|
||||||
"enhance_character_description": True,
|
|
||||||
"add_personality_reminders": True,
|
|
||||||
"increase_character_weight": 1.5
|
|
||||||
},
|
|
||||||
"low_creativity": {
|
|
||||||
"add_creativity_prompts": True,
|
|
||||||
"increase_temperature": 0.1,
|
|
||||||
"diversify_examples": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"dynamic_prompts": {
|
|
||||||
"quality_boost_templates": [
|
|
||||||
"请生成一个富有创意且符合角色特征的高质量对话",
|
|
||||||
"确保对话逻辑连贯,信息丰富,表达自然",
|
|
||||||
"体现角色的独特性格和说话风格"
|
|
||||||
],
|
|
||||||
"problem_specific_guidance": {
|
|
||||||
"repetitive": "避免重复之前的内容,提供新的信息和观点",
|
|
||||||
"inconsistent": "严格遵循角色设定,保持人格一致性",
|
|
||||||
"dull": "增加对话的趣味性和深度,使用生动的表达"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"generation_parameters": {
|
|
||||||
"adaptive_temperature": {
|
|
||||||
"high_creativity_needed": 0.9,
|
|
||||||
"normal": 0.8,
|
|
||||||
"high_consistency_needed": 0.7
|
|
||||||
},
|
|
||||||
"adaptive_top_p": {
|
|
||||||
"creative_mode": 0.9,
|
|
||||||
"balanced_mode": 0.8,
|
|
||||||
"conservative_mode": 0.7
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate_ab_test_config():
|
|
||||||
"""生成A/B测试配置"""
|
|
||||||
return {
|
|
||||||
"test_name": "model_optimization_validation",
|
|
||||||
"created_at": datetime.now().isoformat(),
|
|
||||||
"groups": {
|
|
||||||
"control": {
|
|
||||||
"name": "原始模型",
|
|
||||||
"description": "未优化的基础模型",
|
|
||||||
"model_path": "/path/to/base/model",
|
|
||||||
"sample_ratio": 0.5
|
|
||||||
},
|
|
||||||
"experimental": {
|
|
||||||
"name": "优化模型",
|
|
||||||
"description": "经过迭代优化的模型",
|
|
||||||
"model_path": "/path/to/optimized/model",
|
|
||||||
"sample_ratio": 0.5
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"evaluation_metrics": {
|
|
||||||
"primary": [
|
|
||||||
"overall_dialogue_score",
|
|
||||||
"character_consistency_score",
|
|
||||||
"creativity_score"
|
|
||||||
],
|
|
||||||
"secondary": [
|
|
||||||
"user_satisfaction",
|
|
||||||
"response_time",
|
|
||||||
"coherence_score"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"test_duration": {
|
|
||||||
"target_samples": 200,
|
|
||||||
"max_duration_days": 7,
|
|
||||||
"min_samples_per_group": 50
|
|
||||||
},
|
|
||||||
"statistical_settings": {
|
|
||||||
"confidence_level": 0.95,
|
|
||||||
"minimum_effect_size": 0.3,
|
|
||||||
"power": 0.8
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate_training_command(training_file_path):
|
def generate_training_command(training_file_path):
|
||||||
"""生成训练命令"""
|
"""生成训练命令"""
|
||||||
|
|||||||
617
AITrain/score_data_manager.py
Normal file
617
AITrain/score_data_manager.py
Normal file
@ -0,0 +1,617 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
'''
|
||||||
|
对话评分数据存储管理系统
|
||||||
|
管理评分结果、统计分析、数据导出等功能
|
||||||
|
'''
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from dataclasses import asdict
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dialogue_scorer import ScoreResult
|
||||||
|
|
||||||
|
class ScoreDataManager:
|
||||||
|
"""评分数据管理器"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "score_data.db"):
|
||||||
|
"""
|
||||||
|
初始化评分数据管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
self._init_database()
|
||||||
|
|
||||||
|
def _init_database(self):
|
||||||
|
"""初始化数据库"""
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
# 评分结果表
|
||||||
|
conn.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS score_results (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
dialogue_id TEXT UNIQUE,
|
||||||
|
session_id TEXT,
|
||||||
|
speaker TEXT,
|
||||||
|
content TEXT,
|
||||||
|
timestamp TEXT,
|
||||||
|
scorer_type TEXT,
|
||||||
|
coherence_score REAL,
|
||||||
|
character_consistency_score REAL,
|
||||||
|
naturalness_score REAL,
|
||||||
|
information_density_score REAL,
|
||||||
|
creativity_score REAL,
|
||||||
|
overall_score REAL,
|
||||||
|
feedback TEXT,
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# 评分统计表
|
||||||
|
conn.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS score_statistics (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
session_id TEXT,
|
||||||
|
speaker TEXT,
|
||||||
|
date_range TEXT,
|
||||||
|
total_dialogues INTEGER,
|
||||||
|
avg_overall_score REAL,
|
||||||
|
avg_coherence REAL,
|
||||||
|
avg_character_consistency REAL,
|
||||||
|
avg_naturalness REAL,
|
||||||
|
avg_information_density REAL,
|
||||||
|
avg_creativity REAL,
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# 模型训练数据表
|
||||||
|
conn.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS training_data (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
dialogue_content TEXT,
|
||||||
|
speaker TEXT,
|
||||||
|
character_data TEXT,
|
||||||
|
context_info TEXT,
|
||||||
|
dialogue_history TEXT,
|
||||||
|
score_overall REAL,
|
||||||
|
score_dimensions TEXT,
|
||||||
|
feedback TEXT,
|
||||||
|
is_high_quality INTEGER,
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# 优化记录表
|
||||||
|
conn.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS optimization_records (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
optimization_type TEXT,
|
||||||
|
before_score REAL,
|
||||||
|
after_score REAL,
|
||||||
|
improvement REAL,
|
||||||
|
parameters_changed TEXT,
|
||||||
|
description TEXT,
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def save_score_result(self, score_result: ScoreResult) -> bool:
|
||||||
|
"""
|
||||||
|
保存评分结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score_result: 评分结果对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 保存是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute('''
|
||||||
|
INSERT OR REPLACE INTO score_results
|
||||||
|
(dialogue_id, session_id, speaker, content, timestamp, scorer_type,
|
||||||
|
coherence_score, character_consistency_score, naturalness_score,
|
||||||
|
information_density_score, creativity_score, overall_score, feedback)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
''', (
|
||||||
|
score_result.dialogue_id,
|
||||||
|
score_result.session_id,
|
||||||
|
score_result.speaker,
|
||||||
|
score_result.content,
|
||||||
|
score_result.timestamp,
|
||||||
|
score_result.scorer_type,
|
||||||
|
score_result.scores.get('coherence', 0),
|
||||||
|
score_result.scores.get('character_consistency', 0),
|
||||||
|
score_result.scores.get('naturalness', 0),
|
||||||
|
score_result.scores.get('information_density', 0),
|
||||||
|
score_result.scores.get('creativity', 0),
|
||||||
|
score_result.overall_score,
|
||||||
|
score_result.feedback
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存评分结果失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_score_results(self,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
speaker: Optional[str] = None,
|
||||||
|
min_score: Optional[float] = None,
|
||||||
|
max_score: Optional[float] = None,
|
||||||
|
limit: int = 100) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
获取评分结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID过滤
|
||||||
|
speaker: 角色过滤
|
||||||
|
min_score: 最低分数过滤
|
||||||
|
max_score: 最高分数过滤
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 评分结果列表
|
||||||
|
"""
|
||||||
|
query = "SELECT * FROM score_results WHERE 1=1"
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
query += " AND session_id = ?"
|
||||||
|
params.append(session_id)
|
||||||
|
|
||||||
|
if speaker:
|
||||||
|
query += " AND speaker = ?"
|
||||||
|
params.append(speaker)
|
||||||
|
|
||||||
|
if min_score is not None:
|
||||||
|
query += " AND overall_score >= ?"
|
||||||
|
params.append(min_score)
|
||||||
|
|
||||||
|
if max_score is not None:
|
||||||
|
query += " AND overall_score <= ?"
|
||||||
|
params.append(max_score)
|
||||||
|
|
||||||
|
query += " ORDER BY created_at DESC LIMIT ?"
|
||||||
|
params.append(limit)
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.execute(query, params)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def calculate_statistics(self,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
speaker: Optional[str] = None,
|
||||||
|
days: int = 7) -> Dict:
|
||||||
|
"""
|
||||||
|
计算评分统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID过滤
|
||||||
|
speaker: 角色过滤
|
||||||
|
days: 统计天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 统计信息
|
||||||
|
"""
|
||||||
|
# 计算时间范围
|
||||||
|
end_date = datetime.now()
|
||||||
|
start_date = end_date - timedelta(days=days)
|
||||||
|
|
||||||
|
query = '''
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_dialogues,
|
||||||
|
AVG(overall_score) as avg_overall,
|
||||||
|
AVG(coherence_score) as avg_coherence,
|
||||||
|
AVG(character_consistency_score) as avg_character_consistency,
|
||||||
|
AVG(naturalness_score) as avg_naturalness,
|
||||||
|
AVG(information_density_score) as avg_information_density,
|
||||||
|
AVG(creativity_score) as avg_creativity,
|
||||||
|
MIN(overall_score) as min_score,
|
||||||
|
MAX(overall_score) as max_score,
|
||||||
|
speaker
|
||||||
|
FROM score_results
|
||||||
|
WHERE created_at >= ? AND created_at <= ?
|
||||||
|
'''
|
||||||
|
|
||||||
|
params = [start_date.isoformat(), end_date.isoformat()]
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
query += " AND session_id = ?"
|
||||||
|
params.append(session_id)
|
||||||
|
|
||||||
|
if speaker:
|
||||||
|
query += " AND speaker = ?"
|
||||||
|
params.append(speaker)
|
||||||
|
else:
|
||||||
|
query += " GROUP BY speaker"
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.execute(query, params)
|
||||||
|
results = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}",
|
||||||
|
"statistics": results
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_high_quality_samples(self, threshold: float = 8.0, limit: int = 50) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
获取高质量对话样本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: 分数阈值
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 高质量样本列表
|
||||||
|
"""
|
||||||
|
query = '''
|
||||||
|
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
|
||||||
|
FROM score_results
|
||||||
|
WHERE overall_score >= ?
|
||||||
|
ORDER BY overall_score DESC
|
||||||
|
LIMIT ?
|
||||||
|
'''
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.execute(query, [threshold, limit])
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def get_low_quality_samples(self, threshold: float = 4.0, limit: int = 50) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
获取低质量对话样本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: 分数阈值
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 低质量样本列表
|
||||||
|
"""
|
||||||
|
query = '''
|
||||||
|
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
|
||||||
|
FROM score_results
|
||||||
|
WHERE overall_score <= ?
|
||||||
|
ORDER BY overall_score ASC
|
||||||
|
LIMIT ?
|
||||||
|
'''
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.execute(query, [threshold, limit])
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def save_training_data(self,
|
||||||
|
dialogue_content: str,
|
||||||
|
speaker: str,
|
||||||
|
character_data: Dict,
|
||||||
|
context_info: List[Dict],
|
||||||
|
dialogue_history: List[Dict],
|
||||||
|
score_result: ScoreResult) -> bool:
|
||||||
|
"""
|
||||||
|
保存用于模型训练的数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogue_content: 对话内容
|
||||||
|
speaker: 说话者
|
||||||
|
character_data: 角色数据
|
||||||
|
context_info: 上下文信息
|
||||||
|
dialogue_history: 对话历史
|
||||||
|
score_result: 评分结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 保存是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 判断是否为高质量样本
|
||||||
|
is_high_quality = 1 if score_result.overall_score >= 7.0 else 0
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute('''
|
||||||
|
INSERT INTO training_data
|
||||||
|
(dialogue_content, speaker, character_data, context_info,
|
||||||
|
dialogue_history, score_overall, score_dimensions, feedback, is_high_quality)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
''', (
|
||||||
|
dialogue_content,
|
||||||
|
speaker,
|
||||||
|
json.dumps(character_data, ensure_ascii=False),
|
||||||
|
json.dumps(context_info, ensure_ascii=False),
|
||||||
|
json.dumps([asdict(turn) if hasattr(turn, '__dict__') else turn for turn in dialogue_history], ensure_ascii=False),
|
||||||
|
score_result.overall_score,
|
||||||
|
json.dumps(score_result.scores, ensure_ascii=False),
|
||||||
|
score_result.feedback,
|
||||||
|
is_high_quality
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存训练数据失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def export_to_csv(self, output_path: str = "score_data_export.csv") -> bool:
|
||||||
|
"""
|
||||||
|
导出评分数据到CSV文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: 输出文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 导出是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = '''
|
||||||
|
SELECT dialogue_id, session_id, speaker, content, timestamp, scorer_type,
|
||||||
|
coherence_score, character_consistency_score, naturalness_score,
|
||||||
|
information_density_score, creativity_score, overall_score, feedback
|
||||||
|
FROM score_results
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
'''
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
df = pd.read_sql_query(query, conn)
|
||||||
|
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||||
|
|
||||||
|
print(f"✓ 评分数据已导出到: {output_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"导出数据失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_score_trends(self, speaker: str, days: int = 30) -> Dict:
|
||||||
|
"""
|
||||||
|
获取评分趋势数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speaker: 角色名称
|
||||||
|
days: 统计天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 趋势数据
|
||||||
|
"""
|
||||||
|
end_date = datetime.now()
|
||||||
|
start_date = end_date - timedelta(days=days)
|
||||||
|
|
||||||
|
query = '''
|
||||||
|
SELECT DATE(created_at) as date,
|
||||||
|
AVG(overall_score) as avg_score,
|
||||||
|
COUNT(*) as count
|
||||||
|
FROM score_results
|
||||||
|
WHERE speaker = ? AND created_at >= ? AND created_at <= ?
|
||||||
|
GROUP BY DATE(created_at)
|
||||||
|
ORDER BY date
|
||||||
|
'''
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.execute(query, [speaker, start_date.isoformat(), end_date.isoformat()])
|
||||||
|
results = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"speaker": speaker,
|
||||||
|
"period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}",
|
||||||
|
"trend_data": results
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_optimization_record(self,
|
||||||
|
optimization_type: str,
|
||||||
|
before_score: float,
|
||||||
|
after_score: float,
|
||||||
|
parameters_changed: Dict,
|
||||||
|
description: str = "") -> bool:
|
||||||
|
"""
|
||||||
|
保存优化记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimization_type: 优化类型
|
||||||
|
before_score: 优化前分数
|
||||||
|
after_score: 优化后分数
|
||||||
|
parameters_changed: 改变的参数
|
||||||
|
description: 描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 保存是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
improvement = after_score - before_score
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.execute('''
|
||||||
|
INSERT INTO optimization_records
|
||||||
|
(optimization_type, before_score, after_score, improvement,
|
||||||
|
parameters_changed, description)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
''', (
|
||||||
|
optimization_type,
|
||||||
|
before_score,
|
||||||
|
after_score,
|
||||||
|
improvement,
|
||||||
|
json.dumps(parameters_changed, ensure_ascii=False),
|
||||||
|
description
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存优化记录失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_optimization_history(self, limit: int = 20) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
获取优化历史记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: 返回数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 优化历史记录
|
||||||
|
"""
|
||||||
|
query = '''
|
||||||
|
SELECT * FROM optimization_records
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
'''
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.execute(query, [limit])
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
class ScoreAnalyzer:
|
||||||
|
"""评分分析器"""
|
||||||
|
|
||||||
|
def __init__(self, score_manager: ScoreDataManager):
|
||||||
|
"""
|
||||||
|
初始化评分分析器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score_manager: 评分数据管理器
|
||||||
|
"""
|
||||||
|
self.score_manager = score_manager
|
||||||
|
|
||||||
|
def analyze_character_performance(self, character_name: str) -> Dict:
|
||||||
|
"""
|
||||||
|
分析角色表现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
character_name: 角色名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 分析结果
|
||||||
|
"""
|
||||||
|
# 获取角色的所有评分数据
|
||||||
|
scores = self.score_manager.get_score_results(speaker=character_name, limit=1000)
|
||||||
|
|
||||||
|
if not scores:
|
||||||
|
return {"error": f"没有找到角色 {character_name} 的评分数据"}
|
||||||
|
|
||||||
|
# 计算各种统计信息
|
||||||
|
overall_scores = [score['overall_score'] for score in scores]
|
||||||
|
|
||||||
|
analysis = {
|
||||||
|
"character_name": character_name,
|
||||||
|
"total_dialogues": len(scores),
|
||||||
|
"average_score": round(np.mean(overall_scores), 2),
|
||||||
|
"median_score": round(np.median(overall_scores), 2),
|
||||||
|
"score_std": round(np.std(overall_scores), 2),
|
||||||
|
"min_score": min(overall_scores),
|
||||||
|
"max_score": max(overall_scores),
|
||||||
|
"score_distribution": self._calculate_score_distribution(overall_scores),
|
||||||
|
"dimension_analysis": self._analyze_dimensions(scores),
|
||||||
|
"improvement_suggestions": self._generate_improvement_suggestions(scores)
|
||||||
|
}
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
def _calculate_score_distribution(self, scores: List[float]) -> Dict:
|
||||||
|
"""计算分数分布"""
|
||||||
|
distribution = {
|
||||||
|
"excellent": len([s for s in scores if s >= 8.0]),
|
||||||
|
"good": len([s for s in scores if 6.0 <= s < 8.0]),
|
||||||
|
"average": len([s for s in scores if 4.0 <= s < 6.0]),
|
||||||
|
"poor": len([s for s in scores if s < 4.0])
|
||||||
|
}
|
||||||
|
|
||||||
|
total = len(scores)
|
||||||
|
if total > 0:
|
||||||
|
distribution = {k: {"count": v, "percentage": round(v/total*100, 1)}
|
||||||
|
for k, v in distribution.items()}
|
||||||
|
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
def _analyze_dimensions(self, scores: List[Dict]) -> Dict:
|
||||||
|
"""分析各维度表现"""
|
||||||
|
dimensions = ['coherence_score', 'character_consistency_score', 'naturalness_score',
|
||||||
|
'information_density_score', 'creativity_score']
|
||||||
|
|
||||||
|
dimension_analysis = {}
|
||||||
|
|
||||||
|
for dim in dimensions:
|
||||||
|
dim_scores = [score[dim] for score in scores if score[dim] is not None]
|
||||||
|
if dim_scores:
|
||||||
|
dimension_analysis[dim] = {
|
||||||
|
"average": round(np.mean(dim_scores), 2),
|
||||||
|
"min": min(dim_scores),
|
||||||
|
"max": max(dim_scores),
|
||||||
|
"std": round(np.std(dim_scores), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dimension_analysis
|
||||||
|
|
||||||
|
def _generate_improvement_suggestions(self, scores: List[Dict]) -> List[str]:
|
||||||
|
"""生成改进建议"""
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
# 分析最弱的维度
|
||||||
|
dimensions = {
|
||||||
|
'coherence_score': '连贯性',
|
||||||
|
'character_consistency_score': '角色一致性',
|
||||||
|
'naturalness_score': '自然度',
|
||||||
|
'information_density_score': '信息密度',
|
||||||
|
'creativity_score': '创意性'
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_averages = {}
|
||||||
|
for dim_key, dim_name in dimensions.items():
|
||||||
|
dim_scores = [score[dim_key] for score in scores if score[dim_key] is not None]
|
||||||
|
if dim_scores:
|
||||||
|
dim_averages[dim_name] = np.mean(dim_scores)
|
||||||
|
|
||||||
|
# 找出最弱的维度
|
||||||
|
if dim_averages:
|
||||||
|
weakest_dim = min(dim_averages.items(), key=lambda x: x[1])
|
||||||
|
if weakest_dim[1] < 6.0:
|
||||||
|
suggestions.append(f"需要重点改进{weakest_dim[0]},当前平均分为{weakest_dim[1]:.1f}")
|
||||||
|
|
||||||
|
# 根据分数分布给建议
|
||||||
|
overall_avg = np.mean([score['overall_score'] for score in scores])
|
||||||
|
if overall_avg < 5.0:
|
||||||
|
suggestions.append("整体表现需要改进,建议调整角色设定或提示词")
|
||||||
|
elif overall_avg < 7.0:
|
||||||
|
suggestions.append("表现良好,可以尝试增加对话的创意性和深度")
|
||||||
|
else:
|
||||||
|
suggestions.append("表现优秀,继续保持当前设定")
|
||||||
|
|
||||||
|
return suggestions
|
||||||
|
|
||||||
|
def compare_characters(self, character_names: List[str]) -> Dict:
|
||||||
|
"""
|
||||||
|
比较多个角色的表现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
character_names: 角色名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 比较结果
|
||||||
|
"""
|
||||||
|
comparison = {"characters": {}}
|
||||||
|
|
||||||
|
for char_name in character_names:
|
||||||
|
scores = self.score_manager.get_score_results(speaker=char_name, limit=500)
|
||||||
|
if scores:
|
||||||
|
overall_scores = [score['overall_score'] for score in scores]
|
||||||
|
comparison["characters"][char_name] = {
|
||||||
|
"total_dialogues": len(scores),
|
||||||
|
"average_score": round(np.mean(overall_scores), 2),
|
||||||
|
"score_range": f"{min(overall_scores):.1f} - {max(overall_scores):.1f}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
if comparison["characters"]:
|
||||||
|
sorted_chars = sorted(comparison["characters"].items(),
|
||||||
|
key=lambda x: x[1]["average_score"], reverse=True)
|
||||||
|
comparison["ranking"] = [{"character": char, **data} for char, data in sorted_chars]
|
||||||
|
|
||||||
|
return comparison
|
||||||
Loading…
x
Reference in New Issue
Block a user