90 lines
2.9 KiB
Python
90 lines
2.9 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
'''
|
|
@File : utils.py
|
|
@Time : 2024/06/16 08:05:08
|
|
@Author : 不要葱姜蒜
|
|
@Version : 1.0
|
|
@Desc : None
|
|
'''
|
|
|
|
import os
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
from tqdm import tqdm
|
|
import tiktoken
|
|
import re
|
|
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
|
|
class ReadFiles:
|
|
"""
|
|
class to read files
|
|
"""
|
|
|
|
def __init__(self, path: str) -> None:
|
|
self._path = path
|
|
|
|
def get_content(self, max_token_len: int = 300, cover_content: int = 50):
|
|
# 读取文件内容
|
|
content = self.read_file_content(self._path)
|
|
chunk_content = self.get_chunk(
|
|
content, max_token_len=max_token_len, cover_content=cover_content)
|
|
return chunk_content
|
|
|
|
@classmethod
|
|
def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):
|
|
chunk_text = []
|
|
|
|
curr_len = 0
|
|
curr_chunk = ''
|
|
|
|
token_len = max_token_len - cover_content
|
|
lines = text.splitlines() # 假设以换行符分割文本为行
|
|
|
|
for line in lines:
|
|
line = line.replace(' ', '')
|
|
line_len = len(enc.encode(line))
|
|
if line_len > max_token_len:
|
|
# 如果单行长度就超过限制,则将其分割成多个块
|
|
num_chunks = (line_len + token_len - 1) // token_len
|
|
for i in range(num_chunks):
|
|
start = i * token_len
|
|
end = start + token_len
|
|
# 避免跨单词分割
|
|
while not line[start:end].rstrip().isspace():
|
|
start += 1
|
|
end += 1
|
|
if start >= line_len:
|
|
break
|
|
curr_chunk = (curr_chunk[-cover_content:] if cover_content > 0 else '') + line[start:end]
|
|
chunk_text.append(curr_chunk)
|
|
|
|
if curr_len + line_len <= token_len:
|
|
curr_chunk += line
|
|
curr_chunk += '\n'
|
|
curr_len += line_len
|
|
curr_len += 1
|
|
else:
|
|
chunk_text.append(curr_chunk)
|
|
curr_chunk = curr_chunk[-cover_content:] + line if cover_content > 0 else line
|
|
curr_len = line_len + (cover_content if cover_content > 0 else 0)
|
|
|
|
if curr_chunk:
|
|
chunk_text.append(curr_chunk)
|
|
|
|
return chunk_text
|
|
|
|
@classmethod
|
|
def read_file_content(cls, file_path: str):
|
|
# 根据文件扩展名选择读取方法
|
|
if file_path.endswith('.txt'):
|
|
return cls.read_text(file_path)
|
|
else:
|
|
raise ValueError("Unsupported file type")
|
|
|
|
@classmethod
|
|
def read_text(cls, file_path: str):
|
|
# 读取文本文件
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
return file.read() |