mirror of
https://github.com/ArthurIdema/Zoekeend-Phrase-Indexing.git
synced 2025-10-26 16:24:21 +00:00
384 lines
14 KiB
Python
384 lines
14 KiB
Python
import pathlib
|
|
import sys
|
|
|
|
import duckdb
|
|
import ir_datasets
|
|
|
|
|
|
def copy_file(name_in, name_out):
|
|
""" Simple file copy """
|
|
path1 = pathlib.Path(name_in)
|
|
if not path1.is_file():
|
|
raise ValueError(f"File {name_in} does not exist.")
|
|
path2 = pathlib.Path(name_out)
|
|
if path2.is_file():
|
|
raise ValueError(f"File {name_out} already exists.")
|
|
path2.write_bytes(path1.read_bytes())
|
|
|
|
|
|
def get_stats_stemmer(con):
|
|
""" What stemmer was used on this index? """
|
|
sql = "SELECT stemmer FROM fts_main_documents.stats"
|
|
return con.sql(sql).fetchall()[0][0]
|
|
|
|
|
|
def sample_by_values(con, column, threshold):
|
|
""" Takes one sample per unique value of len/prior. """
|
|
con.sql(f"""
|
|
CREATE VIEW sample AS
|
|
WITH histogram as (
|
|
SELECT "{column}", COUNT(*) AS count
|
|
FROM fts_main_documents.docs
|
|
WHERE "{column}" > {threshold}
|
|
GROUP BY "{column}"
|
|
)
|
|
SELECT LN(SUM(H2.count)) AS x, LN(H1."{column}") AS y
|
|
FROM histogram H1, histogram H2
|
|
WHERE H1."{column}" <= H2."{column}"
|
|
GROUP BY H1."{column}"
|
|
""")
|
|
|
|
|
|
def sample_by_fixed_points(con, column, threshold, total):
|
|
""" Takes {total} samples and averages len/prior for each. """
|
|
con.sql(f"""
|
|
CREATE VIEW sample AS
|
|
WITH groups AS (
|
|
SELECT (CASE WHEN range = 2 THEN 0 ELSE range END) *
|
|
LN(num_docs + 1) / ({total} + 2) AS group_start,
|
|
(range + 1) * LN(num_docs + 1) / ({total} + 2) AS group_end
|
|
FROM RANGE({total} + 2), fts_main_documents.stats
|
|
WHERE range > 1
|
|
)
|
|
SELECT (group_start + group_end) / 2 AS X, LN(AVG({column})) AS Y
|
|
FROM groups, fts_main_documents.docs AS docs
|
|
WHERE LN(docid + 1) >= group_start AND LN(docid + 1) < group_end
|
|
AND "{column}" > {threshold}
|
|
GROUP BY group_start, group_end
|
|
""")
|
|
|
|
|
|
def sample_by_fixed_points_qrels(con, total):
|
|
"""
|
|
Takes {total} samples and estimates the probability of relevance
|
|
from the provided qrels
|
|
"""
|
|
con.sql(f"""
|
|
CREATE VIEW sample AS
|
|
WITH groups AS (
|
|
SELECT (CASE WHEN range = 2 THEN 0 ELSE range END) *
|
|
LN(num_docs + 1) / ({total} + 2) AS group_start,
|
|
(range + 1) * LN(num_docs + 1) / ({total} + 2) AS group_end
|
|
FROM RANGE({total} + 2), fts_main_documents.stats
|
|
WHERE range > 1
|
|
)
|
|
SELECT (group_start + group_end) / 2 AS X,
|
|
LN(COUNT(*)/(EXP(group_end) - EXP(group_start))) AS Y
|
|
FROM groups, fts_main_documents.docs AS docs, qrels
|
|
WHERE LN(docid + 1) >= group_start AND LN(docid + 1) < group_end
|
|
AND docs.name = qrels.did
|
|
AND qrels.rel > 0
|
|
GROUP BY group_start, group_end
|
|
""")
|
|
|
|
|
|
def print_sample_tsv(con, total=None):
|
|
""" Prints sample for drawing nice graphs. """
|
|
result = con.sql("SELECT x, y FROM sample ORDER BY x").fetchall()
|
|
if total and len(result) != total:
|
|
print(f"Warning: less than {total} datapoints.", file=sys.stderr)
|
|
for (x, y) in result:
|
|
print(str(x) + "\t" + str(y))
|
|
|
|
|
|
def train_linear_regression(con):
|
|
""" Approximate sample by using linear regression. """
|
|
con.sql("""
|
|
WITH sums AS (
|
|
SELECT COUNT(*) AS N, SUM(x) AS Sx, SUM(y) AS Sy,
|
|
SUM(x*x) AS Sxx, SUM(x*y) AS Sxy
|
|
FROM sample
|
|
),
|
|
model AS (
|
|
SELECT (Sy*Sxx - Sx*Sxy) / (N*Sxx - Sx*Sx) AS intercept,
|
|
(N*Sxy - Sx*Sy) / (N*Sxx - Sx*Sx) AS slope
|
|
FROM sums
|
|
)
|
|
UPDATE fts_main_documents.stats AS stats
|
|
SET intercept = model.intercept, slope =
|
|
CASE WHEN model.slope < 0 THEN model.slope ELSE 0 END
|
|
FROM model
|
|
""")
|
|
|
|
|
|
def get_qrels_from_file(qrel_file):
|
|
inserts = []
|
|
with open(qrel_file, "r", encoding="ascii") as file:
|
|
for line in file:
|
|
(query_id, q0 ,doc_id, relevance) = line.split()
|
|
if relevance != 0:
|
|
inserts.append([query_id, doc_id, relevance])
|
|
return inserts
|
|
|
|
|
|
def get_qrels_from_ir_datasets(qrels_tag):
|
|
inserts = []
|
|
for q in ir_datasets.load(qrels_tag).qrels_iter():
|
|
if q.relevance != 0:
|
|
inserts.append([q.query_id, q.doc_id, q.relevance])
|
|
return inserts
|
|
|
|
|
|
def insert_qrels(con, qrels_tag):
|
|
con.sql("CREATE OR REPLACE TABLE main.qrels(qid TEXT, did TEXT, rel INT)")
|
|
try:
|
|
inserts = get_qrels_from_ir_datasets(qrels_tag)
|
|
except KeyError:
|
|
inserts = get_qrels_from_file(qrels_tag)
|
|
con.sql("BEGIN TRANSACTION")
|
|
con.executemany("INSERT INTO qrels VALUES (?, ?, ?)", inserts)
|
|
con.sql("COMMIT")
|
|
|
|
|
|
def replace_bm25_fitted_doclen(con, stemmer):
|
|
con.sql(f"""
|
|
CREATE OR REPLACE MACRO fts_main_documents.match_bm25(docname, query_string, b := 0.75, k := 1.2, conjunctive := 0, fields := NULL) AS (
|
|
WITH tokens AS (
|
|
SELECT DISTINCT stem(unnest(fts_main_documents.tokenize(query_string)), '{stemmer}') AS t
|
|
),
|
|
fieldids AS (
|
|
SELECT fieldid
|
|
FROM fts_main_documents.fields
|
|
WHERE CASE WHEN fields IS NULL THEN 1 ELSE field IN (SELECT * FROM (SELECT UNNEST(string_split(fields, ','))) AS fsq) END
|
|
),
|
|
qtermids AS (
|
|
SELECT termid, df
|
|
FROM fts_main_documents.dict AS dict,
|
|
tokens
|
|
WHERE dict.term = tokens.t
|
|
),
|
|
qterms AS (
|
|
SELECT termid,
|
|
docid
|
|
FROM fts_main_documents.terms AS terms
|
|
WHERE CASE WHEN fields IS NULL THEN 1 ELSE fieldid IN (SELECT * FROM fieldids) END
|
|
AND termid IN (SELECT qtermids.termid FROM qtermids)
|
|
),
|
|
term_tf AS (
|
|
SELECT termid, docid, COUNT(*) AS tf
|
|
FROM qterms
|
|
GROUP BY docid, termid
|
|
),
|
|
cdocs AS (
|
|
SELECT docid
|
|
FROM qterms
|
|
GROUP BY docid
|
|
HAVING CASE WHEN conjunctive THEN COUNT(DISTINCT termid) = (SELECT COUNT(*) FROM tokens) ELSE 1 END
|
|
),
|
|
subscores AS (
|
|
SELECT docs.docid, EXP(LN(docs.docid)*stats.slope + stats.intercept) AS newlen, term_tf.termid, tf, df, (log((((stats.num_docs - df) + 0.5) / (df + 0.5))) * ((tf * (k + 1)) / (tf + (k * ((1 - b) + (b * (newlen / stats.avgdl))))))) AS subscore
|
|
FROM term_tf, cdocs, fts_main_documents.docs AS docs, qtermids,
|
|
fts_main_documents.stats AS stats,
|
|
WHERE term_tf.docid = cdocs.docid
|
|
AND term_tf.docid = docs.docid
|
|
AND term_tf.termid = qtermids.termid
|
|
),
|
|
scores AS (
|
|
SELECT docid, sum(subscore) AS score
|
|
FROM subscores
|
|
GROUP BY docid
|
|
)
|
|
SELECT score
|
|
FROM scores, fts_main_documents.docs AS docs
|
|
WHERE scores.docid = docs.docid
|
|
AND docs.name = docname
|
|
)"""
|
|
)
|
|
|
|
|
|
def replace_lm_fitted_doclen(con, stemmer):
|
|
con.sql(f"""
|
|
CREATE OR REPLACE MACRO fts_main_documents.match_lm(docname, query_string, fields := NULL, lambda := 0.3, conjunctive := 0) AS (
|
|
WITH tokens AS (
|
|
SELECT DISTINCT stem(unnest(fts_main_documents.tokenize(query_string)), '{stemmer}') AS t
|
|
),
|
|
fieldids AS (
|
|
SELECT fieldid
|
|
FROM fts_main_documents.fields
|
|
WHERE CASE WHEN fields IS NULL THEN 1 ELSE field IN (SELECT * FROM (SELECT UNNEST(string_split(fields, ','))) AS fsq) END
|
|
),
|
|
qtermids AS (
|
|
SELECT termid, df
|
|
FROM fts_main_documents.dict AS dict,
|
|
tokens
|
|
WHERE dict.term = tokens.t
|
|
),
|
|
qterms AS (
|
|
SELECT termid,
|
|
docid
|
|
FROM fts_main_documents.terms AS terms
|
|
WHERE CASE WHEN fields IS NULL THEN 1 ELSE fieldid IN (SELECT * FROM fieldids) END
|
|
AND termid IN (SELECT qtermids.termid FROM qtermids)
|
|
),
|
|
term_tf AS (
|
|
SELECT termid, docid, COUNT(*) AS tf
|
|
FROM qterms
|
|
GROUP BY docid, termid
|
|
),
|
|
cdocs AS (
|
|
SELECT docid
|
|
FROM qterms
|
|
GROUP BY docid
|
|
HAVING CASE WHEN conjunctive THEN COUNT(DISTINCT termid) = (SELECT COUNT(*) FROM tokens) ELSE 1 END
|
|
),
|
|
subscores AS (
|
|
SELECT docs.docid, EXP(LN(docs.docid)*stats.slope + stats.intercept) AS newlen,
|
|
term_tf.termid, tf, df,
|
|
LN(1 + (lambda * tf * (SELECT sumdf FROM fts_main_documents.stats)) / ((1-lambda) * df * newlen)) AS subscore
|
|
FROM term_tf, cdocs, fts_main_documents.docs AS docs, qtermids,
|
|
fts_main_documents.stats AS stats
|
|
WHERE term_tf.docid = cdocs.docid
|
|
AND term_tf.docid = docs.docid
|
|
AND term_tf.termid = qtermids.termid
|
|
),
|
|
scores AS (
|
|
SELECT docid, LN(ANY_VALUE(newlen)) + sum(subscore) AS score
|
|
FROM subscores
|
|
GROUP BY docid
|
|
)
|
|
SELECT score
|
|
FROM scores, fts_main_documents.docs AS docs
|
|
WHERE scores.docid = docs.docid
|
|
AND docs.name = docname
|
|
)"""
|
|
)
|
|
|
|
|
|
def replace_lm_fitted_prior(con, stemmer='none'):
|
|
"""
|
|
Only use fitted prior, but keep on using the old document lengths.
|
|
"""
|
|
sql = f"""
|
|
CREATE OR REPLACE MACRO fts_main_documents.match_lm(docname, query_string, fields := NULL, lambda := 0.3, conjunctive := 0) AS (
|
|
WITH tokens AS (
|
|
SELECT stem(unnest(fts_main_documents.tokenize(query_string)), '{stemmer}') AS t
|
|
),
|
|
fieldids AS (
|
|
SELECT fieldid
|
|
FROM fts_main_documents.fields
|
|
WHERE CASE WHEN fields IS NULL THEN 1 ELSE field IN (SELECT * FROM (SELECT UNNEST(string_split(fields, ','))) AS fsq) END
|
|
),
|
|
qtermids AS (
|
|
SELECT termid, df, COUNT(*) AS qtf
|
|
FROM fts_main_documents.dict AS dict,
|
|
tokens
|
|
WHERE dict.term = tokens.t
|
|
GROUP BY termid, df
|
|
),
|
|
qterms AS (
|
|
SELECT termid,
|
|
docid
|
|
FROM fts_main_documents.terms AS terms
|
|
WHERE CASE WHEN fields IS NULL THEN 1 ELSE fieldid IN (SELECT * FROM fieldids) END
|
|
AND termid IN (SELECT qtermids.termid FROM qtermids)
|
|
),
|
|
term_tf AS (
|
|
SELECT termid, docid, COUNT(*) AS tf
|
|
FROM qterms
|
|
GROUP BY docid, termid
|
|
),
|
|
cdocs AS (
|
|
SELECT docid
|
|
FROM qterms
|
|
GROUP BY docid
|
|
HAVING CASE WHEN conjunctive THEN COUNT(DISTINCT termid) = (SELECT COUNT(*) FROM tokens) ELSE 1 END
|
|
),
|
|
subscores AS (
|
|
SELECT docs.docid, docs.len, term_tf.termid, term_tf.tf, qtermids.df,
|
|
qtermids.qtf * LN(1 + (lambda * tf * (SELECT ANY_VALUE(sumdf) FROM fts_main_documents.stats)) / ((1-lambda) * df * len)) AS subscore
|
|
FROM term_tf, cdocs, fts_main_documents.docs AS docs, qtermids
|
|
WHERE term_tf.docid = cdocs.docid
|
|
AND term_tf.docid = docs.docid
|
|
AND term_tf.termid = qtermids.termid
|
|
),
|
|
scores AS (
|
|
SELECT docid, (LN(docid)*(SELECT ANY_VALUE(slope) FROM fts_main_documents.stats)) + sum(subscore) AS score
|
|
FROM subscores
|
|
GROUP BY docid
|
|
)
|
|
SELECT score
|
|
FROM scores, fts_main_documents.docs AS docs
|
|
WHERE scores.docid = docs.docid
|
|
AND docs.name = docname
|
|
)
|
|
"""
|
|
con.sql(sql)
|
|
|
|
|
|
def renumber_doc_ids(con, column):
|
|
con.sql(f"""
|
|
-- renumber document ids by decreasing len/prior column
|
|
CREATE TABLE fts_main_documents.docs_new AS
|
|
SELECT ROW_NUMBER() over (ORDER BY "{column}" DESC, name ASC) newid, docs.*
|
|
FROM fts_main_documents.docs AS docs;
|
|
-- update postings
|
|
CREATE TABLE fts_main_documents.terms_new AS
|
|
SELECT D.newid as docid, T.fieldid, T.termid
|
|
FROM fts_main_documents.terms T, fts_main_documents.docs_new D
|
|
WHERE T.docid = D.docid
|
|
ORDER BY T.termid;
|
|
-- replace old by new data
|
|
ALTER TABLE fts_main_documents.docs_new DROP COLUMN docid;
|
|
ALTER TABLE fts_main_documents.docs_new RENAME COLUMN newid TO docid;
|
|
DROP TABLE fts_main_documents.docs;
|
|
DROP TABLE fts_main_documents.terms;
|
|
ALTER TABLE fts_main_documents.docs_new RENAME TO docs;
|
|
ALTER TABLE fts_main_documents.terms_new RENAME TO terms;
|
|
UPDATE fts_main_documents.stats SET index_type = 'fitted';
|
|
""")
|
|
|
|
|
|
def reindex_fitted_column(name_in, name_out, column='prior', total=None,
|
|
print_sample=False, threshold=0, qrels=None):
|
|
if column not in ['len', 'prior']:
|
|
raise ValueError(f'Column "{column}" not allowed: use len or prior.')
|
|
copy_file(name_in, name_out)
|
|
con = duckdb.connect(name_out)
|
|
renumber_doc_ids(con, column)
|
|
try:
|
|
con.sql("""
|
|
ALTER TABLE fts_main_documents.stats ADD intercept DOUBLE;
|
|
ALTER TABLE fts_main_documents.stats ADD slope DOUBLE;
|
|
""")
|
|
except duckdb.duckdb.CatalogException as e:
|
|
print ("Warning: " + str(e), file=sys.stderr)
|
|
if qrels:
|
|
insert_qrels(con, qrels)
|
|
if total:
|
|
sample_by_fixed_points_qrels(con, total)
|
|
else:
|
|
raise ValueError("Not implemented.")
|
|
else:
|
|
if total:
|
|
sample_by_fixed_points(con, column, threshold, total)
|
|
else:
|
|
sample_by_values(con, column, threshold)
|
|
if print_sample:
|
|
print_sample_tsv(con, total)
|
|
train_linear_regression(con)
|
|
con.sql(f"""
|
|
DROP VIEW sample;
|
|
ALTER TABLE fts_main_documents.docs DROP COLUMN "{column}";
|
|
""")
|
|
stemmer = get_stats_stemmer(con)
|
|
if column == 'len':
|
|
replace_lm_fitted_doclen(con, stemmer=stemmer)
|
|
replace_bm25_fitted_doclen(con, stemmer=stemmer)
|
|
else:
|
|
replace_lm_fitted_prior(con, stemmer=stemmer)
|
|
con.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
reindex_fitted_column('robustZE.db', 'robustZE_fitted20.db', column='len', total=None, print_sample=True, threshold=20, qrels=None)
|