mirror of
https://github.com/ArthurIdema/Zoekeend-Phrase-Indexing.git
synced 2025-10-26 16:24:21 +00:00
100 lines
2.8 KiB
Python
100 lines
2.8 KiB
Python
"""
|
|
Zoekeend searcher.
|
|
Author: Djoerd Hiemstra
|
|
"""
|
|
|
|
import sys
|
|
|
|
import duckdb
|
|
import ir_datasets
|
|
|
|
|
|
def duckdb_search_lm(con, query, limit):
|
|
sql = """
|
|
SELECT docname, score, postings_cost
|
|
FROM fts_main_documents.match_lm($1)
|
|
ORDER BY score DESC
|
|
LIMIT $2
|
|
"""
|
|
return con.execute(sql, [query, limit]).fetchall()
|
|
|
|
# def duckdb_search_lm(con, query, limit, l):
|
|
# print(f"Searching for: {query} with limit {limit} and l={l}")
|
|
# sql = """
|
|
# SELECT docname, score, postings_cost
|
|
# FROM fts_main_documents.match_lm(docname, $1)
|
|
# ORDER BY score DESC
|
|
# LIMIT $2
|
|
# """
|
|
# return con.execute(sql, [query, limit]).fetchall()
|
|
|
|
def duckdb_search_bm25(con, query, limit, b, k):
|
|
sql = """
|
|
SELECT did, score
|
|
FROM (
|
|
SELECT did, fts_main_documents.match_bm25(did, $1, b=$2, k=$3) AS score
|
|
FROM documents) sq
|
|
WHERE score IS NOT NULL
|
|
ORDER BY score DESC
|
|
LIMIT $4
|
|
"""
|
|
return con.execute(sql, [query, b, k, limit]).fetchall()
|
|
|
|
class Query:
|
|
def __init__(self, query_id, text):
|
|
self.query_id = query_id
|
|
self.text = text
|
|
|
|
|
|
def get_queries_from_file(query_file):
|
|
with open(query_file, "r") as file:
|
|
for line in file:
|
|
(query_id, text) = line.split('\t')
|
|
yield Query(query_id, text)
|
|
|
|
|
|
def get_queries(query_tag):
|
|
if query_tag == "custom":
|
|
from ze_eval import ir_dataset_test
|
|
return ir_dataset_test().queries_iter()
|
|
try:
|
|
return ir_datasets.load(query_tag).queries_iter()
|
|
except KeyError:
|
|
pass
|
|
return get_queries_from_file(query_tag)
|
|
|
|
|
|
def search_run(db_name, query_tag, matcher='lm', run_tag=None,
|
|
b=0.75, k=1.2, limit=1000, fileout=None,
|
|
startq=None, endq=None):
|
|
con = duckdb.connect(db_name, read_only=True)
|
|
if fileout:
|
|
file = open(fileout, "w")
|
|
else:
|
|
file = sys.stdout
|
|
if not run_tag:
|
|
run_tag = matcher
|
|
queries = get_queries(query_tag)
|
|
for query in queries:
|
|
qid = query.query_id
|
|
if (startq and int(qid) < startq) or (endq and int(qid) > endq):
|
|
continue
|
|
if hasattr(query, 'title'):
|
|
q_string = query.title
|
|
else:
|
|
q_string = query.text
|
|
if matcher == 'lm':
|
|
hits = duckdb_search_lm(con, q_string, limit)
|
|
elif matcher == 'bm25':
|
|
hits = duckdb_search_bm25(con, q_string, limit, b, k)
|
|
else:
|
|
raise ValueError(f"Unknown match function: {matcher}")
|
|
for rank, (docno, score, postings_cost) in enumerate(hits):
|
|
file.write(f'{qid} Q0 {docno} {rank} {score} {run_tag} {postings_cost}\n')
|
|
con.close()
|
|
file.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
search_run('cran.db', 'cranfield.tsv')
|