from contextlib import contextmanager import sqlite3 class DatabaseHandle: def __init__(self, db_path = "data.db"): self.db_path = db_path self._init_db() def _init_db(self): """创建表""" with self._get_connection() as conn: #创建角色表 conn.execute(''' CREATE TABLE IF NOT EXISTS characters (id INTEGER PRIMARY KEY, name TEXT not NULL UNIQUE, age INTEGER, personality TEXT, profession TEXT, characterBackground TEXT, chat_style TEXT ) ''') #创建聊天记录表 conn.execute(''' CREATE TABLE IF NOT EXISTS chat_records (id INTEGER PRIMARY KEY, character_ids TEXT not NULL, chat TEXT, time DATETIME DEFAULT CURRENT_TIMESTAMP ) ''') @contextmanager def _get_connection(self): conn = sqlite3.connect(self.db_path) try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() def add_character(self, data:dict): """添加角色数据""" with self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO characters ( name, age, personality, profession, characterBackground, chat_style ) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(name) DO UPDATE SET age = excluded.age, personality = excluded.personality, profession = excluded.profession, characterBackground = excluded.characterBackground, chat_style = excluded.chat_style ''', ( data["name"], data["age"], data["personality"], data["profession"], data["characterBackground"], data["chat_style"] )) conn.commit() return cursor.lastrowid def get_character_byname(self, name:str) ->list: """ 根据角色名称查询数据(name为空时返回全部数据) :param name: 角色名称(精确匹配,None 或空字符串时返回全部) :return: 角色数据字典列表 """ with self._get_connection() as conn: cursor = conn.cursor() # 动态构建SQL语句 sql = "SELECT * FROM characters" params = () if name: # 当name非空时添加条件 sql += " WHERE name = ?" params = (name,) cursor.execute(sql, params) # 转换结果为字典列表 columns = [col[0] for col in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()] def add_chat(self, data:dict): """添加聊天数据""" with self._get_connection() as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO chat_records ( character_ids, chat ) VALUES (?, ?) ''', ( data["character_ids"], data["chat"] )) conn.commit() return cursor.lastrowid def get_chats_by_character_id(self, character_id: str) -> list: """ 根据角色ID查询聊天记录(target_id为空时返回全部数据) :param target_id: 目标角色ID(None时返回全部记录) :return: 聊天记录字典列表 """ sorted_ids = sorted(character_id.split(","), key=int) # 按数值升序 normalized_param = ",".join(sorted_ids) with self._get_connection() as conn: cursor = conn.cursor() sql = "SELECT * FROM chat_records WHERE character_ids = ?" cursor.execute(sql, (normalized_param,)) columns = [col[0] for col in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()]