Project02/AIGC/DatabaseHandle.py
2025-07-09 19:58:44 +08:00

120 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import contextmanager
import sqlite3
class DatabaseHandle:
def __init__(self, db_path = "data.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
"""创建表"""
with self._get_connection() as conn:
#创建角色表
conn.execute('''
CREATE TABLE IF NOT EXISTS characters
(id INTEGER PRIMARY KEY,
name TEXT not NULL UNIQUE,
age INTEGER,
personality TEXT,
profession TEXT,
characterBackground TEXT,
chat_style TEXT
)
''')
#创建聊天记录表
conn.execute('''
CREATE TABLE IF NOT EXISTS chat_records
(id INTEGER PRIMARY KEY,
character_ids TEXT not NULL,
chat TEXT,
time DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
@contextmanager
def _get_connection(self):
conn = sqlite3.connect(self.db_path)
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def add_character(self, data:dict):
"""添加角色数据"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO characters (
name, age, personality, profession, characterBackground,
chat_style
) VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(name) DO UPDATE SET
age = excluded.age,
personality = excluded.personality,
profession = excluded.profession,
characterBackground = excluded.characterBackground,
chat_style = excluded.chat_style
''', (
data["name"], data["age"], data["personality"],
data["profession"], data["characterBackground"],
data["chat_style"]
))
conn.commit()
return cursor.lastrowid
def get_character_byname(self, name:str) ->list:
"""
根据角色名称查询数据name为空时返回全部数据
:param name: 角色名称精确匹配None 或空字符串时返回全部)
:return: 角色数据字典列表
"""
with self._get_connection() as conn:
cursor = conn.cursor()
# 动态构建SQL语句
sql = "SELECT * FROM characters"
params = ()
if name: # 当name非空时添加条件
sql += " WHERE name = ?"
params = (name,)
cursor.execute(sql, params)
# 转换结果为字典列表
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
def add_chat(self, data:dict):
"""添加聊天数据"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT INTO chat_records (
character_ids, chat
) VALUES (?, ?)
''', (
data["character_ids"], data["chat"]
))
conn.commit()
return cursor.lastrowid
def get_chats_by_character_id(self, character_id: str) -> list:
"""
根据角色ID查询聊天记录target_id为空时返回全部数据
:param target_id: 目标角色IDNone时返回全部记录
:return: 聊天记录字典列表
"""
sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序
normalized_param = ",".join(sorted_ids)
with self._get_connection() as conn:
cursor = conn.cursor()
sql = "SELECT * FROM chat_records WHERE character_ids = ?"
cursor.execute(sql, (normalized_param,))
columns = [col[0] for col in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]