Skip to content

Commit

Permalink
Fix divided by zero issue. (#4784)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#4779

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored Feb 8, 2025
1 parent ccb72e6 commit f374dd3
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions rag/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def find_codec(blob):

return "utf-8"


QUESTION_PATTERN = [
r"第([零一二三四五六七八九十百0-9]+)问",
r"第([零一二三四五六七八九十百0-9]+)条",
Expand All @@ -83,6 +84,7 @@ def find_codec(blob):
r"QUESTION ([0-9]+)",
]


def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
section, last_section = box['text'], last_box['text']
q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+'
Expand All @@ -94,7 +96,7 @@ def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
last_box['x0'] = box['x0']
if 'top' not in last_box:
last_box['top'] = box['top']
if last_bull and box['x0']-last_box['x0']>10:
if last_bull and box['x0'] - last_box['x0'] > 10:
return None, last_index
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
return None, last_index
Expand Down Expand Up @@ -125,13 +127,14 @@ def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
return has_bull, index
return None, last_index


def index_int(index_str):
res = -1
try:
res=int(index_str)
res = int(index_str)
except ValueError:
try:
res=w2n.word_to_num(index_str)
res = w2n.word_to_num(index_str)
except ValueError:
try:
res = cn2an(index_str)
Expand All @@ -142,6 +145,7 @@ def index_int(index_str):
return -1
return res


def qbullets_category(sections):
global QUESTION_PATTERN
hits = [0] * len(QUESTION_PATTERN)
Expand Down Expand Up @@ -230,7 +234,10 @@ def is_english(texts):
return True
return False


def is_chinese(text):
if not text:
return False
chinese = 0
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
Expand All @@ -239,6 +246,7 @@ def is_chinese(text):
return True
return False


def tokenize(d, t, eng):
d["content_with_weight"] = t
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
Expand Down Expand Up @@ -328,7 +336,7 @@ def remove_contents_table(sections, eng=False):
def get(i):
nonlocal sections
return (sections[i] if isinstance(sections[i],
type("")) else sections[i][0]).strip()
type("")) else sections[i][0]).strip()

if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)):
Expand Down Expand Up @@ -378,9 +386,9 @@ def make_colon_as_title(sections):

def title_frequency(bull, sections):
bullets_size = len(BULLET_PATTERN[bull])
levels = [bullets_size+1 for _ in range(len(sections))]
levels = [bullets_size + 1 for _ in range(len(sections))]
if not sections or bull < 0:
return bullets_size+1, levels
return bullets_size + 1, levels

for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
Expand All @@ -390,8 +398,8 @@ def title_frequency(bull, sections):
else:
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size+1
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
most_level = bullets_size + 1
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
if level <= bullets_size:
most_level = level
break
Expand All @@ -416,7 +424,6 @@ def hierarchical_merge(bull, sections, depth):
bullets_size = len(BULLET_PATTERN[bull])
levels = [[] for _ in range(bullets_size + 2)]


for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()):
Expand Down Expand Up @@ -531,7 +538,7 @@ def add_chunk(t, pos):
return cks


def docx_question_level(p, bull = -1):
def docx_question_level(p, bull=-1):
txt = re.sub(r"\u3000", " ", p.text).strip()
if p.style.name.startswith('Heading'):
return int(p.style.name.split(' ')[-1]), txt
Expand All @@ -540,10 +547,10 @@ def docx_question_level(p, bull = -1):
return 0, txt
for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt):
return j+1, txt
return j + 1, txt
return len(BULLET_PATTERN[bull]), txt


def concat_img(img1, img2):
if img1 and not img2:
return img1
Expand Down Expand Up @@ -594,4 +601,3 @@ def add_chunk(t, image, pos=""):
add_chunk(sec, image, '')

return cks, images

0 comments on commit f374dd3

Please sign in to comment.