添加pdf文件转换功能
This commit is contained in:
parent
2fc38d872d
commit
7a243becfd
431
AITrain/pdf_to_rag_processor.py
Normal file
431
AITrain/pdf_to_rag_processor.py
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
'''
|
||||||
|
PDF文档转RAG知识库处理模块
|
||||||
|
支持将世界观文档(如COC.pdf)转换为RAG检索格式
|
||||||
|
'''
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
try:
|
||||||
|
import PyPDF2
|
||||||
|
import fitz # pymupdf
|
||||||
|
PYMUPDF_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PYMUPDF_AVAILABLE = False
|
||||||
|
print("Warning: pymupdf not available, using PyPDF2 only")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
EMBEDDING_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
EMBEDDING_AVAILABLE = False
|
||||||
|
print("Warning: sentence-transformers or faiss not available, using simple text matching")
|
||||||
|
|
||||||
|
class PDFToRAGProcessor:
|
||||||
|
def __init__(self, embedding_model: str = "./sentence-transformers/all-MiniLM-L6-v2"):
|
||||||
|
"""
|
||||||
|
初始化PDF处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_model: 向量化模型名称
|
||||||
|
"""
|
||||||
|
global EMBEDDING_AVAILABLE
|
||||||
|
self.chunks = []
|
||||||
|
self.embeddings = None
|
||||||
|
self.index = None
|
||||||
|
|
||||||
|
if EMBEDDING_AVAILABLE:
|
||||||
|
try:
|
||||||
|
self.embedding_model = SentenceTransformer(embedding_model)
|
||||||
|
print(f"✓ 向量模型加载成功: {embedding_model}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 向量模型加载失败: {e}")
|
||||||
|
self.embedding_model = None
|
||||||
|
EMBEDDING_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
self.embedding_model = None
|
||||||
|
|
||||||
|
def extract_text_from_pdf(self, pdf_path: str) -> str:
|
||||||
|
"""从PDF提取文本"""
|
||||||
|
if not os.path.exists(pdf_path):
|
||||||
|
raise FileNotFoundError(f"PDF文件不存在: {pdf_path}")
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
|
||||||
|
# 优先使用pymupdf,效果更好
|
||||||
|
if PYMUPDF_AVAILABLE:
|
||||||
|
try:
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
for page_num in range(len(doc)):
|
||||||
|
page = doc.load_page(page_num)
|
||||||
|
text += page.get_text()
|
||||||
|
doc.close()
|
||||||
|
print(f"✓ 使用pymupdf提取文本成功")
|
||||||
|
return text
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ pymupdf提取失败: {e}, 尝试PyPDF2")
|
||||||
|
|
||||||
|
# 备用PyPDF2
|
||||||
|
try:
|
||||||
|
with open(pdf_path, 'rb') as file:
|
||||||
|
pdf_reader = PyPDF2.PdfReader(file)
|
||||||
|
for page in pdf_reader.pages:
|
||||||
|
text += page.extract_text()
|
||||||
|
print(f"✓ 使用PyPDF2提取文本成功")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"PDF文本提取失败: {e}")
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def clean_text(self, text: str) -> str:
|
||||||
|
"""清理文本"""
|
||||||
|
# 移除多余空行
|
||||||
|
text = re.sub(r'\n\s*\n', '\n\n', text)
|
||||||
|
# 移除页码等
|
||||||
|
text = re.sub(r'第\s*\d+\s*页', '', text)
|
||||||
|
text = re.sub(r'Page\s*\d+', '', text)
|
||||||
|
# 统一标点符号
|
||||||
|
text = text.replace(' ', ' ') # 全角空格转半角
|
||||||
|
text = re.sub(r'\s+', ' ', text) # 多个空格合并
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def chunk_text_by_semantic(self, text: str, max_chunk_size: int = 500) -> List[Dict]:
|
||||||
|
"""按语义分块文本"""
|
||||||
|
# 按段落分割
|
||||||
|
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
para_size = len(para)
|
||||||
|
|
||||||
|
# 如果当前段落太长,需要进一步分割
|
||||||
|
if para_size > max_chunk_size:
|
||||||
|
# 保存当前chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"size": current_size,
|
||||||
|
"type": "paragraph"
|
||||||
|
})
|
||||||
|
current_chunk = ""
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
# 按句子分割长段落
|
||||||
|
sentences = re.split(r'[。!?;]\s*', para)
|
||||||
|
for sentence in sentences:
|
||||||
|
if sentence.strip():
|
||||||
|
sentence += '。' # 添加标点
|
||||||
|
if current_size + len(sentence) > max_chunk_size and current_chunk:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"size": current_size,
|
||||||
|
"type": "sentence"
|
||||||
|
})
|
||||||
|
current_chunk = sentence
|
||||||
|
current_size = len(sentence)
|
||||||
|
else:
|
||||||
|
current_chunk += sentence
|
||||||
|
current_size += len(sentence)
|
||||||
|
else:
|
||||||
|
# 正常段落处理
|
||||||
|
if current_size + para_size > max_chunk_size and current_chunk:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"size": current_size,
|
||||||
|
"type": "paragraph"
|
||||||
|
})
|
||||||
|
current_chunk = para
|
||||||
|
current_size = para_size
|
||||||
|
else:
|
||||||
|
current_chunk += "\n" + para if current_chunk else para
|
||||||
|
current_size += para_size
|
||||||
|
|
||||||
|
# 保存最后的chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"size": current_size,
|
||||||
|
"type": "paragraph"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 为每个chunk添加唯一ID和元数据
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk["id"] = f"chunk_{i:04d}"
|
||||||
|
chunk["hash"] = hashlib.md5(chunk["content"].encode()).hexdigest()[:8]
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_key_concepts(self, chunks: List[Dict]) -> List[Dict]:
|
||||||
|
"""提取关键概念和术语"""
|
||||||
|
concepts = []
|
||||||
|
|
||||||
|
# 简单的关键词提取规则
|
||||||
|
concept_patterns = [
|
||||||
|
r'(?:技能|属性|规则)[::]\s*([^。\n]+)',
|
||||||
|
r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', # 英文专有名词
|
||||||
|
r'【([^】]+)】', # 中文术语标记
|
||||||
|
r'"([^"]+)"', # 引号内容
|
||||||
|
]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
content = chunk["content"]
|
||||||
|
chunk_concepts = []
|
||||||
|
|
||||||
|
for pattern in concept_patterns:
|
||||||
|
matches = re.findall(pattern, content)
|
||||||
|
chunk_concepts.extend(matches)
|
||||||
|
|
||||||
|
if chunk_concepts:
|
||||||
|
concepts.append({
|
||||||
|
"chunk_id": chunk["id"],
|
||||||
|
"concepts": list(set(chunk_concepts)), # 去重
|
||||||
|
"content_preview": content[:100] + "..." if len(content) > 100 else content
|
||||||
|
})
|
||||||
|
|
||||||
|
return concepts
|
||||||
|
|
||||||
|
def build_vector_index(self, chunks: List[Dict]) -> bool:
|
||||||
|
"""构建向量索引"""
|
||||||
|
if not EMBEDDING_AVAILABLE or not self.embedding_model:
|
||||||
|
print("✗ 向量化功能不可用,将使用文本匹配")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 提取文本内容
|
||||||
|
texts = [chunk["content"] for chunk in chunks]
|
||||||
|
|
||||||
|
# 生成向量
|
||||||
|
print("生成文档向量...")
|
||||||
|
embeddings = self.embedding_model.encode(texts, show_progress_bar=True)
|
||||||
|
|
||||||
|
# 构建FAISS索引
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
self.index.add(embeddings.astype(np.float32))
|
||||||
|
|
||||||
|
self.embeddings = embeddings
|
||||||
|
print(f"✓ 向量索引构建完成,维度: {dimension}, 文档数: {len(chunks)}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 向量索引构建失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_pdf_to_rag(self, pdf_path: str, output_dir: str = "./rag_knowledge") -> Dict:
|
||||||
|
"""
|
||||||
|
完整处理PDF到RAG知识库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: PDF文件路径
|
||||||
|
output_dir: 输出目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果统计
|
||||||
|
"""
|
||||||
|
print(f"开始处理PDF: {pdf_path}")
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 提取文本
|
||||||
|
raw_text = self.extract_text_from_pdf(pdf_path)
|
||||||
|
print(f"✓ 提取文本长度: {len(raw_text)} 字符")
|
||||||
|
|
||||||
|
# 2. 清理文本
|
||||||
|
clean_text = self.clean_text(raw_text)
|
||||||
|
|
||||||
|
# 保存清理后的文本
|
||||||
|
text_output_path = os.path.join(output_dir, "extracted_text.txt")
|
||||||
|
with open(text_output_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(clean_text)
|
||||||
|
print(f"✓ 清理后文本保存至: {text_output_path}")
|
||||||
|
|
||||||
|
# 3. 语义分块
|
||||||
|
chunks = self.chunk_text_by_semantic(clean_text)
|
||||||
|
self.chunks = chunks
|
||||||
|
print(f"✓ 文档分块完成: {len(chunks)} 个块")
|
||||||
|
|
||||||
|
# 4. 提取关键概念
|
||||||
|
concepts = self.extract_key_concepts(chunks)
|
||||||
|
print(f"✓ 提取关键概念: {len(concepts)} 个")
|
||||||
|
|
||||||
|
# 5. 构建向量索引
|
||||||
|
vector_success = self.build_vector_index(chunks)
|
||||||
|
|
||||||
|
# 6. 保存知识库文件
|
||||||
|
knowledge_base = {
|
||||||
|
"metadata": {
|
||||||
|
"source_file": os.path.basename(pdf_path),
|
||||||
|
"processed_time": datetime.now().isoformat(),
|
||||||
|
"total_chunks": len(chunks),
|
||||||
|
"total_concepts": len(concepts),
|
||||||
|
"vector_enabled": vector_success
|
||||||
|
},
|
||||||
|
"chunks": chunks,
|
||||||
|
"concepts": concepts
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存JSON知识库
|
||||||
|
kb_output_path = os.path.join(output_dir, "knowledge_base.json")
|
||||||
|
with open(kb_output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(knowledge_base, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"✓ 知识库保存至: {kb_output_path}")
|
||||||
|
|
||||||
|
# 保存向量索引
|
||||||
|
if vector_success:
|
||||||
|
index_path = os.path.join(output_dir, "vector_index.faiss")
|
||||||
|
faiss.write_index(self.index, index_path)
|
||||||
|
|
||||||
|
embeddings_path = os.path.join(output_dir, "embeddings.npy")
|
||||||
|
np.save(embeddings_path, self.embeddings)
|
||||||
|
print(f"✓ 向量索引保存至: {index_path}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"chunks_count": len(chunks),
|
||||||
|
"concepts_count": len(concepts),
|
||||||
|
"vector_enabled": vector_success,
|
||||||
|
"output_dir": output_dir
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_knowledge_base(self, knowledge_dir: str) -> bool:
|
||||||
|
"""加载已有知识库"""
|
||||||
|
try:
|
||||||
|
# 加载JSON知识库
|
||||||
|
kb_path = os.path.join(knowledge_dir, "knowledge_base.json")
|
||||||
|
with open(kb_path, 'r', encoding='utf-8') as f:
|
||||||
|
knowledge_base = json.load(f)
|
||||||
|
|
||||||
|
self.chunks = knowledge_base["chunks"]
|
||||||
|
|
||||||
|
# 加载向量索引
|
||||||
|
if EMBEDDING_AVAILABLE:
|
||||||
|
index_path = os.path.join(knowledge_dir, "vector_index.faiss")
|
||||||
|
embeddings_path = os.path.join(knowledge_dir, "embeddings.npy")
|
||||||
|
|
||||||
|
if os.path.exists(index_path) and os.path.exists(embeddings_path):
|
||||||
|
self.index = faiss.read_index(index_path)
|
||||||
|
self.embeddings = np.load(embeddings_path)
|
||||||
|
print(f"✓ 向量索引加载成功")
|
||||||
|
|
||||||
|
print(f"✓ 知识库加载成功: {len(self.chunks)} 个文档块")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 知识库加载失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def search_relevant_content(self, query: str, top_k: int = 3) -> List[Dict]:
|
||||||
|
"""搜索相关内容"""
|
||||||
|
if not self.chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 向量搜索
|
||||||
|
if EMBEDDING_AVAILABLE and self.embedding_model and self.index:
|
||||||
|
try:
|
||||||
|
query_vector = self.embedding_model.encode([query])
|
||||||
|
distances, indices = self.index.search(query_vector.astype(np.float32), top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
|
||||||
|
if idx < len(self.chunks):
|
||||||
|
result = self.chunks[idx].copy()
|
||||||
|
result["relevance_score"] = float(1 / (1 + distance)) # 转换为相似度分数
|
||||||
|
result["rank"] = i + 1
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"向量搜索失败: {e}, 使用文本匹配")
|
||||||
|
|
||||||
|
# 文本匹配搜索
|
||||||
|
query_lower = query.lower()
|
||||||
|
scored_chunks = []
|
||||||
|
|
||||||
|
for chunk in self.chunks:
|
||||||
|
content_lower = chunk["content"].lower()
|
||||||
|
|
||||||
|
# 简单的相关性评分
|
||||||
|
score = 0
|
||||||
|
query_words = query_lower.split()
|
||||||
|
|
||||||
|
for word in query_words:
|
||||||
|
if word in content_lower:
|
||||||
|
score += content_lower.count(word)
|
||||||
|
|
||||||
|
if score > 0:
|
||||||
|
result = chunk.copy()
|
||||||
|
result["relevance_score"] = score
|
||||||
|
scored_chunks.append(result)
|
||||||
|
|
||||||
|
# 按分数排序
|
||||||
|
scored_chunks.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||||
|
|
||||||
|
# 添加排名
|
||||||
|
for i, chunk in enumerate(scored_chunks[:top_k]):
|
||||||
|
chunk["rank"] = i + 1
|
||||||
|
|
||||||
|
return scored_chunks[:top_k]
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""测试PDF处理功能"""
|
||||||
|
processor = PDFToRAGProcessor()
|
||||||
|
|
||||||
|
# 示例:处理COC规则书
|
||||||
|
pdf_path = input("请输入PDF文件路径 (如: ./coc.pdf): ").strip()
|
||||||
|
|
||||||
|
if not os.path.exists(pdf_path):
|
||||||
|
print(f"文件不存在: {pdf_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = processor.process_pdf_to_rag(pdf_path)
|
||||||
|
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
print(f"PDF处理完成!")
|
||||||
|
print(f"状态: {result['status']}")
|
||||||
|
print(f"文档块数量: {result['chunks_count']}")
|
||||||
|
print(f"关键概念数量: {result['concepts_count']}")
|
||||||
|
print(f"向量索引: {'启用' if result['vector_enabled'] else '未启用'}")
|
||||||
|
print(f"输出目录: {result['output_dir']}")
|
||||||
|
|
||||||
|
# 测试搜索
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
print("测试知识库搜索:")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("\n请输入搜索关键词 (输入'quit'退出): ").strip()
|
||||||
|
if query.lower() == 'quit':
|
||||||
|
break
|
||||||
|
|
||||||
|
results = processor.search_relevant_content(query, top_k=3)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
print(f"\n找到 {len(results)} 个相关结果:")
|
||||||
|
for result in results:
|
||||||
|
print(f"\n排名 {result['rank']} (相关度: {result['relevance_score']:.3f}):")
|
||||||
|
content = result['content']
|
||||||
|
preview = content[:200] + "..." if len(content) > 200 else content
|
||||||
|
print(f"{preview}")
|
||||||
|
print("-" * 40)
|
||||||
|
else:
|
||||||
|
print("未找到相关内容")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理失败: {e}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user