diff --git a/AITrain/dialogue_scorer.py b/AITrain/dialogue_scorer.py new file mode 100644 index 0000000..ef11310 --- /dev/null +++ b/AITrain/dialogue_scorer.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +对话质量评分系统 +使用AI模型对生成的对话进行多维度质量评分 +''' + +import json +import re +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass +from datetime import datetime + +# 尝试导入numpy,如果失败则跳过 +try: + import numpy as np + NUMPY_AVAILABLE = True +except ImportError: + NUMPY_AVAILABLE = False + +@dataclass +class ScoreResult: + """单次打分结果""" + dialogue_id: str + session_id: str + speaker: str + content: str + timestamp: str + scores: Dict[str, float] # 各维度分数 + overall_score: float # 总分 + feedback: str # 反馈意见 + scorer_type: str # 打分器类型 (ai/human) + +class DialogueAIScorer: + """AI对话质量评分器""" + + def __init__(self, base_model_path: str, tokenizer=None, model=None): + """ + 初始化AI评分器 + + Args: + base_model_path: 基础模型路径 + tokenizer: 分词器(可选,复用现有的) + model: 模型(可选,复用现有的) + """ + self.base_model_path = base_model_path + self.tokenizer = tokenizer + self.model = model + + # 如果没有传入模型,则加载 + if self.tokenizer is None or self.model is None: + self._load_model() + + # 评分维度定义 + self.score_dimensions = { + "coherence": { + "name": "连贯性", + "description": "对话是否与上下文逻辑连贯", + "weight": 0.25 + }, + "character_consistency": { + "name": "角色一致性", + "description": "是否符合角色设定和人格特征", + "weight": 0.25 + }, + "naturalness": { + "name": "自然度", + "description": "语言表达是否自然流畅", + "weight": 0.20 + }, + "information_density": { + "name": "信息密度", + "description": "是否包含有意义的信息,避免废话", + "weight": 0.15 + }, + "creativity": { + "name": "创意性", + "description": "内容是否有趣、有创意", + "weight": 0.15 + } + } + + def _load_model(self): + """加载模型和分词器""" + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + import torch + + print(f"Loading scorer tokenizer from: {self.base_model_path}") + self.tokenizer = AutoTokenizer.from_pretrained( + self.base_model_path, + use_fast=False, + trust_remote_code=True + ) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + print(f"Loading scorer model from: {self.base_model_path}") + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_path, + device_map="auto", + torch_dtype=torch.bfloat16, + trust_remote_code=True + ) + + except Exception as e: + print(f"✗ AI评分器模型加载失败: {e}") + raise + + def score_dialogue(self, + dialogue_content: str, + speaker: str, + character_data: Dict, + dialogue_history: List[Dict] = None, + context_info: List[Dict] = None) -> ScoreResult: + """ + 对单条对话进行AI评分 + + Args: + dialogue_content: 对话内容 + speaker: 说话者 + character_data: 角色数据 + dialogue_history: 对话历史 + context_info: 上下文信息 + + Returns: + ScoreResult: 评分结果 + """ + # 构建评分提示 + scoring_prompt = self._build_scoring_prompt( + dialogue_content, speaker, character_data, dialogue_history, context_info + ) + + # 使用AI模型生成评分 + try: + scores, feedback = self._generate_ai_scores(scoring_prompt) + + # 计算总分 + overall_score = self._calculate_overall_score(scores) + + # 创建评分结果 + result = ScoreResult( + dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + session_id="", # 由调用方设置 + speaker=speaker, + content=dialogue_content, + timestamp=datetime.now().isoformat(), + scores=scores, + overall_score=overall_score, + feedback=feedback, + scorer_type="ai" + ) + + return result + + except Exception as e: + print(f"✗ AI评分失败: {e}") + # 返回默认评分 + return self._create_default_score(dialogue_content, speaker) + + def _build_scoring_prompt(self, + dialogue_content: str, + speaker: str, + character_data: Dict, + dialogue_history: List[Dict] = None, + context_info: List[Dict] = None) -> str: + """构建评分提示""" + + # 基础角色信息 + character_info = "" + if character_data: + personality = character_data.get('personality', {}) + traits = personality.get('core_traits', []) + occupation = character_data.get('basic_info', {}).get('occupation', '未知') + character_info = f"角色职业: {occupation}, 性格特点: {', '.join(traits[:3])}" + + # 对话历史 + history_text = "" + if dialogue_history: + history_text = "对话历史:\n" + for turn in dialogue_history[-3:]: # 只取最近3轮 + history_text += f"{turn.get('speaker', '未知')}: {turn.get('content', '')}\n" + + # 构建完整提示 + prompt = f"""请对以下对话内容进行质量评分。 + +角色设定: +{character_info} + +{history_text} + +当前对话: +{speaker}: {dialogue_content} + +请从以下5个维度评分(1-10分): +1. 连贯性 - 对话是否与上下文逻辑连贯 +2. 角色一致性 - 是否符合角色设定和人格特征 +3. 自然度 - 语言表达是否自然流畅 +4. 信息密度 - 是否包含有意义的信息,避免废话 +5. 创意性 - 内容是否有趣、有创意 + +请按以下格式输出: +连贯性: X分 +角色一致性: X分 +自然度: X分 +信息密度: X分 +创意性: X分 + +总体评价: [具体的改进建议和优点分析]""" + + return prompt + + def _generate_ai_scores(self, prompt: str) -> Tuple[Dict[str, float], str]: + """使用AI模型生成评分""" + import torch + + # 准备消息 + messages = [ + {"role": "system", "content": "你是一个专业的对话质量评估专家,请客观公正地评分。"}, + {"role": "user", "content": prompt} + ] + + # 应用对话模板 + inputs = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True, + enable_thinking=False + ) + + # 移动到设备 + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + # 生成评分 + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=300, + do_sample=True, + temperature=0.3, # 较低温度确保评分稳定 + top_p=0.8, + pad_token_id=self.tokenizer.eos_token_id, + repetition_penalty=1.1 + ) + + # 解码输出 + response = outputs[0][inputs['input_ids'].shape[1]:] + result_text = self.tokenizer.decode(response, skip_special_tokens=True).strip() + + # 解析评分结果 + scores, feedback = self._parse_score_response(result_text) + + return scores, feedback + + def _parse_score_response(self, response: str) -> Tuple[Dict[str, float], str]: + """解析AI评分响应""" + scores = {} + feedback = "" + + # 定义维度映射 + dimension_map = { + "连贯性": "coherence", + "角色一致性": "character_consistency", + "自然度": "naturalness", + "信息密度": "information_density", + "创意性": "creativity" + } + + try: + lines = response.split('\n') + feedback_start = False + + for line in lines: + line = line.strip() + + # 查找评分 + for chinese_name, english_key in dimension_map.items(): + if chinese_name in line and ':' in line: + # 提取分数 + score_match = re.search(r'(\d+(?:\.\d+)?)', line) + if score_match: + score = float(score_match.group(1)) + # 确保分数在1-10范围内 + score = max(1.0, min(10.0, score)) + scores[english_key] = score + + # 查找总体评价 + if '总体评价' in line or '评价' in line: + feedback_start = True + feedback_content = line.split(':', 1) + if len(feedback_content) > 1: + feedback += feedback_content[1].strip() + elif feedback_start and line: + feedback += " " + line + + # 确保所有维度都有分数 + for english_key in dimension_map.values(): + if english_key not in scores: + scores[english_key] = 5.0 # 默认中等分数 + + if not feedback: + feedback = "AI评分完成,建议根据各维度分数进行改进。" + + except Exception as e: + print(f"解析评分响应失败: {e}") + # 使用默认分数 + for english_key in dimension_map.values(): + scores[english_key] = 5.0 + feedback = "评分解析失败,使用默认分数。" + + return scores, feedback + + def _calculate_overall_score(self, scores: Dict[str, float]) -> float: + """计算总分""" + total_score = 0.0 + total_weight = 0.0 + + for dimension, score in scores.items(): + if dimension in self.score_dimensions: + weight = self.score_dimensions[dimension]["weight"] + total_score += score * weight + total_weight += weight + + if total_weight > 0: + return round(total_score / total_weight, 2) + else: + return 5.0 + + def _create_default_score(self, dialogue_content: str, speaker: str) -> ScoreResult: + """创建默认评分结果""" + default_scores = { + "coherence": 5.0, + "character_consistency": 5.0, + "naturalness": 5.0, + "information_density": 5.0, + "creativity": 5.0 + } + + return ScoreResult( + dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + session_id="", + speaker=speaker, + content=dialogue_content, + timestamp=datetime.now().isoformat(), + scores=default_scores, + overall_score=5.0, + feedback="使用默认评分", + scorer_type="ai" + ) + + def batch_score_dialogue(self, dialogue_list: List[Dict]) -> List[ScoreResult]: + """批量评分对话""" + results = [] + + for i, dialogue_item in enumerate(dialogue_list): + print(f"正在评分 {i+1}/{len(dialogue_list)}: {dialogue_item.get('speaker', '未知')}") + + try: + result = self.score_dialogue( + dialogue_content=dialogue_item.get('content', ''), + speaker=dialogue_item.get('speaker', '未知'), + character_data=dialogue_item.get('character_data', {}), + dialogue_history=dialogue_item.get('dialogue_history', []), + context_info=dialogue_item.get('context_info', []) + ) + + # 设置session_id + result.session_id = dialogue_item.get('session_id', '') + results.append(result) + + except Exception as e: + print(f"评分失败: {e}") + # 添加默认评分 + default_result = self._create_default_score( + dialogue_item.get('content', ''), + dialogue_item.get('speaker', '未知') + ) + default_result.session_id = dialogue_item.get('session_id', '') + results.append(default_result) + + return results + +class HumanScorer: + """人工评分器""" + + def __init__(self): + self.score_dimensions = { + "coherence": "连贯性", + "character_consistency": "角色一致性", + "naturalness": "自然度", + "information_density": "信息密度", + "creativity": "创意性" + } + + def score_dialogue_interactive(self, + dialogue_content: str, + speaker: str, + session_id: str = "") -> ScoreResult: + """交互式人工评分""" + + print(f"\n=== 人工评分 ===") + print(f"角色: {speaker}") + print(f"对话: {dialogue_content}") + print(f"请对以下维度评分 (1-10分):") + + scores = {} + + for dimension_key, dimension_name in self.score_dimensions.items(): + while True: + try: + score_input = input(f"{dimension_name} (1-10): ").strip() + score = float(score_input) + if 1 <= score <= 10: + scores[dimension_key] = score + break + else: + print("请输入1-10之间的分数") + except ValueError: + print("请输入有效的数字") + + # 获取反馈 + feedback = input("请输入评价和建议 (可选): ").strip() + if not feedback: + feedback = "人工评分完成" + + # 计算总分 + overall_score = sum(scores.values()) / len(scores) + + return ScoreResult( + dialogue_id=f"{speaker}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + session_id=session_id, + speaker=speaker, + content=dialogue_content, + timestamp=datetime.now().isoformat(), + scores=scores, + overall_score=round(overall_score, 2), + feedback=feedback, + scorer_type="human" + ) + +class QuickScorer: + """快速规则评分器(用于实时反馈)""" + + def __init__(self): + pass + + def quick_score(self, + dialogue_content: str, + speaker: str, + dialogue_history: List[Dict] = None) -> float: + """快速评分(基于规则)""" + + score = 5.0 # 基础分 + + # 长度检查 + content_length = len(dialogue_content.strip()) + if content_length < 10: + score -= 2.0 # 太短 + elif content_length > 200: + score -= 1.0 # 太长 + elif 30 <= content_length <= 100: + score += 0.5 # 长度适中 + + # 重复检查 + if dialogue_history: + recent_content = [turn.get('content', '') for turn in dialogue_history[-3:]] + for prev_content in recent_content: + if dialogue_content == prev_content: + score -= 3.0 # 重复内容 + elif self._calculate_similarity(dialogue_content, prev_content) > 0.8: + score -= 1.5 # 高度相似 + + # 内容质量检查 + if any(word in dialogue_content for word in ['...', '呃', '额', '嗯嗯']): + score -= 0.5 # 含有填充词 + + if re.search(r'[。!?]', dialogue_content): + score += 0.3 # 有标点符号 + + # 确保分数在合理范围内 + return max(1.0, min(10.0, score)) + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """计算文本相似度(简单方法)""" + words1 = set(text1) + words2 = set(text2) + + if not words1 and not words2: + return 1.0 + + intersection = len(words1.intersection(words2)) + union = len(words1.union(words2)) + + return intersection / union if union > 0 else 0.0 \ No newline at end of file diff --git a/AITrain/main_controller.py b/AITrain/main_controller.py index b3b578a..605387a 100644 --- a/AITrain/main_controller.py +++ b/AITrain/main_controller.py @@ -529,215 +529,75 @@ def run_model_optimization(): conv_mgr = ConversationManager("./conversation_data/conversations.db") - print("模型优化选项:") - print("1. 分析优化需求") - print("2. 生成LoRA训练脚本") - print("3. 创建提示优化配置") - print("4. 执行增量训练") - print("5. 性能对比验证") + # print("模型优化选项:") + + # print("1. 生成LoRA训练脚本") + # print("2. 执行增量训练") + - choice = input("请输入选择 (1-5): ").strip() + # choice = input("请输入选择 (1-5): ").strip() + + # if choice == '1': + # 生成LoRA训练脚本 + print("\n=== 生成LoRA训练脚本 ===") - if choice == '1': - # 分析优化需求 - print("\n=== 优化需求分析 ===") - - with sqlite3.connect(conv_mgr.db_path) as conn: - # 获取性能数据 - cursor = conn.execute(""" - SELECT - COUNT(*) as total, - AVG(dialogue_score) as avg_score, - AVG(CASE WHEN dialogue_score >= 8.0 THEN 1.0 ELSE 0.0 END) as high_quality_rate - FROM dialogue_turns WHERE dialogue_score > 0 - """) - - total, avg_score, hq_rate = cursor.fetchone() - - print(f"当前性能指标:") - print(f" - 总评分样本: {total}") - print(f" - 平均分数: {avg_score:.2f}") - print(f" - 高质量率: {hq_rate*100:.1f}%") - - # 分析优化潜力 - optimization_needs = [] - - if avg_score < 7.0: - optimization_needs.append("整体质量需要提升") - - if hq_rate < 0.6: - optimization_needs.append("高质量对话比例偏低") - - # 分析各维度表现 - cursor = conn.execute(""" - SELECT score_details FROM dialogue_turns - WHERE dialogue_score > 0 AND score_details != '{}' - ORDER BY timestamp DESC LIMIT 100 - """) - - dim_scores = {'coherence': [], 'character_consistency': [], - 'naturalness': [], 'information_density': [], 'creativity': []} - - for (score_details,) in cursor.fetchall(): - try: - scores = json.loads(score_details) - for dim, score in scores.items(): - if dim in dim_scores: - dim_scores[dim].append(float(score)) - except: - continue - - weak_dimensions = [] - print(f"\n维度分析:") - for dim, scores in dim_scores.items(): - if scores: - avg = sum(scores) / len(scores) - print(f" - {dim}: {avg:.2f}分") - if avg < 7.0: - weak_dimensions.append(dim) - - if weak_dimensions: - optimization_needs.append(f"薄弱维度: {weak_dimensions}") - - print(f"\n优化建议:") - if optimization_needs: - for i, need in enumerate(optimization_needs, 1): - print(f" {i}. {need}") - else: - print(" 当前模型表现良好,可考虑细微调优") - - elif choice == '2': - # 生成LoRA训练脚本 - print("\n=== 生成LoRA训练脚本 ===") - - script_content = generate_lora_training_script() - script_path = "./scripts/iterative_lora_training.py" - - os.makedirs("./scripts", exist_ok=True) - with open(script_path, 'w', encoding='utf-8') as f: - f.write(script_content) - - print(f"✓ LoRA训练脚本已生成: {script_path}") - print("使用方法:") - print(" 1. 先运行训练数据生成 (选项8)") - print(" 2. 修改脚本中的路径配置") - print(f" 3. 运行: python {script_path}") - - elif choice == '3': - # 创建提示优化配置 - print("\n=== 创建提示优化配置 ===") - - config = generate_prompt_optimization_config() - config_path = "./config/prompt_optimization.json" - - os.makedirs("./config", exist_ok=True) - with open(config_path, 'w', encoding='utf-8') as f: - json.dump(config, f, ensure_ascii=False, indent=2) - - print(f"✓ 提示优化配置已生成: {config_path}") - print("配置包含:") - print(" - 动态提示调整规则") - print(" - 质量阈值设置") - print(" - 生成参数优化") - - elif choice == '4': - # 执行增量训练 - print("\n=== 执行增量训练 ===") - - # 检查训练数据 - training_dir = "./training_data" - if not os.path.exists(training_dir): - print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)") - return - - training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')] - if not training_files: - print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)") - return - - print(f"找到训练数据文件:") - for i, file in enumerate(training_files, 1): - print(f" {i}. {file}") - - file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip() - try: - selected_file = training_files[int(file_idx) - 1] - training_file_path = os.path.join(training_dir, selected_file) - - print(f"将使用训练文件: {selected_file}") - print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源") - - # 生成训练命令 - training_command = generate_training_command(training_file_path) - print(f"建议训练命令:") - print(f" {training_command}") - - # 可选:执行训练(需要用户确认) - confirm = input("是否现在执行训练?(y/N): ").strip().lower() - if confirm == 'y': - print("开始增量训练...") - # 这里可以添加实际的训练执行逻辑 - print("⚠ 训练功能需要根据实际环境配置") - - except (ValueError, IndexError): - print("❌ 无效的文件选择") - - elif choice == '5': - # 性能对比验证 - print("\n=== 性能对比验证 ===") - - print("验证选项:") - print("1. 历史性能趋势对比") - print("2. A/B测试配置生成") - print("3. 模型版本性能对比") - - verify_choice = input("请输入选择 (1-3): ").strip() - - if verify_choice == '1': - # 历史性能趋势 - with sqlite3.connect(conv_mgr.db_path) as conn: - cursor = conn.execute(""" - SELECT - DATE(timestamp) as date, - AVG(dialogue_score) as avg_score, - COUNT(*) as count - FROM dialogue_turns - WHERE dialogue_score > 0 - AND timestamp >= datetime('now', '-30 days') - GROUP BY DATE(timestamp) - ORDER BY date ASC - """) - - trend_data = cursor.fetchall() - if trend_data: - print(f"30天性能趋势:") - for date, avg_score, count in trend_data: - trend = "📈" if avg_score > 7.5 else "📉" if avg_score < 6.5 else "📊" - print(f" {date}: {avg_score:.2f}分 {trend} ({count}条对话)") - else: - print("暂无足够的历史数据") - - elif verify_choice == '2': - # A/B测试配置 - ab_config = generate_ab_test_config() - ab_config_path = "./config/ab_test_config.json" - - os.makedirs("./config", exist_ok=True) - with open(ab_config_path, 'w', encoding='utf-8') as f: - json.dump(ab_config, f, ensure_ascii=False, indent=2) - - print(f"✓ A/B测试配置已生成: {ab_config_path}") - print("配置包含:") - print(" - 对照组和实验组设置") - print(" - 评估指标定义") - print(" - 测试持续时间配置") - - elif verify_choice == '3': - print("模型版本对比功能开发中...") - print("建议手动记录不同版本的性能指标进行对比") + script_content = generate_lora_training_script() + script_path = "./scripts/iterative_lora_training.py" - else: - print("❌ 无效选择") + os.makedirs("./scripts", exist_ok=True) + with open(script_path, 'w', encoding='utf-8') as f: + f.write(script_content) + + print(f"✓ LoRA训练脚本已生成: {script_path}") + print("使用方法:") + print(" 1. 先运行训练数据生成 (选项8)") + print(" 2. 修改脚本中的路径配置") + print(f" 3. 运行: python {script_path}") + + # elif choice == '2': + # # 执行增量训练 + # print("\n=== 执行增量训练 ===") + + # # 检查训练数据 + # training_dir = "./training_data" + # if not os.path.exists(training_dir): + # print("❌ 训练数据目录不存在,请先生成训练数据 (选项8)") + # return + + # training_files = [f for f in os.listdir(training_dir) if f.endswith('.json')] + # if not training_files: + # print("❌ 未找到训练数据文件,请先生成训练数据 (选项8)") + # return + + # print(f"找到训练数据文件:") + # for i, file in enumerate(training_files, 1): + # print(f" {i}. {file}") + + # file_idx = input(f"选择训练数据文件 (1-{len(training_files)}): ").strip() + # try: + # selected_file = training_files[int(file_idx) - 1] + # training_file_path = os.path.join(training_dir, selected_file) + + # print(f"将使用训练文件: {selected_file}") + # print("⚠ 注意:实际训练需要配置正确的模型路径和计算资源") + + # # 生成训练命令 + # training_command = generate_training_command(training_file_path) + # print(f"建议训练命令:") + # print(f" {training_command}") + + # # 可选:执行训练(需要用户确认) + # confirm = input("是否现在执行训练?(y/N): ").strip().lower() + # if confirm == 'y': + # print("开始增量训练...") + # # 这里可以添加实际的训练执行逻辑 + # print("⚠ 训练功能需要根据实际环境配置") + + # except (ValueError, IndexError): + # print("❌ 无效的文件选择") + + # else: + # print("❌ 无效选择") except Exception as e: print(f"✗ 模型优化失败: {e}") @@ -837,102 +697,8 @@ if __name__ == '__main__': trainer.train() ''' -def generate_prompt_optimization_config(): - """生成提示优化配置""" - return { - "optimization_rules": { - "quality_thresholds": { - "excellent": 9.0, - "good": 8.0, - "acceptable": 7.0, - "needs_improvement": 6.0 - }, - "adaptive_adjustments": { - "low_coherence": { - "add_context_emphasis": True, - "increase_history_weight": 1.2, - "add_logical_constraints": True - }, - "low_character_consistency": { - "enhance_character_description": True, - "add_personality_reminders": True, - "increase_character_weight": 1.5 - }, - "low_creativity": { - "add_creativity_prompts": True, - "increase_temperature": 0.1, - "diversify_examples": True - } - } - }, - "dynamic_prompts": { - "quality_boost_templates": [ - "请生成一个富有创意且符合角色特征的高质量对话", - "确保对话逻辑连贯,信息丰富,表达自然", - "体现角色的独特性格和说话风格" - ], - "problem_specific_guidance": { - "repetitive": "避免重复之前的内容,提供新的信息和观点", - "inconsistent": "严格遵循角色设定,保持人格一致性", - "dull": "增加对话的趣味性和深度,使用生动的表达" - } - }, - "generation_parameters": { - "adaptive_temperature": { - "high_creativity_needed": 0.9, - "normal": 0.8, - "high_consistency_needed": 0.7 - }, - "adaptive_top_p": { - "creative_mode": 0.9, - "balanced_mode": 0.8, - "conservative_mode": 0.7 - } - } - } -def generate_ab_test_config(): - """生成A/B测试配置""" - return { - "test_name": "model_optimization_validation", - "created_at": datetime.now().isoformat(), - "groups": { - "control": { - "name": "原始模型", - "description": "未优化的基础模型", - "model_path": "/path/to/base/model", - "sample_ratio": 0.5 - }, - "experimental": { - "name": "优化模型", - "description": "经过迭代优化的模型", - "model_path": "/path/to/optimized/model", - "sample_ratio": 0.5 - } - }, - "evaluation_metrics": { - "primary": [ - "overall_dialogue_score", - "character_consistency_score", - "creativity_score" - ], - "secondary": [ - "user_satisfaction", - "response_time", - "coherence_score" - ] - }, - "test_duration": { - "target_samples": 200, - "max_duration_days": 7, - "min_samples_per_group": 50 - }, - "statistical_settings": { - "confidence_level": 0.95, - "minimum_effect_size": 0.3, - "power": 0.8 - } - } + def generate_training_command(training_file_path): """生成训练命令""" diff --git a/AITrain/score_data_manager.py b/AITrain/score_data_manager.py new file mode 100644 index 0000000..68d10fe --- /dev/null +++ b/AITrain/score_data_manager.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +对话评分数据存储管理系统 +管理评分结果、统计分析、数据导出等功能 +''' + +import sqlite3 +import json +import pandas as pd +from typing import Dict, List, Optional, Tuple +from datetime import datetime, timedelta +from dataclasses import asdict +import numpy as np +import os + +from dialogue_scorer import ScoreResult + +class ScoreDataManager: + """评分数据管理器""" + + def __init__(self, db_path: str = "score_data.db"): + """ + 初始化评分数据管理器 + + Args: + db_path: 数据库文件路径 + """ + self.db_path = db_path + self._init_database() + + def _init_database(self): + """初始化数据库""" + with sqlite3.connect(self.db_path) as conn: + # 评分结果表 + conn.execute(''' + CREATE TABLE IF NOT EXISTS score_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + dialogue_id TEXT UNIQUE, + session_id TEXT, + speaker TEXT, + content TEXT, + timestamp TEXT, + scorer_type TEXT, + coherence_score REAL, + character_consistency_score REAL, + naturalness_score REAL, + information_density_score REAL, + creativity_score REAL, + overall_score REAL, + feedback TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # 评分统计表 + conn.execute(''' + CREATE TABLE IF NOT EXISTS score_statistics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + speaker TEXT, + date_range TEXT, + total_dialogues INTEGER, + avg_overall_score REAL, + avg_coherence REAL, + avg_character_consistency REAL, + avg_naturalness REAL, + avg_information_density REAL, + avg_creativity REAL, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # 模型训练数据表 + conn.execute(''' + CREATE TABLE IF NOT EXISTS training_data ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + dialogue_content TEXT, + speaker TEXT, + character_data TEXT, + context_info TEXT, + dialogue_history TEXT, + score_overall REAL, + score_dimensions TEXT, + feedback TEXT, + is_high_quality INTEGER, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + ''') + + # 优化记录表 + conn.execute(''' + CREATE TABLE IF NOT EXISTS optimization_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + optimization_type TEXT, + before_score REAL, + after_score REAL, + improvement REAL, + parameters_changed TEXT, + description TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + ''') + + conn.commit() + + def save_score_result(self, score_result: ScoreResult) -> bool: + """ + 保存评分结果 + + Args: + score_result: 评分结果对象 + + Returns: + bool: 保存是否成功 + """ + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(''' + INSERT OR REPLACE INTO score_results + (dialogue_id, session_id, speaker, content, timestamp, scorer_type, + coherence_score, character_consistency_score, naturalness_score, + information_density_score, creativity_score, overall_score, feedback) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', ( + score_result.dialogue_id, + score_result.session_id, + score_result.speaker, + score_result.content, + score_result.timestamp, + score_result.scorer_type, + score_result.scores.get('coherence', 0), + score_result.scores.get('character_consistency', 0), + score_result.scores.get('naturalness', 0), + score_result.scores.get('information_density', 0), + score_result.scores.get('creativity', 0), + score_result.overall_score, + score_result.feedback + )) + conn.commit() + return True + except Exception as e: + print(f"保存评分结果失败: {e}") + return False + + def get_score_results(self, + session_id: Optional[str] = None, + speaker: Optional[str] = None, + min_score: Optional[float] = None, + max_score: Optional[float] = None, + limit: int = 100) -> List[Dict]: + """ + 获取评分结果 + + Args: + session_id: 会话ID过滤 + speaker: 角色过滤 + min_score: 最低分数过滤 + max_score: 最高分数过滤 + limit: 返回数量限制 + + Returns: + List[Dict]: 评分结果列表 + """ + query = "SELECT * FROM score_results WHERE 1=1" + params = [] + + if session_id: + query += " AND session_id = ?" + params.append(session_id) + + if speaker: + query += " AND speaker = ?" + params.append(speaker) + + if min_score is not None: + query += " AND overall_score >= ?" + params.append(min_score) + + if max_score is not None: + query += " AND overall_score <= ?" + params.append(max_score) + + query += " ORDER BY created_at DESC LIMIT ?" + params.append(limit) + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, params) + return [dict(row) for row in cursor.fetchall()] + + def calculate_statistics(self, + session_id: Optional[str] = None, + speaker: Optional[str] = None, + days: int = 7) -> Dict: + """ + 计算评分统计信息 + + Args: + session_id: 会话ID过滤 + speaker: 角色过滤 + days: 统计天数 + + Returns: + Dict: 统计信息 + """ + # 计算时间范围 + end_date = datetime.now() + start_date = end_date - timedelta(days=days) + + query = ''' + SELECT + COUNT(*) as total_dialogues, + AVG(overall_score) as avg_overall, + AVG(coherence_score) as avg_coherence, + AVG(character_consistency_score) as avg_character_consistency, + AVG(naturalness_score) as avg_naturalness, + AVG(information_density_score) as avg_information_density, + AVG(creativity_score) as avg_creativity, + MIN(overall_score) as min_score, + MAX(overall_score) as max_score, + speaker + FROM score_results + WHERE created_at >= ? AND created_at <= ? + ''' + + params = [start_date.isoformat(), end_date.isoformat()] + + if session_id: + query += " AND session_id = ?" + params.append(session_id) + + if speaker: + query += " AND speaker = ?" + params.append(speaker) + else: + query += " GROUP BY speaker" + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, params) + results = [dict(row) for row in cursor.fetchall()] + + return { + "period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}", + "statistics": results + } + + def get_high_quality_samples(self, threshold: float = 8.0, limit: int = 50) -> List[Dict]: + """ + 获取高质量对话样本 + + Args: + threshold: 分数阈值 + limit: 返回数量限制 + + Returns: + List[Dict]: 高质量样本列表 + """ + query = ''' + SELECT dialogue_id, session_id, speaker, content, overall_score, feedback + FROM score_results + WHERE overall_score >= ? + ORDER BY overall_score DESC + LIMIT ? + ''' + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, [threshold, limit]) + return [dict(row) for row in cursor.fetchall()] + + def get_low_quality_samples(self, threshold: float = 4.0, limit: int = 50) -> List[Dict]: + """ + 获取低质量对话样本 + + Args: + threshold: 分数阈值 + limit: 返回数量限制 + + Returns: + List[Dict]: 低质量样本列表 + """ + query = ''' + SELECT dialogue_id, session_id, speaker, content, overall_score, feedback + FROM score_results + WHERE overall_score <= ? + ORDER BY overall_score ASC + LIMIT ? + ''' + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, [threshold, limit]) + return [dict(row) for row in cursor.fetchall()] + + def save_training_data(self, + dialogue_content: str, + speaker: str, + character_data: Dict, + context_info: List[Dict], + dialogue_history: List[Dict], + score_result: ScoreResult) -> bool: + """ + 保存用于模型训练的数据 + + Args: + dialogue_content: 对话内容 + speaker: 说话者 + character_data: 角色数据 + context_info: 上下文信息 + dialogue_history: 对话历史 + score_result: 评分结果 + + Returns: + bool: 保存是否成功 + """ + try: + # 判断是否为高质量样本 + is_high_quality = 1 if score_result.overall_score >= 7.0 else 0 + + with sqlite3.connect(self.db_path) as conn: + conn.execute(''' + INSERT INTO training_data + (dialogue_content, speaker, character_data, context_info, + dialogue_history, score_overall, score_dimensions, feedback, is_high_quality) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', ( + dialogue_content, + speaker, + json.dumps(character_data, ensure_ascii=False), + json.dumps(context_info, ensure_ascii=False), + json.dumps([asdict(turn) if hasattr(turn, '__dict__') else turn for turn in dialogue_history], ensure_ascii=False), + score_result.overall_score, + json.dumps(score_result.scores, ensure_ascii=False), + score_result.feedback, + is_high_quality + )) + conn.commit() + return True + except Exception as e: + print(f"保存训练数据失败: {e}") + return False + + def export_to_csv(self, output_path: str = "score_data_export.csv") -> bool: + """ + 导出评分数据到CSV文件 + + Args: + output_path: 输出文件路径 + + Returns: + bool: 导出是否成功 + """ + try: + query = ''' + SELECT dialogue_id, session_id, speaker, content, timestamp, scorer_type, + coherence_score, character_consistency_score, naturalness_score, + information_density_score, creativity_score, overall_score, feedback + FROM score_results + ORDER BY created_at DESC + ''' + + with sqlite3.connect(self.db_path) as conn: + df = pd.read_sql_query(query, conn) + df.to_csv(output_path, index=False, encoding='utf-8-sig') + + print(f"✓ 评分数据已导出到: {output_path}") + return True + except Exception as e: + print(f"导出数据失败: {e}") + return False + + def get_score_trends(self, speaker: str, days: int = 30) -> Dict: + """ + 获取评分趋势数据 + + Args: + speaker: 角色名称 + days: 统计天数 + + Returns: + Dict: 趋势数据 + """ + end_date = datetime.now() + start_date = end_date - timedelta(days=days) + + query = ''' + SELECT DATE(created_at) as date, + AVG(overall_score) as avg_score, + COUNT(*) as count + FROM score_results + WHERE speaker = ? AND created_at >= ? AND created_at <= ? + GROUP BY DATE(created_at) + ORDER BY date + ''' + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, [speaker, start_date.isoformat(), end_date.isoformat()]) + results = [dict(row) for row in cursor.fetchall()] + + return { + "speaker": speaker, + "period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}", + "trend_data": results + } + + def save_optimization_record(self, + optimization_type: str, + before_score: float, + after_score: float, + parameters_changed: Dict, + description: str = "") -> bool: + """ + 保存优化记录 + + Args: + optimization_type: 优化类型 + before_score: 优化前分数 + after_score: 优化后分数 + parameters_changed: 改变的参数 + description: 描述 + + Returns: + bool: 保存是否成功 + """ + try: + improvement = after_score - before_score + + with sqlite3.connect(self.db_path) as conn: + conn.execute(''' + INSERT INTO optimization_records + (optimization_type, before_score, after_score, improvement, + parameters_changed, description) + VALUES (?, ?, ?, ?, ?, ?) + ''', ( + optimization_type, + before_score, + after_score, + improvement, + json.dumps(parameters_changed, ensure_ascii=False), + description + )) + conn.commit() + return True + except Exception as e: + print(f"保存优化记录失败: {e}") + return False + + def get_optimization_history(self, limit: int = 20) -> List[Dict]: + """ + 获取优化历史记录 + + Args: + limit: 返回数量限制 + + Returns: + List[Dict]: 优化历史记录 + """ + query = ''' + SELECT * FROM optimization_records + ORDER BY created_at DESC + LIMIT ? + ''' + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, [limit]) + return [dict(row) for row in cursor.fetchall()] + +class ScoreAnalyzer: + """评分分析器""" + + def __init__(self, score_manager: ScoreDataManager): + """ + 初始化评分分析器 + + Args: + score_manager: 评分数据管理器 + """ + self.score_manager = score_manager + + def analyze_character_performance(self, character_name: str) -> Dict: + """ + 分析角色表现 + + Args: + character_name: 角色名称 + + Returns: + Dict: 分析结果 + """ + # 获取角色的所有评分数据 + scores = self.score_manager.get_score_results(speaker=character_name, limit=1000) + + if not scores: + return {"error": f"没有找到角色 {character_name} 的评分数据"} + + # 计算各种统计信息 + overall_scores = [score['overall_score'] for score in scores] + + analysis = { + "character_name": character_name, + "total_dialogues": len(scores), + "average_score": round(np.mean(overall_scores), 2), + "median_score": round(np.median(overall_scores), 2), + "score_std": round(np.std(overall_scores), 2), + "min_score": min(overall_scores), + "max_score": max(overall_scores), + "score_distribution": self._calculate_score_distribution(overall_scores), + "dimension_analysis": self._analyze_dimensions(scores), + "improvement_suggestions": self._generate_improvement_suggestions(scores) + } + + return analysis + + def _calculate_score_distribution(self, scores: List[float]) -> Dict: + """计算分数分布""" + distribution = { + "excellent": len([s for s in scores if s >= 8.0]), + "good": len([s for s in scores if 6.0 <= s < 8.0]), + "average": len([s for s in scores if 4.0 <= s < 6.0]), + "poor": len([s for s in scores if s < 4.0]) + } + + total = len(scores) + if total > 0: + distribution = {k: {"count": v, "percentage": round(v/total*100, 1)} + for k, v in distribution.items()} + + return distribution + + def _analyze_dimensions(self, scores: List[Dict]) -> Dict: + """分析各维度表现""" + dimensions = ['coherence_score', 'character_consistency_score', 'naturalness_score', + 'information_density_score', 'creativity_score'] + + dimension_analysis = {} + + for dim in dimensions: + dim_scores = [score[dim] for score in scores if score[dim] is not None] + if dim_scores: + dimension_analysis[dim] = { + "average": round(np.mean(dim_scores), 2), + "min": min(dim_scores), + "max": max(dim_scores), + "std": round(np.std(dim_scores), 2) + } + + return dimension_analysis + + def _generate_improvement_suggestions(self, scores: List[Dict]) -> List[str]: + """生成改进建议""" + suggestions = [] + + # 分析最弱的维度 + dimensions = { + 'coherence_score': '连贯性', + 'character_consistency_score': '角色一致性', + 'naturalness_score': '自然度', + 'information_density_score': '信息密度', + 'creativity_score': '创意性' + } + + dim_averages = {} + for dim_key, dim_name in dimensions.items(): + dim_scores = [score[dim_key] for score in scores if score[dim_key] is not None] + if dim_scores: + dim_averages[dim_name] = np.mean(dim_scores) + + # 找出最弱的维度 + if dim_averages: + weakest_dim = min(dim_averages.items(), key=lambda x: x[1]) + if weakest_dim[1] < 6.0: + suggestions.append(f"需要重点改进{weakest_dim[0]},当前平均分为{weakest_dim[1]:.1f}") + + # 根据分数分布给建议 + overall_avg = np.mean([score['overall_score'] for score in scores]) + if overall_avg < 5.0: + suggestions.append("整体表现需要改进,建议调整角色设定或提示词") + elif overall_avg < 7.0: + suggestions.append("表现良好,可以尝试增加对话的创意性和深度") + else: + suggestions.append("表现优秀,继续保持当前设定") + + return suggestions + + def compare_characters(self, character_names: List[str]) -> Dict: + """ + 比较多个角色的表现 + + Args: + character_names: 角色名称列表 + + Returns: + Dict: 比较结果 + """ + comparison = {"characters": {}} + + for char_name in character_names: + scores = self.score_manager.get_score_results(speaker=char_name, limit=500) + if scores: + overall_scores = [score['overall_score'] for score in scores] + comparison["characters"][char_name] = { + "total_dialogues": len(scores), + "average_score": round(np.mean(overall_scores), 2), + "score_range": f"{min(overall_scores):.1f} - {max(overall_scores):.1f}" + } + + # 排序 + if comparison["characters"]: + sorted_chars = sorted(comparison["characters"].items(), + key=lambda x: x[1]["average_score"], reverse=True) + comparison["ranking"] = [{"character": char, **data} for char, data in sorted_chars] + + return comparison \ No newline at end of file