Project02/AIGC/DatabaseHandle.py

120 lines
4.6 KiB
Python
Raw Permalink Normal View History

2025-07-03 19:01:27 +08:00
from contextlib import contextmanager
import sqlite3
class DatabaseHandle:
def __init__(self, db_path = "data.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
"""创建表"""
with self._get_connection() as conn:
#创建角色表
conn.execute('''
CREATE TABLE IF NOT EXISTS characters
(id INTEGER PRIMARY KEY,
name TEXT not NULL UNIQUE,
2025-07-03 19:01:27 +08:00
age INTEGER,
personality TEXT,
profession TEXT,
characterBackground TEXT,
2025-07-03 19:01:27 +08:00
chat_style TEXT
)
''')
#创建聊天记录表
conn.execute('''
CREATE TABLE IF NOT EXISTS chat_records
(id INTEGER PRIMARY KEY,
character_ids TEXT not NULL,
chat TEXT,
time DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
@contextmanager
def _get_connection(self):
conn = sqlite3.connect(self.db_path)
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def add_character(self, data:dict):
"""添加角色数据"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO characters (
name, age, personality, profession, characterBackground,
chat_style
) VALUES (?, ?, ?, ?, ?, ?)
2025-07-07 09:33:56 +08:00
ON CONFLICT(name) DO UPDATE SET
age = excluded.age,
personality = excluded.personality,
profession = excluded.profession,
characterBackground = excluded.characterBackground,
chat_style = excluded.chat_style
''', (
data["name"], data["age"], data["personality"],
data["profession"], data["characterBackground"],
data["chat_style"]
))
conn.commit()
return cursor.lastrowid
def get_character_byname(self, name:str) ->list:
"""
根据角色名称查询数据name为空时返回全部数据
:param name: 角色名称精确匹配None 或空字符串时返回全部
:return: 角色数据字典列表
"""
with self._get_connection() as conn:
cursor = conn.cursor()
# 动态构建SQL语句
sql = "SELECT * FROM characters"
params = ()
if name: # 当name非空时添加条件
sql += " WHERE name = ?"
params = (name,)
cursor.execute(sql, params)
# 转换结果为字典列表
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
2025-07-03 19:01:27 +08:00
def add_chat(self, data:dict):
"""添加聊天数据"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO chat_records (
character_ids, chat
) VALUES (?, ?)
''', (
data["character_ids"], data["chat"]
))
conn.commit()
return cursor.lastrowid
2025-07-09 19:58:44 +08:00
def get_chats_by_character_id(self, character_id: str) -> list:
"""
根据角色ID查询聊天记录target_id为空时返回全部数据
:param target_id: 目标角色IDNone时返回全部记录
:return: 聊天记录字典列表
"""
2025-07-09 19:58:44 +08:00
sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序
normalized_param = ",".join(sorted_ids)
with self._get_connection() as conn:
cursor = conn.cursor()
2025-07-09 19:58:44 +08:00
sql = "SELECT * FROM chat_records WHERE character_ids = ?"
cursor.execute(sql, (normalized_param,))
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]