diff --git a/AITrain/pdf_to_rag_processor.py b/AITrain/pdf_to_rag_processor.py new file mode 100644 index 0000000..48cf6f1 --- /dev/null +++ b/AITrain/pdf_to_rag_processor.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +''' +PDF文档转RAG知识库处理模块 +支持将世界观文档(如COC.pdf)转换为RAG检索格式 +''' + +import json +import os +import re +from typing import List, Dict, Tuple +import hashlib +from datetime import datetime + +try: + import PyPDF2 + import fitz # pymupdf + PYMUPDF_AVAILABLE = True +except ImportError: + PYMUPDF_AVAILABLE = False + print("Warning: pymupdf not available, using PyPDF2 only") + +try: + from sentence_transformers import SentenceTransformer + import faiss + import numpy as np + EMBEDDING_AVAILABLE = True +except ImportError: + EMBEDDING_AVAILABLE = False + print("Warning: sentence-transformers or faiss not available, using simple text matching") + +class PDFToRAGProcessor: + def __init__(self, embedding_model: str = "./sentence-transformers/all-MiniLM-L6-v2"): + """ + 初始化PDF处理器 + + Args: + embedding_model: 向量化模型名称 + """ + global EMBEDDING_AVAILABLE + self.chunks = [] + self.embeddings = None + self.index = None + + if EMBEDDING_AVAILABLE: + try: + self.embedding_model = SentenceTransformer(embedding_model) + print(f"✓ 向量模型加载成功: {embedding_model}") + except Exception as e: + print(f"✗ 向量模型加载失败: {e}") + self.embedding_model = None + EMBEDDING_AVAILABLE = False + else: + self.embedding_model = None + + def extract_text_from_pdf(self, pdf_path: str) -> str: + """从PDF提取文本""" + if not os.path.exists(pdf_path): + raise FileNotFoundError(f"PDF文件不存在: {pdf_path}") + + text = "" + + # 优先使用pymupdf,效果更好 + if PYMUPDF_AVAILABLE: + try: + doc = fitz.open(pdf_path) + for page_num in range(len(doc)): + page = doc.load_page(page_num) + text += page.get_text() + doc.close() + print(f"✓ 使用pymupdf提取文本成功") + return text + except Exception as e: + print(f"✗ pymupdf提取失败: {e}, 尝试PyPDF2") + + # 备用PyPDF2 + try: + with open(pdf_path, 'rb') as file: + pdf_reader = PyPDF2.PdfReader(file) + for page in pdf_reader.pages: + text += page.extract_text() + print(f"✓ 使用PyPDF2提取文本成功") + except Exception as e: + raise Exception(f"PDF文本提取失败: {e}") + + return text + + def clean_text(self, text: str) -> str: + """清理文本""" + # 移除多余空行 + text = re.sub(r'\n\s*\n', '\n\n', text) + # 移除页码等 + text = re.sub(r'第\s*\d+\s*页', '', text) + text = re.sub(r'Page\s*\d+', '', text) + # 统一标点符号 + text = text.replace(' ', ' ') # 全角空格转半角 + text = re.sub(r'\s+', ' ', text) # 多个空格合并 + + return text.strip() + + def chunk_text_by_semantic(self, text: str, max_chunk_size: int = 500) -> List[Dict]: + """按语义分块文本""" + # 按段落分割 + paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] + + chunks = [] + current_chunk = "" + current_size = 0 + + for para in paragraphs: + para_size = len(para) + + # 如果当前段落太长,需要进一步分割 + if para_size > max_chunk_size: + # 保存当前chunk + if current_chunk: + chunks.append({ + "content": current_chunk.strip(), + "size": current_size, + "type": "paragraph" + }) + current_chunk = "" + current_size = 0 + + # 按句子分割长段落 + sentences = re.split(r'[。!?;]\s*', para) + for sentence in sentences: + if sentence.strip(): + sentence += '。' # 添加标点 + if current_size + len(sentence) > max_chunk_size and current_chunk: + chunks.append({ + "content": current_chunk.strip(), + "size": current_size, + "type": "sentence" + }) + current_chunk = sentence + current_size = len(sentence) + else: + current_chunk += sentence + current_size += len(sentence) + else: + # 正常段落处理 + if current_size + para_size > max_chunk_size and current_chunk: + chunks.append({ + "content": current_chunk.strip(), + "size": current_size, + "type": "paragraph" + }) + current_chunk = para + current_size = para_size + else: + current_chunk += "\n" + para if current_chunk else para + current_size += para_size + + # 保存最后的chunk + if current_chunk: + chunks.append({ + "content": current_chunk.strip(), + "size": current_size, + "type": "paragraph" + }) + + # 为每个chunk添加唯一ID和元数据 + for i, chunk in enumerate(chunks): + chunk["id"] = f"chunk_{i:04d}" + chunk["hash"] = hashlib.md5(chunk["content"].encode()).hexdigest()[:8] + + return chunks + + def extract_key_concepts(self, chunks: List[Dict]) -> List[Dict]: + """提取关键概念和术语""" + concepts = [] + + # 简单的关键词提取规则 + concept_patterns = [ + r'(?:技能|属性|规则)[::]\s*([^。\n]+)', + r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', # 英文专有名词 + r'【([^】]+)】', # 中文术语标记 + r'"([^"]+)"', # 引号内容 + ] + + for chunk in chunks: + content = chunk["content"] + chunk_concepts = [] + + for pattern in concept_patterns: + matches = re.findall(pattern, content) + chunk_concepts.extend(matches) + + if chunk_concepts: + concepts.append({ + "chunk_id": chunk["id"], + "concepts": list(set(chunk_concepts)), # 去重 + "content_preview": content[:100] + "..." if len(content) > 100 else content + }) + + return concepts + + def build_vector_index(self, chunks: List[Dict]) -> bool: + """构建向量索引""" + if not EMBEDDING_AVAILABLE or not self.embedding_model: + print("✗ 向量化功能不可用,将使用文本匹配") + return False + + try: + # 提取文本内容 + texts = [chunk["content"] for chunk in chunks] + + # 生成向量 + print("生成文档向量...") + embeddings = self.embedding_model.encode(texts, show_progress_bar=True) + + # 构建FAISS索引 + dimension = embeddings.shape[1] + self.index = faiss.IndexFlatL2(dimension) + self.index.add(embeddings.astype(np.float32)) + + self.embeddings = embeddings + print(f"✓ 向量索引构建完成,维度: {dimension}, 文档数: {len(chunks)}") + return True + + except Exception as e: + print(f"✗ 向量索引构建失败: {e}") + return False + + def process_pdf_to_rag(self, pdf_path: str, output_dir: str = "./rag_knowledge") -> Dict: + """ + 完整处理PDF到RAG知识库 + + Args: + pdf_path: PDF文件路径 + output_dir: 输出目录 + + Returns: + 处理结果统计 + """ + print(f"开始处理PDF: {pdf_path}") + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 1. 提取文本 + raw_text = self.extract_text_from_pdf(pdf_path) + print(f"✓ 提取文本长度: {len(raw_text)} 字符") + + # 2. 清理文本 + clean_text = self.clean_text(raw_text) + + # 保存清理后的文本 + text_output_path = os.path.join(output_dir, "extracted_text.txt") + with open(text_output_path, 'w', encoding='utf-8') as f: + f.write(clean_text) + print(f"✓ 清理后文本保存至: {text_output_path}") + + # 3. 语义分块 + chunks = self.chunk_text_by_semantic(clean_text) + self.chunks = chunks + print(f"✓ 文档分块完成: {len(chunks)} 个块") + + # 4. 提取关键概念 + concepts = self.extract_key_concepts(chunks) + print(f"✓ 提取关键概念: {len(concepts)} 个") + + # 5. 构建向量索引 + vector_success = self.build_vector_index(chunks) + + # 6. 保存知识库文件 + knowledge_base = { + "metadata": { + "source_file": os.path.basename(pdf_path), + "processed_time": datetime.now().isoformat(), + "total_chunks": len(chunks), + "total_concepts": len(concepts), + "vector_enabled": vector_success + }, + "chunks": chunks, + "concepts": concepts + } + + # 保存JSON知识库 + kb_output_path = os.path.join(output_dir, "knowledge_base.json") + with open(kb_output_path, 'w', encoding='utf-8') as f: + json.dump(knowledge_base, f, ensure_ascii=False, indent=2) + print(f"✓ 知识库保存至: {kb_output_path}") + + # 保存向量索引 + if vector_success: + index_path = os.path.join(output_dir, "vector_index.faiss") + faiss.write_index(self.index, index_path) + + embeddings_path = os.path.join(output_dir, "embeddings.npy") + np.save(embeddings_path, self.embeddings) + print(f"✓ 向量索引保存至: {index_path}") + + return { + "status": "success", + "chunks_count": len(chunks), + "concepts_count": len(concepts), + "vector_enabled": vector_success, + "output_dir": output_dir + } + + def load_knowledge_base(self, knowledge_dir: str) -> bool: + """加载已有知识库""" + try: + # 加载JSON知识库 + kb_path = os.path.join(knowledge_dir, "knowledge_base.json") + with open(kb_path, 'r', encoding='utf-8') as f: + knowledge_base = json.load(f) + + self.chunks = knowledge_base["chunks"] + + # 加载向量索引 + if EMBEDDING_AVAILABLE: + index_path = os.path.join(knowledge_dir, "vector_index.faiss") + embeddings_path = os.path.join(knowledge_dir, "embeddings.npy") + + if os.path.exists(index_path) and os.path.exists(embeddings_path): + self.index = faiss.read_index(index_path) + self.embeddings = np.load(embeddings_path) + print(f"✓ 向量索引加载成功") + + print(f"✓ 知识库加载成功: {len(self.chunks)} 个文档块") + return True + + except Exception as e: + print(f"✗ 知识库加载失败: {e}") + return False + + def search_relevant_content(self, query: str, top_k: int = 3) -> List[Dict]: + """搜索相关内容""" + if not self.chunks: + return [] + + # 向量搜索 + if EMBEDDING_AVAILABLE and self.embedding_model and self.index: + try: + query_vector = self.embedding_model.encode([query]) + distances, indices = self.index.search(query_vector.astype(np.float32), top_k) + + results = [] + for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): + if idx < len(self.chunks): + result = self.chunks[idx].copy() + result["relevance_score"] = float(1 / (1 + distance)) # 转换为相似度分数 + result["rank"] = i + 1 + results.append(result) + + return results + + except Exception as e: + print(f"向量搜索失败: {e}, 使用文本匹配") + + # 文本匹配搜索 + query_lower = query.lower() + scored_chunks = [] + + for chunk in self.chunks: + content_lower = chunk["content"].lower() + + # 简单的相关性评分 + score = 0 + query_words = query_lower.split() + + for word in query_words: + if word in content_lower: + score += content_lower.count(word) + + if score > 0: + result = chunk.copy() + result["relevance_score"] = score + scored_chunks.append(result) + + # 按分数排序 + scored_chunks.sort(key=lambda x: x["relevance_score"], reverse=True) + + # 添加排名 + for i, chunk in enumerate(scored_chunks[:top_k]): + chunk["rank"] = i + 1 + + return scored_chunks[:top_k] + +def main(): + """测试PDF处理功能""" + processor = PDFToRAGProcessor() + + # 示例:处理COC规则书 + pdf_path = input("请输入PDF文件路径 (如: ./coc.pdf): ").strip() + + if not os.path.exists(pdf_path): + print(f"文件不存在: {pdf_path}") + return + + try: + result = processor.process_pdf_to_rag(pdf_path) + + print(f"\n{'='*50}") + print(f"PDF处理完成!") + print(f"状态: {result['status']}") + print(f"文档块数量: {result['chunks_count']}") + print(f"关键概念数量: {result['concepts_count']}") + print(f"向量索引: {'启用' if result['vector_enabled'] else '未启用'}") + print(f"输出目录: {result['output_dir']}") + + # 测试搜索 + print(f"\n{'='*50}") + print("测试知识库搜索:") + + while True: + query = input("\n请输入搜索关键词 (输入'quit'退出): ").strip() + if query.lower() == 'quit': + break + + results = processor.search_relevant_content(query, top_k=3) + + if results: + print(f"\n找到 {len(results)} 个相关结果:") + for result in results: + print(f"\n排名 {result['rank']} (相关度: {result['relevance_score']:.3f}):") + content = result['content'] + preview = content[:200] + "..." if len(content) > 200 else content + print(f"{preview}") + print("-" * 40) + else: + print("未找到相关内容") + + except Exception as e: + print(f"处理失败: {e}") + +if __name__ == '__main__': + main() \ No newline at end of file