Project02/AITrain/score_data_manager.py

617 lines
22 KiB
Python
Raw Normal View History

2025-08-23 18:37:00 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
对话评分数据存储管理系统
管理评分结果统计分析数据导出等功能
'''
import sqlite3
import json
import pandas as pd
from typing import Dict, List, Optional, Tuple
from datetime import datetime, timedelta
from dataclasses import asdict
import numpy as np
import os
from dialogue_scorer import ScoreResult
class ScoreDataManager:
"""评分数据管理器"""
def __init__(self, db_path: str = "score_data.db"):
"""
初始化评分数据管理器
Args:
db_path: 数据库文件路径
"""
self.db_path = db_path
self._init_database()
def _init_database(self):
"""初始化数据库"""
with sqlite3.connect(self.db_path) as conn:
# 评分结果表
conn.execute('''
CREATE TABLE IF NOT EXISTS score_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
dialogue_id TEXT UNIQUE,
session_id TEXT,
speaker TEXT,
content TEXT,
timestamp TEXT,
scorer_type TEXT,
coherence_score REAL,
character_consistency_score REAL,
naturalness_score REAL,
information_density_score REAL,
creativity_score REAL,
overall_score REAL,
feedback TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
''')
# 评分统计表
conn.execute('''
CREATE TABLE IF NOT EXISTS score_statistics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
speaker TEXT,
date_range TEXT,
total_dialogues INTEGER,
avg_overall_score REAL,
avg_coherence REAL,
avg_character_consistency REAL,
avg_naturalness REAL,
avg_information_density REAL,
avg_creativity REAL,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
''')
# 模型训练数据表
conn.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
dialogue_content TEXT,
speaker TEXT,
character_data TEXT,
context_info TEXT,
dialogue_history TEXT,
score_overall REAL,
score_dimensions TEXT,
feedback TEXT,
is_high_quality INTEGER,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
''')
# 优化记录表
conn.execute('''
CREATE TABLE IF NOT EXISTS optimization_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
optimization_type TEXT,
before_score REAL,
after_score REAL,
improvement REAL,
parameters_changed TEXT,
description TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
def save_score_result(self, score_result: ScoreResult) -> bool:
"""
保存评分结果
Args:
score_result: 评分结果对象
Returns:
bool: 保存是否成功
"""
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
INSERT OR REPLACE INTO score_results
(dialogue_id, session_id, speaker, content, timestamp, scorer_type,
coherence_score, character_consistency_score, naturalness_score,
information_density_score, creativity_score, overall_score, feedback)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
score_result.dialogue_id,
score_result.session_id,
score_result.speaker,
score_result.content,
score_result.timestamp,
score_result.scorer_type,
score_result.scores.get('coherence', 0),
score_result.scores.get('character_consistency', 0),
score_result.scores.get('naturalness', 0),
score_result.scores.get('information_density', 0),
score_result.scores.get('creativity', 0),
score_result.overall_score,
score_result.feedback
))
conn.commit()
return True
except Exception as e:
print(f"保存评分结果失败: {e}")
return False
def get_score_results(self,
session_id: Optional[str] = None,
speaker: Optional[str] = None,
min_score: Optional[float] = None,
max_score: Optional[float] = None,
limit: int = 100) -> List[Dict]:
"""
获取评分结果
Args:
session_id: 会话ID过滤
speaker: 角色过滤
min_score: 最低分数过滤
max_score: 最高分数过滤
limit: 返回数量限制
Returns:
List[Dict]: 评分结果列表
"""
query = "SELECT * FROM score_results WHERE 1=1"
params = []
if session_id:
query += " AND session_id = ?"
params.append(session_id)
if speaker:
query += " AND speaker = ?"
params.append(speaker)
if min_score is not None:
query += " AND overall_score >= ?"
params.append(min_score)
if max_score is not None:
query += " AND overall_score <= ?"
params.append(max_score)
query += " ORDER BY created_at DESC LIMIT ?"
params.append(limit)
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def calculate_statistics(self,
session_id: Optional[str] = None,
speaker: Optional[str] = None,
days: int = 7) -> Dict:
"""
计算评分统计信息
Args:
session_id: 会话ID过滤
speaker: 角色过滤
days: 统计天数
Returns:
Dict: 统计信息
"""
# 计算时间范围
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
query = '''
SELECT
COUNT(*) as total_dialogues,
AVG(overall_score) as avg_overall,
AVG(coherence_score) as avg_coherence,
AVG(character_consistency_score) as avg_character_consistency,
AVG(naturalness_score) as avg_naturalness,
AVG(information_density_score) as avg_information_density,
AVG(creativity_score) as avg_creativity,
MIN(overall_score) as min_score,
MAX(overall_score) as max_score,
speaker
FROM score_results
WHERE created_at >= ? AND created_at <= ?
'''
params = [start_date.isoformat(), end_date.isoformat()]
if session_id:
query += " AND session_id = ?"
params.append(session_id)
if speaker:
query += " AND speaker = ?"
params.append(speaker)
else:
query += " GROUP BY speaker"
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(query, params)
results = [dict(row) for row in cursor.fetchall()]
return {
"period": f"{start_date.strftime('%Y-%m-%d')}{end_date.strftime('%Y-%m-%d')}",
"statistics": results
}
def get_high_quality_samples(self, threshold: float = 8.0, limit: int = 50) -> List[Dict]:
"""
获取高质量对话样本
Args:
threshold: 分数阈值
limit: 返回数量限制
Returns:
List[Dict]: 高质量样本列表
"""
query = '''
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
FROM score_results
WHERE overall_score >= ?
ORDER BY overall_score DESC
LIMIT ?
'''
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(query, [threshold, limit])
return [dict(row) for row in cursor.fetchall()]
def get_low_quality_samples(self, threshold: float = 4.0, limit: int = 50) -> List[Dict]:
"""
获取低质量对话样本
Args:
threshold: 分数阈值
limit: 返回数量限制
Returns:
List[Dict]: 低质量样本列表
"""
query = '''
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
FROM score_results
WHERE overall_score <= ?
ORDER BY overall_score ASC
LIMIT ?
'''
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(query, [threshold, limit])
return [dict(row) for row in cursor.fetchall()]
def save_training_data(self,
dialogue_content: str,
speaker: str,
character_data: Dict,
context_info: List[Dict],
dialogue_history: List[Dict],
score_result: ScoreResult) -> bool:
"""
保存用于模型训练的数据
Args:
dialogue_content: 对话内容
speaker: 说话者
character_data: 角色数据
context_info: 上下文信息
dialogue_history: 对话历史
score_result: 评分结果
Returns:
bool: 保存是否成功
"""
try:
# 判断是否为高质量样本
is_high_quality = 1 if score_result.overall_score >= 7.0 else 0
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
INSERT INTO training_data
(dialogue_content, speaker, character_data, context_info,
dialogue_history, score_overall, score_dimensions, feedback, is_high_quality)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
dialogue_content,
speaker,
json.dumps(character_data, ensure_ascii=False),
json.dumps(context_info, ensure_ascii=False),
json.dumps([asdict(turn) if hasattr(turn, '__dict__') else turn for turn in dialogue_history], ensure_ascii=False),
score_result.overall_score,
json.dumps(score_result.scores, ensure_ascii=False),
score_result.feedback,
is_high_quality
))
conn.commit()
return True
except Exception as e:
print(f"保存训练数据失败: {e}")
return False
def export_to_csv(self, output_path: str = "score_data_export.csv") -> bool:
"""
导出评分数据到CSV文件
Args:
output_path: 输出文件路径
Returns:
bool: 导出是否成功
"""
try:
query = '''
SELECT dialogue_id, session_id, speaker, content, timestamp, scorer_type,
coherence_score, character_consistency_score, naturalness_score,
information_density_score, creativity_score, overall_score, feedback
FROM score_results
ORDER BY created_at DESC
'''
with sqlite3.connect(self.db_path) as conn:
df = pd.read_sql_query(query, conn)
df.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"✓ 评分数据已导出到: {output_path}")
return True
except Exception as e:
print(f"导出数据失败: {e}")
return False
def get_score_trends(self, speaker: str, days: int = 30) -> Dict:
"""
获取评分趋势数据
Args:
speaker: 角色名称
days: 统计天数
Returns:
Dict: 趋势数据
"""
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
query = '''
SELECT DATE(created_at) as date,
AVG(overall_score) as avg_score,
COUNT(*) as count
FROM score_results
WHERE speaker = ? AND created_at >= ? AND created_at <= ?
GROUP BY DATE(created_at)
ORDER BY date
'''
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(query, [speaker, start_date.isoformat(), end_date.isoformat()])
results = [dict(row) for row in cursor.fetchall()]
return {
"speaker": speaker,
"period": f"{start_date.strftime('%Y-%m-%d')}{end_date.strftime('%Y-%m-%d')}",
"trend_data": results
}
def save_optimization_record(self,
optimization_type: str,
before_score: float,
after_score: float,
parameters_changed: Dict,
description: str = "") -> bool:
"""
保存优化记录
Args:
optimization_type: 优化类型
before_score: 优化前分数
after_score: 优化后分数
parameters_changed: 改变的参数
description: 描述
Returns:
bool: 保存是否成功
"""
try:
improvement = after_score - before_score
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
INSERT INTO optimization_records
(optimization_type, before_score, after_score, improvement,
parameters_changed, description)
VALUES (?, ?, ?, ?, ?, ?)
''', (
optimization_type,
before_score,
after_score,
improvement,
json.dumps(parameters_changed, ensure_ascii=False),
description
))
conn.commit()
return True
except Exception as e:
print(f"保存优化记录失败: {e}")
return False
def get_optimization_history(self, limit: int = 20) -> List[Dict]:
"""
获取优化历史记录
Args:
limit: 返回数量限制
Returns:
List[Dict]: 优化历史记录
"""
query = '''
SELECT * FROM optimization_records
ORDER BY created_at DESC
LIMIT ?
'''
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(query, [limit])
return [dict(row) for row in cursor.fetchall()]
class ScoreAnalyzer:
"""评分分析器"""
def __init__(self, score_manager: ScoreDataManager):
"""
初始化评分分析器
Args:
score_manager: 评分数据管理器
"""
self.score_manager = score_manager
def analyze_character_performance(self, character_name: str) -> Dict:
"""
分析角色表现
Args:
character_name: 角色名称
Returns:
Dict: 分析结果
"""
# 获取角色的所有评分数据
scores = self.score_manager.get_score_results(speaker=character_name, limit=1000)
if not scores:
return {"error": f"没有找到角色 {character_name} 的评分数据"}
# 计算各种统计信息
overall_scores = [score['overall_score'] for score in scores]
analysis = {
"character_name": character_name,
"total_dialogues": len(scores),
"average_score": round(np.mean(overall_scores), 2),
"median_score": round(np.median(overall_scores), 2),
"score_std": round(np.std(overall_scores), 2),
"min_score": min(overall_scores),
"max_score": max(overall_scores),
"score_distribution": self._calculate_score_distribution(overall_scores),
"dimension_analysis": self._analyze_dimensions(scores),
"improvement_suggestions": self._generate_improvement_suggestions(scores)
}
return analysis
def _calculate_score_distribution(self, scores: List[float]) -> Dict:
"""计算分数分布"""
distribution = {
"excellent": len([s for s in scores if s >= 8.0]),
"good": len([s for s in scores if 6.0 <= s < 8.0]),
"average": len([s for s in scores if 4.0 <= s < 6.0]),
"poor": len([s for s in scores if s < 4.0])
}
total = len(scores)
if total > 0:
distribution = {k: {"count": v, "percentage": round(v/total*100, 1)}
for k, v in distribution.items()}
return distribution
def _analyze_dimensions(self, scores: List[Dict]) -> Dict:
"""分析各维度表现"""
dimensions = ['coherence_score', 'character_consistency_score', 'naturalness_score',
'information_density_score', 'creativity_score']
dimension_analysis = {}
for dim in dimensions:
dim_scores = [score[dim] for score in scores if score[dim] is not None]
if dim_scores:
dimension_analysis[dim] = {
"average": round(np.mean(dim_scores), 2),
"min": min(dim_scores),
"max": max(dim_scores),
"std": round(np.std(dim_scores), 2)
}
return dimension_analysis
def _generate_improvement_suggestions(self, scores: List[Dict]) -> List[str]:
"""生成改进建议"""
suggestions = []
# 分析最弱的维度
dimensions = {
'coherence_score': '连贯性',
'character_consistency_score': '角色一致性',
'naturalness_score': '自然度',
'information_density_score': '信息密度',
'creativity_score': '创意性'
}
dim_averages = {}
for dim_key, dim_name in dimensions.items():
dim_scores = [score[dim_key] for score in scores if score[dim_key] is not None]
if dim_scores:
dim_averages[dim_name] = np.mean(dim_scores)
# 找出最弱的维度
if dim_averages:
weakest_dim = min(dim_averages.items(), key=lambda x: x[1])
if weakest_dim[1] < 6.0:
suggestions.append(f"需要重点改进{weakest_dim[0]},当前平均分为{weakest_dim[1]:.1f}")
# 根据分数分布给建议
overall_avg = np.mean([score['overall_score'] for score in scores])
if overall_avg < 5.0:
suggestions.append("整体表现需要改进,建议调整角色设定或提示词")
elif overall_avg < 7.0:
suggestions.append("表现良好,可以尝试增加对话的创意性和深度")
else:
suggestions.append("表现优秀,继续保持当前设定")
return suggestions
def compare_characters(self, character_names: List[str]) -> Dict:
"""
比较多个角色的表现
Args:
character_names: 角色名称列表
Returns:
Dict: 比较结果
"""
comparison = {"characters": {}}
for char_name in character_names:
scores = self.score_manager.get_score_results(speaker=char_name, limit=500)
if scores:
overall_scores = [score['overall_score'] for score in scores]
comparison["characters"][char_name] = {
"total_dialogues": len(scores),
"average_score": round(np.mean(overall_scores), 2),
"score_range": f"{min(overall_scores):.1f} - {max(overall_scores):.1f}"
}
# 排序
if comparison["characters"]:
sorted_chars = sorted(comparison["characters"].items(),
key=lambda x: x[1]["average_score"], reverse=True)
comparison["ranking"] = [{"character": char, **data} for char, data in sorted_chars]
return comparison