Zoekeend-Phrase-Indexing/ze_reindex_prior.py
2025-08-19 17:23:02 +02:00

115 lines
4.1 KiB
Python

import pathlib
import sys
import duckdb
def copy_file(name_in, name_out):
path1 = pathlib.Path(name_in)
if not path1.is_file():
raise ValueError(f"File {name_in} does not exist.")
path2 = pathlib.Path(name_out)
if path2.is_file():
raise ValueError(f"File {name_out} already exists.")
path2.write_bytes(path1.read_bytes())
def get_stats_stemmer(con):
sql = "SELECT stemmer FROM fts_main_documents.stats"
return con.sql(sql).fetchall()[0][0]
def replace_lm_prior(con, stemmer):
con.sql(f"""
CREATE OR REPLACE MACRO fts_main_documents.match_lm(docname, query_string, fields := NULL, lambda := 0.3, conjunctive := 0) AS (
WITH tokens AS (
SELECT DISTINCT stem(unnest(fts_main_documents.tokenize(query_string)), '{stemmer}') AS t
),
fieldids AS (
SELECT fieldid
FROM fts_main_documents.fields
WHERE CASE WHEN ((fields IS NULL)) THEN (1) ELSE (field = ANY(SELECT * FROM (SELECT unnest(string_split(fields, ','))) AS fsq)) END
),
qtermids AS (
SELECT termid, df
FROM fts_main_documents.dict AS dict, tokens
WHERE (dict.term = tokens.t)
),
qterms AS (
SELECT termid, docid
FROM fts_main_documents.terms AS terms
WHERE (CASE WHEN ((fields IS NULL)) THEN (1) ELSE (fieldid = ANY(SELECT * FROM fieldids)) END
AND (termid = ANY(SELECT qtermids.termid FROM qtermids)))
),
term_tf AS (
SELECT termid, docid, count_star() AS tf
FROM qterms
GROUP BY docid, termid
),
cdocs AS (
SELECT docid
FROM qterms
GROUP BY docid
HAVING CASE WHEN (conjunctive) THEN ((count(DISTINCT termid) = (SELECT count_star() FROM tokens))) ELSE 1 END
),
subscores AS (
SELECT docs.docid, prior, len, term_tf.termid, tf, df, LN(1 + (lambda * tf * (SELECT ANY_VALUE(sumdf) FROM fts_main_documents.stats)) / ((1-lambda) * df * len)) AS subscore
FROM term_tf, cdocs, fts_main_documents.docs AS docs, qtermids
WHERE ((term_tf.docid = cdocs.docid)
AND (term_tf.docid = docs.docid)
AND (term_tf.termid = qtermids.termid))
),
scores AS (
SELECT docid, LN(ANY_VALUE(prior)) + sum(subscore) AS score FROM subscores GROUP BY docid
)
SELECT score FROM scores, fts_main_documents.docs AS docs
WHERE ((scores.docid = docs.docid) AND (docs."name" = docname)))
""")
def insert_priors(con, csv_file, default):
con.sql(f"""
UPDATE fts_main_documents.docs AS docs
SET prior = priors.prior
FROM read_csv({csv_file}) AS priors
WHERE docs.name = priors.did
""")
if not default is None:
con.sql(f"""
UPDATE fts_main_documents.docs
SET prior = {default}
WHERE prior IS NULL
""")
else:
count = con.sql("""
SELECT COUNT(*)
FROM fts_main_documents.docs
WHERE prior IS NULL
""").fetchall()[0][0]
if count > 0:
print(f"Warning: {count} rows missing from file. Use --default", file=sys.stderr)
def reindex_prior(name_in, name_out, csv_file=None, default=None, init=None):
copy_file(name_in, name_out)
con = duckdb.connect(name_out)
con.sql("ALTER TABLE fts_main_documents.docs ADD prior DOUBLE")
if (csv_file and init):
print(f"Warning: init={init} ignored.", file=sys.stderr)
if csv_file:
insert_priors(con, csv_file, default)
elif init:
if init == 'len':
con.sql("UPDATE fts_main_documents.docs SET prior = len")
elif init == 'uniform':
con.sql("UPDATE fts_main_documents.docs SET prior = 1")
else:
raise ValueError(f'Unknown value for init: {init}')
stemmer = get_stats_stemmer(con)
replace_lm_prior(con, stemmer=stemmer)
con.close()
if __name__ == "__main__":
reindex_prior('cran.db', 'cran_prior.db', csv_file='test_priors.csv')