ai-robot-core/ai-service/app/services/document/markdown_chunker.py

772 lines
25 KiB
Python

"""
Markdown intelligent chunker with structure-aware splitting.
Supports headers, code blocks, tables, lists, and preserves context.
"""
import logging
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
logger = logging.getLogger(__name__)
class MarkdownElementType(Enum):
"""Types of Markdown elements."""
HEADER = "header"
PARAGRAPH = "paragraph"
CODE_BLOCK = "code_block"
INLINE_CODE = "inline_code"
TABLE = "table"
LIST = "list"
BLOCKQUOTE = "blockquote"
HORIZONTAL_RULE = "horizontal_rule"
IMAGE = "image"
LINK = "link"
TEXT = "text"
@dataclass
class MarkdownElement:
"""Represents a parsed Markdown element."""
type: MarkdownElementType
content: str
level: int = 0
language: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
line_start: int = 0
line_end: int = 0
def to_dict(self) -> dict[str, Any]:
return {
"type": self.type.value,
"content": self.content,
"level": self.level,
"language": self.language,
"metadata": self.metadata,
"line_start": self.line_start,
"line_end": self.line_end,
}
@dataclass
class MarkdownChunk:
"""Represents a chunk of Markdown content with context."""
chunk_id: str
content: str
element_type: MarkdownElementType
header_context: list[str]
level: int = 0
language: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"chunk_id": self.chunk_id,
"content": self.content,
"element_type": self.element_type.value,
"header_context": self.header_context,
"level": self.level,
"language": self.language,
"metadata": self.metadata,
}
class MarkdownParser:
"""
Parser for Markdown documents.
Extracts structured elements from Markdown text.
"""
HEADER_PATTERN = re.compile(r'^(#{1,6})\s+(.+?)(?:\s+#+)?$', re.MULTILINE)
CODE_BLOCK_PATTERN = re.compile(r'^```(\w*)\n(.*?)^```', re.MULTILINE | re.DOTALL)
TABLE_PATTERN = re.compile(r'^(\|.+\|)\n(\|[-:\s|]+\|)\n((?:\|.+\|\n?)+)', re.MULTILINE)
LIST_PATTERN = re.compile(r'^([ \t]*[-*+]|\d+\.)\s+(.+)$', re.MULTILINE)
BLOCKQUOTE_PATTERN = re.compile(r'^>\s*(.+)$', re.MULTILINE)
HR_PATTERN = re.compile(r'^[-*_]{3,}\s*$', re.MULTILINE)
IMAGE_PATTERN = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)')
LINK_PATTERN = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
INLINE_CODE_PATTERN = re.compile(r'`([^`]+)`')
def parse(self, text: str) -> list[MarkdownElement]:
"""
Parse Markdown text into structured elements.
Args:
text: Raw Markdown text
Returns:
List of MarkdownElement objects
"""
elements = []
lines = text.split('\n')
current_pos = 0
code_block_ranges = self._extract_code_blocks(text, lines, elements)
table_ranges = self._extract_tables(text, lines, elements)
protected_ranges = code_block_ranges + table_ranges
self._extract_headers(lines, elements, protected_ranges)
self._extract_lists(lines, elements, protected_ranges)
self._extract_blockquotes(lines, elements, protected_ranges)
self._extract_horizontal_rules(lines, elements, protected_ranges)
self._fill_paragraphs(lines, elements, protected_ranges)
elements.sort(key=lambda e: e.line_start)
return elements
def _extract_code_blocks(
self,
text: str,
lines: list[str],
elements: list[MarkdownElement],
) -> list[tuple[int, int]]:
"""Extract code blocks with language info."""
ranges = []
in_code_block = False
code_start = 0
language = ""
code_content = []
for i, line in enumerate(lines):
if line.strip().startswith('```'):
if not in_code_block:
in_code_block = True
code_start = i
language = line.strip()[3:].strip()
code_content = []
else:
in_code_block = False
elements.append(MarkdownElement(
type=MarkdownElementType.CODE_BLOCK,
content='\n'.join(code_content),
language=language,
line_start=code_start,
line_end=i,
metadata={"language": language},
))
ranges.append((code_start, i))
elif in_code_block:
code_content.append(line)
return ranges
def _extract_tables(
self,
text: str,
lines: list[str],
elements: list[MarkdownElement],
) -> list[tuple[int, int]]:
"""Extract Markdown tables."""
ranges = []
i = 0
while i < len(lines):
line = lines[i]
if '|' in line and i + 1 < len(lines):
next_line = lines[i + 1]
if '|' in next_line and re.match(r'^[\|\-\:\s]+$', next_line.strip()):
table_lines = [line, next_line]
j = i + 2
while j < len(lines) and '|' in lines[j]:
table_lines.append(lines[j])
j += 1
table_content = '\n'.join(table_lines)
headers = [h.strip() for h in line.split('|') if h.strip()]
row_count = len(table_lines) - 2
elements.append(MarkdownElement(
type=MarkdownElementType.TABLE,
content=table_content,
line_start=i,
line_end=j - 1,
metadata={
"headers": headers,
"row_count": row_count,
},
))
ranges.append((i, j - 1))
i = j
continue
i += 1
return ranges
def _is_in_protected_range(self, line_num: int, ranges: list[tuple[int, int]]) -> bool:
"""Check if a line is within a protected range."""
for start, end in ranges:
if start <= line_num <= end:
return True
return False
def _extract_headers(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract headers with level info."""
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
continue
match = self.HEADER_PATTERN.match(line)
if match:
level = len(match.group(1))
title = match.group(2).strip()
elements.append(MarkdownElement(
type=MarkdownElementType.HEADER,
content=title,
level=level,
line_start=i,
line_end=i,
metadata={"level": level},
))
def _extract_lists(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract list items."""
in_list = False
list_start = 0
list_items = []
list_indent = 0
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
if in_list:
self._save_list(elements, list_start, i - 1, list_items)
in_list = False
list_items = []
continue
match = self.LIST_PATTERN.match(line)
if match:
indent = len(line) - len(line.lstrip())
item_content = match.group(2)
if not in_list:
in_list = True
list_start = i
list_indent = indent
list_items = [(indent, item_content)]
else:
list_items.append((indent, item_content))
else:
if in_list:
if line.strip() == '':
continue
else:
self._save_list(elements, list_start, i - 1, list_items)
in_list = False
list_items = []
if in_list:
self._save_list(elements, list_start, len(lines) - 1, list_items)
def _save_list(
self,
elements: list[MarkdownElement],
start: int,
end: int,
items: list[tuple[int, str]],
) -> None:
"""Save a list element."""
if not items:
return
content = '\n'.join([item[1] for item in items])
elements.append(MarkdownElement(
type=MarkdownElementType.LIST,
content=content,
line_start=start,
line_end=end,
metadata={
"item_count": len(items),
"is_ordered": False,
},
))
def _extract_blockquotes(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract blockquotes."""
in_quote = False
quote_start = 0
quote_lines = []
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
if in_quote:
self._save_blockquote(elements, quote_start, i - 1, quote_lines)
in_quote = False
quote_lines = []
continue
match = self.BLOCKQUOTE_PATTERN.match(line)
if match:
if not in_quote:
in_quote = True
quote_start = i
quote_lines.append(match.group(1))
else:
if in_quote:
self._save_blockquote(elements, quote_start, i - 1, quote_lines)
in_quote = False
quote_lines = []
if in_quote:
self._save_blockquote(elements, quote_start, len(lines) - 1, quote_lines)
def _save_blockquote(
self,
elements: list[MarkdownElement],
start: int,
end: int,
lines: list[str],
) -> None:
"""Save a blockquote element."""
if not lines:
return
elements.append(MarkdownElement(
type=MarkdownElementType.BLOCKQUOTE,
content='\n'.join(lines),
line_start=start,
line_end=end,
))
def _extract_horizontal_rules(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Extract horizontal rules."""
for i, line in enumerate(lines):
if self._is_in_protected_range(i, protected_ranges):
continue
if self.HR_PATTERN.match(line):
elements.append(MarkdownElement(
type=MarkdownElementType.HORIZONTAL_RULE,
content=line,
line_start=i,
line_end=i,
))
def _fill_paragraphs(
self,
lines: list[str],
elements: list[MarkdownElement],
protected_ranges: list[tuple[int, int]],
) -> None:
"""Fill in paragraphs for remaining content."""
occupied = set()
for start, end in protected_ranges:
for i in range(start, end + 1):
occupied.add(i)
for elem in elements:
for i in range(elem.line_start, elem.line_end + 1):
occupied.add(i)
i = 0
while i < len(lines):
if i in occupied:
i += 1
continue
if lines[i].strip() == '':
i += 1
continue
para_start = i
para_lines = []
while i < len(lines) and i not in occupied and lines[i].strip() != '':
para_lines.append(lines[i])
occupied.add(i)
i += 1
if para_lines:
elements.append(MarkdownElement(
type=MarkdownElementType.PARAGRAPH,
content='\n'.join(para_lines),
line_start=para_start,
line_end=i - 1,
))
class MarkdownChunker:
"""
Intelligent chunker for Markdown documents.
Features:
- Structure-aware splitting (headers, code blocks, tables, lists)
- Context preservation (header hierarchy)
- Configurable chunk size and overlap
- Metadata extraction
"""
def __init__(
self,
max_chunk_size: int = 1000,
min_chunk_size: int = 100,
chunk_overlap: int = 50,
preserve_code_blocks: bool = True,
preserve_tables: bool = True,
preserve_lists: bool = True,
include_header_context: bool = True,
):
self._max_chunk_size = max_chunk_size
self._min_chunk_size = min_chunk_size
self._chunk_overlap = chunk_overlap
self._preserve_code_blocks = preserve_code_blocks
self._preserve_tables = preserve_tables
self._preserve_lists = preserve_lists
self._include_header_context = include_header_context
self._parser = MarkdownParser()
def chunk(self, text: str, doc_id: str = "") -> list[MarkdownChunk]:
"""
Chunk Markdown text into structured segments.
Args:
text: Raw Markdown text
doc_id: Optional document ID for chunk IDs
Returns:
List of MarkdownChunk objects
"""
elements = self._parser.parse(text)
chunks = []
header_stack: list[str] = []
chunk_index = 0
for elem in elements:
if elem.type == MarkdownElementType.HEADER:
level = elem.level
while len(header_stack) >= level:
if header_stack:
header_stack.pop()
header_stack.append(elem.content)
continue
if elem.type == MarkdownElementType.HORIZONTAL_RULE:
continue
chunk_content = self._format_element_content(elem)
if not chunk_content:
continue
chunk_id = f"{doc_id}_chunk_{chunk_index}" if doc_id else f"chunk_{chunk_index}"
header_context = []
if self._include_header_context:
header_context = header_stack.copy()
if len(chunk_content) > self._max_chunk_size:
sub_chunks = self._split_large_element(
elem,
chunk_id,
header_context,
chunk_index,
)
chunks.extend(sub_chunks)
chunk_index += len(sub_chunks)
else:
chunks.append(MarkdownChunk(
chunk_id=chunk_id,
content=chunk_content,
element_type=elem.type,
header_context=header_context,
level=elem.level,
language=elem.language,
metadata=elem.metadata,
))
chunk_index += 1
return chunks
def _format_element_content(self, elem: MarkdownElement) -> str:
"""Format element content based on type."""
if elem.type == MarkdownElementType.CODE_BLOCK:
lang = elem.language or ""
return f"```{lang}\n{elem.content}\n```"
elif elem.type == MarkdownElementType.TABLE:
return elem.content
elif elem.type == MarkdownElementType.LIST:
return elem.content
elif elem.type == MarkdownElementType.BLOCKQUOTE:
lines = elem.content.split('\n')
return '\n'.join([f"> {line}" for line in lines])
elif elem.type == MarkdownElementType.PARAGRAPH:
return elem.content
return elem.content
def _split_large_element(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split a large element into smaller chunks."""
chunks = []
if elem.type == MarkdownElementType.CODE_BLOCK:
chunks = self._split_code_block(elem, base_id, header_context, start_index)
elif elem.type == MarkdownElementType.TABLE:
chunks = self._split_table(elem, base_id, header_context, start_index)
elif elem.type == MarkdownElementType.LIST:
chunks = self._split_list(elem, base_id, header_context, start_index)
else:
chunks = self._split_text(elem, base_id, header_context, start_index)
return chunks
def _split_code_block(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split code block while preserving language marker."""
chunks = []
lines = elem.content.split('\n')
current_lines = []
current_size = 0
sub_index = 0
for line in lines:
if current_size + len(line) + 1 > self._max_chunk_size and current_lines:
chunk_content = f"```{elem.language}\n" + '\n'.join(current_lines) + "\n```"
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=chunk_content,
element_type=MarkdownElementType.CODE_BLOCK,
header_context=header_context,
language=elem.language,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_lines = []
current_size = 0
current_lines.append(line)
current_size += len(line) + 1
if current_lines:
chunk_content = f"```{elem.language}\n" + '\n'.join(current_lines) + "\n```"
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=chunk_content,
element_type=MarkdownElementType.CODE_BLOCK,
header_context=header_context,
language=elem.language,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def _split_table(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split table while preserving header row."""
chunks = []
lines = elem.content.split('\n')
if len(lines) < 2:
return [MarkdownChunk(
chunk_id=f"{base_id}_0",
content=elem.content,
element_type=MarkdownElementType.TABLE,
header_context=header_context,
metadata=elem.metadata,
)]
header_line = lines[0]
separator_line = lines[1]
data_lines = lines[2:]
current_lines = [header_line, separator_line]
current_size = len(header_line) + len(separator_line) + 2
sub_index = 0
for line in data_lines:
if current_size + len(line) + 1 > self._max_chunk_size and len(current_lines) > 2:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_lines),
element_type=MarkdownElementType.TABLE,
header_context=header_context,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_lines = [header_line, separator_line]
current_size = len(header_line) + len(separator_line) + 2
current_lines.append(line)
current_size += len(line) + 1
if len(current_lines) > 2:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_lines),
element_type=MarkdownElementType.TABLE,
header_context=header_context,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def _split_list(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split list into smaller chunks."""
chunks = []
items = elem.content.split('\n')
current_items = []
current_size = 0
sub_index = 0
for item in items:
if current_size + len(item) + 1 > self._max_chunk_size and current_items:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_items),
element_type=MarkdownElementType.LIST,
header_context=header_context,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_items = []
current_size = 0
current_items.append(item)
current_size += len(item) + 1
if current_items:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content='\n'.join(current_items),
element_type=MarkdownElementType.LIST,
header_context=header_context,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def _split_text(
self,
elem: MarkdownElement,
base_id: str,
header_context: list[str],
start_index: int,
) -> list[MarkdownChunk]:
"""Split text content by sentences or paragraphs."""
chunks = []
text = elem.content
sub_index = 0
paragraphs = text.split('\n\n')
current_content = ""
current_size = 0
for para in paragraphs:
if current_size + len(para) + 2 > self._max_chunk_size and current_content:
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=current_content.strip(),
element_type=elem.type,
header_context=header_context,
metadata={**elem.metadata, "is_partial": True, "part": sub_index + 1},
))
sub_index += 1
current_content = ""
current_size = 0
current_content += para + "\n\n"
current_size += len(para) + 2
if current_content.strip():
chunks.append(MarkdownChunk(
chunk_id=f"{base_id}_{sub_index}",
content=current_content.strip(),
element_type=elem.type,
header_context=header_context,
metadata={**elem.metadata, "is_partial": sub_index > 0, "part": sub_index + 1},
))
return chunks
def chunk_markdown(
text: str,
doc_id: str = "",
max_chunk_size: int = 1000,
min_chunk_size: int = 100,
preserve_code_blocks: bool = True,
preserve_tables: bool = True,
preserve_lists: bool = True,
include_header_context: bool = True,
) -> list[dict[str, Any]]:
"""
Convenience function to chunk Markdown text.
Args:
text: Raw Markdown text
doc_id: Optional document ID
max_chunk_size: Maximum chunk size in characters
min_chunk_size: Minimum chunk size in characters
preserve_code_blocks: Whether to preserve code blocks
preserve_tables: Whether to preserve tables
preserve_lists: Whether to preserve lists
include_header_context: Whether to include header context
Returns:
List of chunk dictionaries
"""
chunker = MarkdownChunker(
max_chunk_size=max_chunk_size,
min_chunk_size=min_chunk_size,
preserve_code_blocks=preserve_code_blocks,
preserve_tables=preserve_tables,
preserve_lists=preserve_lists,
include_header_context=include_header_context,
)
chunks = chunker.chunk(text, doc_id)
return [chunk.to_dict() for chunk in chunks]