617 lines
22 KiB
Python
617 lines
22 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
'''
|
|
对话评分数据存储管理系统
|
|
管理评分结果、统计分析、数据导出等功能
|
|
'''
|
|
|
|
import sqlite3
|
|
import json
|
|
import pandas as pd
|
|
from typing import Dict, List, Optional, Tuple
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import asdict
|
|
import numpy as np
|
|
import os
|
|
|
|
from dialogue_scorer import ScoreResult
|
|
|
|
class ScoreDataManager:
|
|
"""评分数据管理器"""
|
|
|
|
def __init__(self, db_path: str = "score_data.db"):
|
|
"""
|
|
初始化评分数据管理器
|
|
|
|
Args:
|
|
db_path: 数据库文件路径
|
|
"""
|
|
self.db_path = db_path
|
|
self._init_database()
|
|
|
|
def _init_database(self):
|
|
"""初始化数据库"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
# 评分结果表
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS score_results (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
dialogue_id TEXT UNIQUE,
|
|
session_id TEXT,
|
|
speaker TEXT,
|
|
content TEXT,
|
|
timestamp TEXT,
|
|
scorer_type TEXT,
|
|
coherence_score REAL,
|
|
character_consistency_score REAL,
|
|
naturalness_score REAL,
|
|
information_density_score REAL,
|
|
creativity_score REAL,
|
|
overall_score REAL,
|
|
feedback TEXT,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
# 评分统计表
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS score_statistics (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id TEXT,
|
|
speaker TEXT,
|
|
date_range TEXT,
|
|
total_dialogues INTEGER,
|
|
avg_overall_score REAL,
|
|
avg_coherence REAL,
|
|
avg_character_consistency REAL,
|
|
avg_naturalness REAL,
|
|
avg_information_density REAL,
|
|
avg_creativity REAL,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
# 模型训练数据表
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS training_data (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
dialogue_content TEXT,
|
|
speaker TEXT,
|
|
character_data TEXT,
|
|
context_info TEXT,
|
|
dialogue_history TEXT,
|
|
score_overall REAL,
|
|
score_dimensions TEXT,
|
|
feedback TEXT,
|
|
is_high_quality INTEGER,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
# 优化记录表
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS optimization_records (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
optimization_type TEXT,
|
|
before_score REAL,
|
|
after_score REAL,
|
|
improvement REAL,
|
|
parameters_changed TEXT,
|
|
description TEXT,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
conn.commit()
|
|
|
|
def save_score_result(self, score_result: ScoreResult) -> bool:
|
|
"""
|
|
保存评分结果
|
|
|
|
Args:
|
|
score_result: 评分结果对象
|
|
|
|
Returns:
|
|
bool: 保存是否成功
|
|
"""
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute('''
|
|
INSERT OR REPLACE INTO score_results
|
|
(dialogue_id, session_id, speaker, content, timestamp, scorer_type,
|
|
coherence_score, character_consistency_score, naturalness_score,
|
|
information_density_score, creativity_score, overall_score, feedback)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
score_result.dialogue_id,
|
|
score_result.session_id,
|
|
score_result.speaker,
|
|
score_result.content,
|
|
score_result.timestamp,
|
|
score_result.scorer_type,
|
|
score_result.scores.get('coherence', 0),
|
|
score_result.scores.get('character_consistency', 0),
|
|
score_result.scores.get('naturalness', 0),
|
|
score_result.scores.get('information_density', 0),
|
|
score_result.scores.get('creativity', 0),
|
|
score_result.overall_score,
|
|
score_result.feedback
|
|
))
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
print(f"保存评分结果失败: {e}")
|
|
return False
|
|
|
|
def get_score_results(self,
|
|
session_id: Optional[str] = None,
|
|
speaker: Optional[str] = None,
|
|
min_score: Optional[float] = None,
|
|
max_score: Optional[float] = None,
|
|
limit: int = 100) -> List[Dict]:
|
|
"""
|
|
获取评分结果
|
|
|
|
Args:
|
|
session_id: 会话ID过滤
|
|
speaker: 角色过滤
|
|
min_score: 最低分数过滤
|
|
max_score: 最高分数过滤
|
|
limit: 返回数量限制
|
|
|
|
Returns:
|
|
List[Dict]: 评分结果列表
|
|
"""
|
|
query = "SELECT * FROM score_results WHERE 1=1"
|
|
params = []
|
|
|
|
if session_id:
|
|
query += " AND session_id = ?"
|
|
params.append(session_id)
|
|
|
|
if speaker:
|
|
query += " AND speaker = ?"
|
|
params.append(speaker)
|
|
|
|
if min_score is not None:
|
|
query += " AND overall_score >= ?"
|
|
params.append(min_score)
|
|
|
|
if max_score is not None:
|
|
query += " AND overall_score <= ?"
|
|
params.append(max_score)
|
|
|
|
query += " ORDER BY created_at DESC LIMIT ?"
|
|
params.append(limit)
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(query, params)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def calculate_statistics(self,
|
|
session_id: Optional[str] = None,
|
|
speaker: Optional[str] = None,
|
|
days: int = 7) -> Dict:
|
|
"""
|
|
计算评分统计信息
|
|
|
|
Args:
|
|
session_id: 会话ID过滤
|
|
speaker: 角色过滤
|
|
days: 统计天数
|
|
|
|
Returns:
|
|
Dict: 统计信息
|
|
"""
|
|
# 计算时间范围
|
|
end_date = datetime.now()
|
|
start_date = end_date - timedelta(days=days)
|
|
|
|
query = '''
|
|
SELECT
|
|
COUNT(*) as total_dialogues,
|
|
AVG(overall_score) as avg_overall,
|
|
AVG(coherence_score) as avg_coherence,
|
|
AVG(character_consistency_score) as avg_character_consistency,
|
|
AVG(naturalness_score) as avg_naturalness,
|
|
AVG(information_density_score) as avg_information_density,
|
|
AVG(creativity_score) as avg_creativity,
|
|
MIN(overall_score) as min_score,
|
|
MAX(overall_score) as max_score,
|
|
speaker
|
|
FROM score_results
|
|
WHERE created_at >= ? AND created_at <= ?
|
|
'''
|
|
|
|
params = [start_date.isoformat(), end_date.isoformat()]
|
|
|
|
if session_id:
|
|
query += " AND session_id = ?"
|
|
params.append(session_id)
|
|
|
|
if speaker:
|
|
query += " AND speaker = ?"
|
|
params.append(speaker)
|
|
else:
|
|
query += " GROUP BY speaker"
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(query, params)
|
|
results = [dict(row) for row in cursor.fetchall()]
|
|
|
|
return {
|
|
"period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}",
|
|
"statistics": results
|
|
}
|
|
|
|
def get_high_quality_samples(self, threshold: float = 8.0, limit: int = 50) -> List[Dict]:
|
|
"""
|
|
获取高质量对话样本
|
|
|
|
Args:
|
|
threshold: 分数阈值
|
|
limit: 返回数量限制
|
|
|
|
Returns:
|
|
List[Dict]: 高质量样本列表
|
|
"""
|
|
query = '''
|
|
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
|
|
FROM score_results
|
|
WHERE overall_score >= ?
|
|
ORDER BY overall_score DESC
|
|
LIMIT ?
|
|
'''
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(query, [threshold, limit])
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def get_low_quality_samples(self, threshold: float = 4.0, limit: int = 50) -> List[Dict]:
|
|
"""
|
|
获取低质量对话样本
|
|
|
|
Args:
|
|
threshold: 分数阈值
|
|
limit: 返回数量限制
|
|
|
|
Returns:
|
|
List[Dict]: 低质量样本列表
|
|
"""
|
|
query = '''
|
|
SELECT dialogue_id, session_id, speaker, content, overall_score, feedback
|
|
FROM score_results
|
|
WHERE overall_score <= ?
|
|
ORDER BY overall_score ASC
|
|
LIMIT ?
|
|
'''
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(query, [threshold, limit])
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def save_training_data(self,
|
|
dialogue_content: str,
|
|
speaker: str,
|
|
character_data: Dict,
|
|
context_info: List[Dict],
|
|
dialogue_history: List[Dict],
|
|
score_result: ScoreResult) -> bool:
|
|
"""
|
|
保存用于模型训练的数据
|
|
|
|
Args:
|
|
dialogue_content: 对话内容
|
|
speaker: 说话者
|
|
character_data: 角色数据
|
|
context_info: 上下文信息
|
|
dialogue_history: 对话历史
|
|
score_result: 评分结果
|
|
|
|
Returns:
|
|
bool: 保存是否成功
|
|
"""
|
|
try:
|
|
# 判断是否为高质量样本
|
|
is_high_quality = 1 if score_result.overall_score >= 7.0 else 0
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute('''
|
|
INSERT INTO training_data
|
|
(dialogue_content, speaker, character_data, context_info,
|
|
dialogue_history, score_overall, score_dimensions, feedback, is_high_quality)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
dialogue_content,
|
|
speaker,
|
|
json.dumps(character_data, ensure_ascii=False),
|
|
json.dumps(context_info, ensure_ascii=False),
|
|
json.dumps([asdict(turn) if hasattr(turn, '__dict__') else turn for turn in dialogue_history], ensure_ascii=False),
|
|
score_result.overall_score,
|
|
json.dumps(score_result.scores, ensure_ascii=False),
|
|
score_result.feedback,
|
|
is_high_quality
|
|
))
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
print(f"保存训练数据失败: {e}")
|
|
return False
|
|
|
|
def export_to_csv(self, output_path: str = "score_data_export.csv") -> bool:
|
|
"""
|
|
导出评分数据到CSV文件
|
|
|
|
Args:
|
|
output_path: 输出文件路径
|
|
|
|
Returns:
|
|
bool: 导出是否成功
|
|
"""
|
|
try:
|
|
query = '''
|
|
SELECT dialogue_id, session_id, speaker, content, timestamp, scorer_type,
|
|
coherence_score, character_consistency_score, naturalness_score,
|
|
information_density_score, creativity_score, overall_score, feedback
|
|
FROM score_results
|
|
ORDER BY created_at DESC
|
|
'''
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
df = pd.read_sql_query(query, conn)
|
|
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
|
|
|
print(f"✓ 评分数据已导出到: {output_path}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"导出数据失败: {e}")
|
|
return False
|
|
|
|
def get_score_trends(self, speaker: str, days: int = 30) -> Dict:
|
|
"""
|
|
获取评分趋势数据
|
|
|
|
Args:
|
|
speaker: 角色名称
|
|
days: 统计天数
|
|
|
|
Returns:
|
|
Dict: 趋势数据
|
|
"""
|
|
end_date = datetime.now()
|
|
start_date = end_date - timedelta(days=days)
|
|
|
|
query = '''
|
|
SELECT DATE(created_at) as date,
|
|
AVG(overall_score) as avg_score,
|
|
COUNT(*) as count
|
|
FROM score_results
|
|
WHERE speaker = ? AND created_at >= ? AND created_at <= ?
|
|
GROUP BY DATE(created_at)
|
|
ORDER BY date
|
|
'''
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(query, [speaker, start_date.isoformat(), end_date.isoformat()])
|
|
results = [dict(row) for row in cursor.fetchall()]
|
|
|
|
return {
|
|
"speaker": speaker,
|
|
"period": f"{start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}",
|
|
"trend_data": results
|
|
}
|
|
|
|
def save_optimization_record(self,
|
|
optimization_type: str,
|
|
before_score: float,
|
|
after_score: float,
|
|
parameters_changed: Dict,
|
|
description: str = "") -> bool:
|
|
"""
|
|
保存优化记录
|
|
|
|
Args:
|
|
optimization_type: 优化类型
|
|
before_score: 优化前分数
|
|
after_score: 优化后分数
|
|
parameters_changed: 改变的参数
|
|
description: 描述
|
|
|
|
Returns:
|
|
bool: 保存是否成功
|
|
"""
|
|
try:
|
|
improvement = after_score - before_score
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute('''
|
|
INSERT INTO optimization_records
|
|
(optimization_type, before_score, after_score, improvement,
|
|
parameters_changed, description)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
optimization_type,
|
|
before_score,
|
|
after_score,
|
|
improvement,
|
|
json.dumps(parameters_changed, ensure_ascii=False),
|
|
description
|
|
))
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
print(f"保存优化记录失败: {e}")
|
|
return False
|
|
|
|
def get_optimization_history(self, limit: int = 20) -> List[Dict]:
|
|
"""
|
|
获取优化历史记录
|
|
|
|
Args:
|
|
limit: 返回数量限制
|
|
|
|
Returns:
|
|
List[Dict]: 优化历史记录
|
|
"""
|
|
query = '''
|
|
SELECT * FROM optimization_records
|
|
ORDER BY created_at DESC
|
|
LIMIT ?
|
|
'''
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(query, [limit])
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
class ScoreAnalyzer:
|
|
"""评分分析器"""
|
|
|
|
def __init__(self, score_manager: ScoreDataManager):
|
|
"""
|
|
初始化评分分析器
|
|
|
|
Args:
|
|
score_manager: 评分数据管理器
|
|
"""
|
|
self.score_manager = score_manager
|
|
|
|
def analyze_character_performance(self, character_name: str) -> Dict:
|
|
"""
|
|
分析角色表现
|
|
|
|
Args:
|
|
character_name: 角色名称
|
|
|
|
Returns:
|
|
Dict: 分析结果
|
|
"""
|
|
# 获取角色的所有评分数据
|
|
scores = self.score_manager.get_score_results(speaker=character_name, limit=1000)
|
|
|
|
if not scores:
|
|
return {"error": f"没有找到角色 {character_name} 的评分数据"}
|
|
|
|
# 计算各种统计信息
|
|
overall_scores = [score['overall_score'] for score in scores]
|
|
|
|
analysis = {
|
|
"character_name": character_name,
|
|
"total_dialogues": len(scores),
|
|
"average_score": round(np.mean(overall_scores), 2),
|
|
"median_score": round(np.median(overall_scores), 2),
|
|
"score_std": round(np.std(overall_scores), 2),
|
|
"min_score": min(overall_scores),
|
|
"max_score": max(overall_scores),
|
|
"score_distribution": self._calculate_score_distribution(overall_scores),
|
|
"dimension_analysis": self._analyze_dimensions(scores),
|
|
"improvement_suggestions": self._generate_improvement_suggestions(scores)
|
|
}
|
|
|
|
return analysis
|
|
|
|
def _calculate_score_distribution(self, scores: List[float]) -> Dict:
|
|
"""计算分数分布"""
|
|
distribution = {
|
|
"excellent": len([s for s in scores if s >= 8.0]),
|
|
"good": len([s for s in scores if 6.0 <= s < 8.0]),
|
|
"average": len([s for s in scores if 4.0 <= s < 6.0]),
|
|
"poor": len([s for s in scores if s < 4.0])
|
|
}
|
|
|
|
total = len(scores)
|
|
if total > 0:
|
|
distribution = {k: {"count": v, "percentage": round(v/total*100, 1)}
|
|
for k, v in distribution.items()}
|
|
|
|
return distribution
|
|
|
|
def _analyze_dimensions(self, scores: List[Dict]) -> Dict:
|
|
"""分析各维度表现"""
|
|
dimensions = ['coherence_score', 'character_consistency_score', 'naturalness_score',
|
|
'information_density_score', 'creativity_score']
|
|
|
|
dimension_analysis = {}
|
|
|
|
for dim in dimensions:
|
|
dim_scores = [score[dim] for score in scores if score[dim] is not None]
|
|
if dim_scores:
|
|
dimension_analysis[dim] = {
|
|
"average": round(np.mean(dim_scores), 2),
|
|
"min": min(dim_scores),
|
|
"max": max(dim_scores),
|
|
"std": round(np.std(dim_scores), 2)
|
|
}
|
|
|
|
return dimension_analysis
|
|
|
|
def _generate_improvement_suggestions(self, scores: List[Dict]) -> List[str]:
|
|
"""生成改进建议"""
|
|
suggestions = []
|
|
|
|
# 分析最弱的维度
|
|
dimensions = {
|
|
'coherence_score': '连贯性',
|
|
'character_consistency_score': '角色一致性',
|
|
'naturalness_score': '自然度',
|
|
'information_density_score': '信息密度',
|
|
'creativity_score': '创意性'
|
|
}
|
|
|
|
dim_averages = {}
|
|
for dim_key, dim_name in dimensions.items():
|
|
dim_scores = [score[dim_key] for score in scores if score[dim_key] is not None]
|
|
if dim_scores:
|
|
dim_averages[dim_name] = np.mean(dim_scores)
|
|
|
|
# 找出最弱的维度
|
|
if dim_averages:
|
|
weakest_dim = min(dim_averages.items(), key=lambda x: x[1])
|
|
if weakest_dim[1] < 6.0:
|
|
suggestions.append(f"需要重点改进{weakest_dim[0]},当前平均分为{weakest_dim[1]:.1f}")
|
|
|
|
# 根据分数分布给建议
|
|
overall_avg = np.mean([score['overall_score'] for score in scores])
|
|
if overall_avg < 5.0:
|
|
suggestions.append("整体表现需要改进,建议调整角色设定或提示词")
|
|
elif overall_avg < 7.0:
|
|
suggestions.append("表现良好,可以尝试增加对话的创意性和深度")
|
|
else:
|
|
suggestions.append("表现优秀,继续保持当前设定")
|
|
|
|
return suggestions
|
|
|
|
def compare_characters(self, character_names: List[str]) -> Dict:
|
|
"""
|
|
比较多个角色的表现
|
|
|
|
Args:
|
|
character_names: 角色名称列表
|
|
|
|
Returns:
|
|
Dict: 比较结果
|
|
"""
|
|
comparison = {"characters": {}}
|
|
|
|
for char_name in character_names:
|
|
scores = self.score_manager.get_score_results(speaker=char_name, limit=500)
|
|
if scores:
|
|
overall_scores = [score['overall_score'] for score in scores]
|
|
comparison["characters"][char_name] = {
|
|
"total_dialogues": len(scores),
|
|
"average_score": round(np.mean(overall_scores), 2),
|
|
"score_range": f"{min(overall_scores):.1f} - {max(overall_scores):.1f}"
|
|
}
|
|
|
|
# 排序
|
|
if comparison["characters"]:
|
|
sorted_chars = sorted(comparison["characters"].items(),
|
|
key=lambda x: x[1]["average_score"], reverse=True)
|
|
comparison["ranking"] = [{"character": char, **data} for char, data in sorted_chars]
|
|
|
|
return comparison |