Zoekeend-Phrase-Indexing/ze_index_export.py
2025-08-19 17:23:02 +02:00

116 lines
4.2 KiB
Python

"""
Zoekeend CIFF exporter
Author: Gijs Hendriksen
"""
from typing import Iterable, Type, TypeVar
import duckdb
from ciff_toolkit.write import CiffWriter
from ciff_toolkit.ciff_pb2 import Header, PostingsList, DocRecord
from google.protobuf.message import Message
from tqdm import tqdm
M = TypeVar('M', bound=Message)
def _create_message_from_row(row: tuple | dict, message_type: Type[M]) -> M:
if isinstance(row, tuple):
mapping = zip(message_type.DESCRIPTOR.fields, row)
else:
mapping = [(field, row[field.name]) for field in message_type.DESCRIPTOR.fields]
msg = message_type()
for field, value in mapping:
if field.label == field.LABEL_REPEATED:
for x in value:
getattr(msg, field.name).append(_create_message_from_row(x, field.message_type._concrete_class))
else:
setattr(msg, field.name, value)
return msg
def create_protobuf_messages_from_result(result: duckdb.DuckDBPyRelation, message_type: Type[M], batch_size: int = 1024) -> Iterable[M]:
try:
import protarrow
for batch in result.fetch_arrow_reader(batch_size):
yield from protarrow.record_batch_to_messages(batch, message_type)
except ImportError:
while batch := result.fetchmany(batch_size):
for row in batch:
yield _create_message_from_row(row, message_type)
def create_ciff_header(conn: duckdb.DuckDBPyConnection, description: str) -> Header:
header_info = conn.execute("""
SELECT
1 AS version,
(SELECT COUNT(*) FROM fts_main_documents.dict) AS num_postings_lists,
num_docs,
(SELECT COUNT(*) FROM fts_main_documents.dict) AS total_postings_lists,
num_docs AS total_docs,
(SELECT SUM(len) FROM fts_main_documents.docs)::BIGINT AS total_terms_in_collection,
avgdl AS average_doclength,
? AS description,
FROM fts_main_documents.stats
""", [description])
header, = create_protobuf_messages_from_result(header_info, Header)
return header
def create_ciff_postings_lists(conn: duckdb.DuckDBPyConnection, batch_size: int = 1024) -> Iterable[PostingsList]:
postings_info = conn.sql("""
WITH postings AS (
SELECT termid, docid, COUNT(*) AS tf
FROM fts_main_documents.terms
GROUP BY ALL
),
gapped_postings AS (
SELECT *, docid - lag(docid, 1, 0) OVER (PARTITION BY termid ORDER BY docid) AS gap
FROM postings
),
grouped_postings AS (
SELECT termid, list(row(gap, tf)::STRUCT(docid BIGINT, tf BIGINT) ORDER BY docid) AS postings, SUM(tf)::BIGINT AS cf
FROM gapped_postings
GROUP BY termid
)
SELECT term, df, cf, postings
FROM grouped_postings
JOIN fts_main_documents.dict USING (termid)
ORDER BY term;
""")
yield from create_protobuf_messages_from_result(postings_info, PostingsList, batch_size=batch_size)
def create_ciff_doc_records(conn: duckdb.DuckDBPyConnection, batch_size: int = 1024) -> Iterable[DocRecord]:
docs_info = conn.sql("""
SELECT
docid,
name AS collection_docid,
len AS doclength,
FROM fts_main_documents.docs
ORDER BY collection_docid
""")
yield from create_protobuf_messages_from_result(docs_info, DocRecord, batch_size=batch_size)
def ciff_export(db_name: str, file_name: str, description: str, batch_size: int = 1024):
with duckdb.connect(db_name) as conn, CiffWriter(file_name) as writer:
header = create_ciff_header(conn, description)
print(header)
writer.write_header(header)
writer.write_postings_lists(tqdm(create_ciff_postings_lists(conn, batch_size=batch_size), total=header.num_postings_lists,
desc='Writing posting lists', unit='pl'))
writer.write_documents(tqdm(create_ciff_doc_records(conn, batch_size=batch_size), total=header.num_docs,
desc='Writing documents', unit='d'))
if __name__ == '__main__':
ciff_export('index.db', 'index-copy.ciff.gz', 'OWS.eu index', batch_size=2**12)