2025-07-03 19:01:27 +08:00
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
import sqlite3
|
|
|
|
|
|
|
|
|
|
class DatabaseHandle:
|
|
|
|
|
def __init__(self, db_path = "data.db"):
|
|
|
|
|
self.db_path = db_path
|
|
|
|
|
self._init_db()
|
|
|
|
|
|
|
|
|
|
def _init_db(self):
|
|
|
|
|
"""创建表"""
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
#创建角色表
|
|
|
|
|
conn.execute('''
|
|
|
|
|
CREATE TABLE IF NOT EXISTS characters
|
|
|
|
|
(id INTEGER PRIMARY KEY,
|
2025-07-04 11:12:24 +08:00
|
|
|
|
name TEXT not NULL UNIQUE,
|
2025-07-03 19:01:27 +08:00
|
|
|
|
age INTEGER,
|
|
|
|
|
personality TEXT,
|
|
|
|
|
profession TEXT,
|
2025-07-04 11:12:24 +08:00
|
|
|
|
characterBackground TEXT,
|
2025-07-03 19:01:27 +08:00
|
|
|
|
chat_style TEXT
|
|
|
|
|
)
|
|
|
|
|
''')
|
|
|
|
|
#创建聊天记录表
|
|
|
|
|
conn.execute('''
|
|
|
|
|
CREATE TABLE IF NOT EXISTS chat_records
|
|
|
|
|
(id INTEGER PRIMARY KEY,
|
|
|
|
|
character_ids TEXT not NULL,
|
|
|
|
|
chat TEXT,
|
|
|
|
|
time DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
|
|
|
)
|
|
|
|
|
''')
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def _get_connection(self):
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
try:
|
|
|
|
|
yield conn
|
|
|
|
|
conn.commit()
|
|
|
|
|
except Exception:
|
|
|
|
|
conn.rollback()
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def add_character(self, data:dict):
|
|
|
|
|
"""添加角色数据"""
|
2025-07-04 11:12:24 +08:00
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
INSERT INTO characters (
|
|
|
|
|
name, age, personality, profession, characterBackground,
|
|
|
|
|
chat_style
|
|
|
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
2025-07-07 09:33:56 +08:00
|
|
|
|
ON CONFLICT(name) DO UPDATE SET
|
|
|
|
|
age = excluded.age,
|
|
|
|
|
personality = excluded.personality,
|
|
|
|
|
profession = excluded.profession,
|
|
|
|
|
characterBackground = excluded.characterBackground,
|
|
|
|
|
chat_style = excluded.chat_style
|
2025-07-04 11:12:24 +08:00
|
|
|
|
''', (
|
|
|
|
|
data["name"], data["age"], data["personality"],
|
|
|
|
|
data["profession"], data["characterBackground"],
|
|
|
|
|
data["chat_style"]
|
|
|
|
|
))
|
|
|
|
|
conn.commit()
|
|
|
|
|
return cursor.lastrowid
|
|
|
|
|
|
|
|
|
|
def get_character_byname(self, name:str) ->list:
|
|
|
|
|
"""
|
|
|
|
|
根据角色名称查询数据(name为空时返回全部数据)
|
|
|
|
|
:param name: 角色名称(精确匹配,None 或空字符串时返回全部)
|
|
|
|
|
:return: 角色数据字典列表
|
|
|
|
|
"""
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
# 动态构建SQL语句
|
|
|
|
|
sql = "SELECT * FROM characters"
|
|
|
|
|
params = ()
|
|
|
|
|
|
|
|
|
|
if name: # 当name非空时添加条件
|
|
|
|
|
sql += " WHERE name = ?"
|
|
|
|
|
params = (name,)
|
|
|
|
|
|
|
|
|
|
cursor.execute(sql, params)
|
|
|
|
|
|
|
|
|
|
# 转换结果为字典列表
|
|
|
|
|
columns = [col[0] for col in cursor.description]
|
|
|
|
|
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
2025-07-03 19:01:27 +08:00
|
|
|
|
|
2025-07-04 11:12:24 +08:00
|
|
|
|
def add_chat(self, data:dict):
|
|
|
|
|
"""添加聊天数据"""
|
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
INSERT INTO chat_records (
|
|
|
|
|
character_ids, chat
|
|
|
|
|
) VALUES (?, ?)
|
|
|
|
|
''', (
|
|
|
|
|
data["character_ids"], data["chat"]
|
|
|
|
|
))
|
|
|
|
|
conn.commit()
|
|
|
|
|
return cursor.lastrowid
|
|
|
|
|
|
2025-07-09 19:58:44 +08:00
|
|
|
|
def get_chats_by_character_id(self, character_id: str) -> list:
|
2025-07-04 11:12:24 +08:00
|
|
|
|
"""
|
|
|
|
|
根据角色ID查询聊天记录(target_id为空时返回全部数据)
|
|
|
|
|
:param target_id: 目标角色ID(None时返回全部记录)
|
|
|
|
|
:return: 聊天记录字典列表
|
|
|
|
|
"""
|
2025-07-09 19:58:44 +08:00
|
|
|
|
|
|
|
|
|
sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序
|
|
|
|
|
normalized_param = ",".join(sorted_ids)
|
2025-07-04 11:12:24 +08:00
|
|
|
|
with self._get_connection() as conn:
|
|
|
|
|
cursor = conn.cursor()
|
2025-07-09 19:58:44 +08:00
|
|
|
|
sql = "SELECT * FROM chat_records WHERE character_ids = ?"
|
|
|
|
|
cursor.execute(sql, (normalized_param,))
|
2025-07-04 11:12:24 +08:00
|
|
|
|
columns = [col[0] for col in cursor.description]
|
|
|
|
|
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|