120 lines
4.6 KiB
Python
120 lines
4.6 KiB
Python
from contextlib import contextmanager
|
||
import sqlite3
|
||
|
||
class DatabaseHandle:
|
||
def __init__(self, db_path = "data.db"):
|
||
self.db_path = db_path
|
||
self._init_db()
|
||
|
||
def _init_db(self):
|
||
"""创建表"""
|
||
with self._get_connection() as conn:
|
||
#创建角色表
|
||
conn.execute('''
|
||
CREATE TABLE IF NOT EXISTS characters
|
||
(id INTEGER PRIMARY KEY,
|
||
name TEXT not NULL UNIQUE,
|
||
age INTEGER,
|
||
personality TEXT,
|
||
profession TEXT,
|
||
characterBackground TEXT,
|
||
chat_style TEXT
|
||
)
|
||
''')
|
||
#创建聊天记录表
|
||
conn.execute('''
|
||
CREATE TABLE IF NOT EXISTS chat_records
|
||
(id INTEGER PRIMARY KEY,
|
||
character_ids TEXT not NULL,
|
||
chat TEXT,
|
||
time DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
''')
|
||
|
||
@contextmanager
|
||
def _get_connection(self):
|
||
conn = sqlite3.connect(self.db_path)
|
||
try:
|
||
yield conn
|
||
conn.commit()
|
||
except Exception:
|
||
conn.rollback()
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
def add_character(self, data:dict):
|
||
"""添加角色数据"""
|
||
with self._get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO characters (
|
||
name, age, personality, profession, characterBackground,
|
||
chat_style
|
||
) VALUES (?, ?, ?, ?, ?, ?)
|
||
ON CONFLICT(name) DO UPDATE SET
|
||
age = excluded.age,
|
||
personality = excluded.personality,
|
||
profession = excluded.profession,
|
||
characterBackground = excluded.characterBackground,
|
||
chat_style = excluded.chat_style
|
||
''', (
|
||
data["name"], data["age"], data["personality"],
|
||
data["profession"], data["characterBackground"],
|
||
data["chat_style"]
|
||
))
|
||
conn.commit()
|
||
return cursor.lastrowid
|
||
|
||
def get_character_byname(self, name:str) ->list:
|
||
"""
|
||
根据角色名称查询数据(name为空时返回全部数据)
|
||
:param name: 角色名称(精确匹配,None 或空字符串时返回全部)
|
||
:return: 角色数据字典列表
|
||
"""
|
||
with self._get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 动态构建SQL语句
|
||
sql = "SELECT * FROM characters"
|
||
params = ()
|
||
|
||
if name: # 当name非空时添加条件
|
||
sql += " WHERE name = ?"
|
||
params = (name,)
|
||
|
||
cursor.execute(sql, params)
|
||
|
||
# 转换结果为字典列表
|
||
columns = [col[0] for col in cursor.description]
|
||
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||
|
||
def add_chat(self, data:dict):
|
||
"""添加聊天数据"""
|
||
with self._get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO chat_records (
|
||
character_ids, chat
|
||
) VALUES (?, ?)
|
||
''', (
|
||
data["character_ids"], data["chat"]
|
||
))
|
||
conn.commit()
|
||
return cursor.lastrowid
|
||
|
||
def get_chats_by_character_id(self, character_id: str) -> list:
|
||
"""
|
||
根据角色ID查询聊天记录(target_id为空时返回全部数据)
|
||
:param target_id: 目标角色ID(None时返回全部记录)
|
||
:return: 聊天记录字典列表
|
||
"""
|
||
|
||
sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序
|
||
normalized_param = ",".join(sorted_ids)
|
||
with self._get_connection() as conn:
|
||
cursor = conn.cursor()
|
||
sql = "SELECT * FROM chat_records WHERE character_ids = ?"
|
||
cursor.execute(sql, (normalized_param,))
|
||
columns = [col[0] for col in cursor.description]
|
||
return [dict(zip(columns, row)) for row in cursor.fetchall()] |