文件
csp/scripts/import_local_pdf_rag.py

1276 行
39 KiB
Python

此文件含有模棱两可的 Unicode 字符
此文件含有可能会与其他字符混淆的 Unicode 字符。 如果您是想特意这样的,可以安全地忽略该警告。 使用 Escape 按钮显示他们。
#!/usr/bin/env python3
"""Generate CSP J/S problems from local PDFs via RAG + LLM into SQLite."""
from __future__ import annotations
import argparse
import json
import os
import random
import re
import sqlite3
import subprocess
import time
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from dataclasses import dataclass
from difflib import SequenceMatcher
from pathlib import Path
from typing import Any
import requests
RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}
@dataclass
class ExistingProblem:
id: int
title: str
statement_md: str
@dataclass
class CorpusEntry:
path: str
text: str
exam_hint: str
@dataclass
class CandidateProblem:
title: str
difficulty: int
statement_md: str
sample_input: str
sample_output: str
answer: str
explanation: str
knowledge_points: list[str]
tags: list[str]
exam_type: str
source_paths: list[str]
source: str
llm_error: str
@dataclass
class AttemptResult:
attempt_id: int
source_paths: list[str]
candidate: CandidateProblem | None
error: str
def now_sec() -> int:
return int(time.time())
def normalize(text: str) -> str:
text = text.lower().strip()
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^0-9a-z\u4e00-\u9fff ]+", " ", text)
return re.sub(r"\s+", " ", text).strip()
def similarity(a: str, b: str) -> float:
if not a or not b:
return 0.0
return SequenceMatcher(None, normalize(a), normalize(b)).ratio()
def to_exam_type(text: str) -> str:
t = text.lower()
if "csp-j" in t or "noip-j" in t or "junior" in t or "普及" in t:
return "csp-j"
if "csp-s" in t or "noip-s" in t or "senior" in t or "提高" in t:
return "csp-s"
return ""
def ensure_problem_columns(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute("PRAGMA table_info(problems)")
cols = {str(row[1]) for row in cur.fetchall()}
needed = {
"sample_input": "ALTER TABLE problems ADD COLUMN sample_input TEXT NOT NULL DEFAULT ''",
"sample_output": "ALTER TABLE problems ADD COLUMN sample_output TEXT NOT NULL DEFAULT ''",
"statement_url": "ALTER TABLE problems ADD COLUMN statement_url TEXT NOT NULL DEFAULT ''",
"llm_profile_json": "ALTER TABLE problems ADD COLUMN llm_profile_json TEXT NOT NULL DEFAULT '{}'",
}
for col, sql in needed.items():
if col not in cols:
cur.execute(sql)
conn.commit()
def ensure_core_tables(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS problems (
id INTEGER PRIMARY KEY AUTOINCREMENT,
slug TEXT NOT NULL UNIQUE,
title TEXT NOT NULL,
statement_md TEXT NOT NULL,
difficulty INTEGER NOT NULL DEFAULT 1,
source TEXT NOT NULL DEFAULT '',
statement_url TEXT NOT NULL DEFAULT '',
llm_profile_json TEXT NOT NULL DEFAULT '{}',
sample_input TEXT NOT NULL DEFAULT '',
sample_output TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS problem_tags (
problem_id INTEGER NOT NULL,
tag TEXT NOT NULL,
PRIMARY KEY(problem_id, tag)
)
"""
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_problem_tags_tag ON problem_tags(tag)")
conn.commit()
def ensure_import_tables(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS import_jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
status TEXT NOT NULL,
trigger TEXT NOT NULL DEFAULT 'manual',
total_count INTEGER NOT NULL DEFAULT 0,
processed_count INTEGER NOT NULL DEFAULT 0,
success_count INTEGER NOT NULL DEFAULT 0,
failed_count INTEGER NOT NULL DEFAULT 0,
options_json TEXT NOT NULL DEFAULT '{}',
last_error TEXT NOT NULL DEFAULT '',
started_at INTEGER NOT NULL,
finished_at INTEGER,
updated_at INTEGER NOT NULL,
created_at INTEGER NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS import_job_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
source_path TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
title TEXT NOT NULL DEFAULT '',
difficulty INTEGER NOT NULL DEFAULT 0,
problem_id INTEGER,
error_text TEXT NOT NULL DEFAULT '',
started_at INTEGER,
finished_at INTEGER,
updated_at INTEGER NOT NULL,
created_at INTEGER NOT NULL,
UNIQUE(job_id, source_path)
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_import_jobs_created_at ON import_jobs(created_at DESC)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_import_job_items_job_status "
"ON import_job_items(job_id, status, updated_at DESC)"
)
conn.commit()
def create_import_job(
conn: sqlite3.Connection, trigger: str, total_count: int, options_json: str
) -> int:
ts = now_sec()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO import_jobs(
status,trigger,total_count,processed_count,success_count,failed_count,
options_json,last_error,started_at,finished_at,updated_at,created_at
) VALUES(?,?,?,?,?,?,?,?,?,?,?,?)
""",
(
"running",
trigger or "manual",
max(0, total_count),
0,
0,
0,
options_json,
"",
ts,
None,
ts,
ts,
),
)
conn.commit()
return int(cur.lastrowid)
def insert_item(
conn: sqlite3.Connection,
job_id: int,
source_path: str,
status: str,
title: str,
difficulty: int,
problem_id: int | None,
error_text: str,
) -> None:
ts = now_sec()
conn.execute(
"""
INSERT OR REPLACE INTO import_job_items(
job_id,source_path,status,title,difficulty,problem_id,error_text,
started_at,finished_at,updated_at,created_at
) VALUES(?,?,?,?,?,?,?,?,?,?,?)
""",
(
job_id,
source_path,
status,
title,
int(difficulty),
problem_id,
error_text[:500],
ts,
ts,
ts,
ts,
),
)
conn.commit()
def update_import_job_progress(
conn: sqlite3.Connection,
job_id: int,
total_count: int,
processed_count: int,
success_count: int,
failed_count: int,
last_error: str,
) -> None:
ts = now_sec()
conn.execute(
"""
UPDATE import_jobs
SET total_count=?,
processed_count=?,
success_count=?,
failed_count=?,
last_error=?,
updated_at=?
WHERE id=?
""",
(
max(0, total_count),
max(0, processed_count),
max(0, success_count),
max(0, failed_count),
last_error[:500],
ts,
job_id,
),
)
conn.commit()
def finish_import_job(
conn: sqlite3.Connection,
job_id: int,
processed_count: int,
success_count: int,
failed_count: int,
last_error: str,
reached_target: bool,
) -> None:
ts = now_sec()
status = "completed" if reached_target and failed_count == 0 else "completed_with_errors"
conn.execute(
"""
UPDATE import_jobs
SET status=?,
total_count=?,
processed_count=?,
success_count=?,
failed_count=?,
last_error=?,
finished_at=?,
updated_at=?
WHERE id=?
""",
(
status,
max(processed_count, 0),
processed_count,
success_count,
failed_count,
last_error[:500],
ts,
ts,
job_id,
),
)
conn.commit()
def mark_job_failed_early(conn: sqlite3.Connection, job_id: int, reason: str) -> None:
ts = now_sec()
conn.execute(
"""
UPDATE import_jobs
SET status='failed',
last_error=?,
finished_at=?,
updated_at=?
WHERE id=?
""",
(reason[:500], ts, ts, job_id),
)
conn.commit()
def load_existing(conn: sqlite3.Connection) -> list[ExistingProblem]:
cur = conn.execute("SELECT id,title,statement_md FROM problems")
rows: list[ExistingProblem] = []
for row in cur.fetchall():
rows.append(
ExistingProblem(
id=int(row[0]),
title=str(row[1] or ""),
statement_md=str(row[2] or ""),
)
)
return rows
def load_exam_type_weights(conn: sqlite3.Connection) -> tuple[int, int]:
j_cnt = 0
s_cnt = 0
cur = conn.execute(
"""
SELECT lower(tag), COUNT(1)
FROM problem_tags
WHERE lower(tag) IN ('csp-j','csp-s')
GROUP BY lower(tag)
"""
)
for row in cur.fetchall():
tag = str(row[0] or "")
cnt = int(row[1] or 0)
if tag == "csp-j":
j_cnt = cnt
elif tag == "csp-s":
s_cnt = cnt
return j_cnt, s_cnt
def load_difficulty_pool(conn: sqlite3.Connection, exam_type: str) -> list[int]:
cur = conn.execute(
"""
SELECT p.difficulty, COUNT(1)
FROM problems p
JOIN problem_tags t ON t.problem_id=p.id
WHERE lower(t.tag)=?
GROUP BY p.difficulty
""",
(exam_type,),
)
pool: list[int] = []
for row in cur.fetchall():
diff = max(1, min(10, int(row[0] or 1)))
cnt = max(1, int(row[1] or 1))
pool.extend([diff] * min(60, cnt))
return pool
def discover_pdf_files(pdf_dir: Path) -> list[Path]:
if not pdf_dir.exists() or not pdf_dir.is_dir():
return []
rows = [p for p in pdf_dir.rglob("*.pdf") if p.is_file()]
rows.sort(key=lambda p: str(p))
return rows
def build_corpus_from_existing_problems(
existing: list[ExistingProblem], max_chars_per_item: int
) -> list[CorpusEntry]:
corpus: list[CorpusEntry] = []
if not existing:
return corpus
cap = min(800, len(existing))
start = max(0, len(existing) - cap)
for p in existing[start:]:
text = (p.statement_md or "").strip()
if len(text) < 80:
continue
snippet = text[: max(500, max_chars_per_item)]
hint_source = f"{p.title} {snippet[:200]}"
corpus.append(
CorpusEntry(
path=f"db:problem:{p.id}:{p.title[:64]}",
text=snippet,
exam_hint=to_exam_type(hint_source),
)
)
return corpus
def extract_pdf_text(path: Path, max_pages: int, max_chars: int) -> str:
cmd = [
"pdftotext",
"-layout",
"-f",
"1",
"-l",
str(max(1, max_pages)),
str(path),
"-",
]
proc = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
timeout=60,
)
if proc.returncode != 0:
return ""
text = proc.stdout.decode("utf-8", errors="ignore")
text = text.replace("\r\n", "\n").replace("\r", "\n")
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r"[ \t]+", " ", text)
text = text.strip()
if not text:
return ""
return text[: max(500, max_chars)]
def build_corpus(files: list[Path], workers: int, max_pages: int, max_chars: int) -> list[CorpusEntry]:
corpus: list[CorpusEntry] = []
def work(path: Path) -> CorpusEntry | None:
text = extract_pdf_text(path, max_pages=max_pages, max_chars=max_chars)
if len(text) < 80:
return None
return CorpusEntry(path=str(path), text=text, exam_hint=to_exam_type(str(path)))
with ThreadPoolExecutor(max_workers=max(1, workers)) as executor:
futures = [executor.submit(work, path) for path in files]
for future in futures:
row = future.result()
if row is not None:
corpus.append(row)
return corpus
def requests_retry_post(
url: str,
headers: dict[str, str],
body: dict[str, Any],
timeout: int,
retries: int,
sleep_sec: float,
) -> dict[str, Any]:
last_error: Exception | None = None
for i in range(1, retries + 1):
try:
resp = requests.post(url, headers=headers, json=body, timeout=timeout)
except requests.RequestException as exc:
last_error = exc
if i < retries:
time.sleep(sleep_sec * i)
continue
raise RuntimeError(f"llm request failed: {exc}") from exc
if resp.status_code in RETRYABLE_HTTP_CODES:
if i < retries:
time.sleep(sleep_sec * i)
continue
raise RuntimeError(f"llm retry exhausted: HTTP {resp.status_code}")
if resp.status_code >= 400:
raise RuntimeError(f"llm request failed: HTTP {resp.status_code}: {resp.text[:300]}")
payload = resp.json()
if isinstance(payload, dict):
return payload
raise RuntimeError("llm response is not JSON object")
if last_error:
raise RuntimeError(f"llm request failed: {last_error}") from last_error
raise RuntimeError("llm request failed")
def extract_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if raw.startswith("```"):
raw = re.sub(r"^```[a-zA-Z0-9_-]*", "", raw).strip()
raw = raw.removesuffix("```").strip()
try:
obj = json.loads(raw)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
pass
match = re.search(r"\{[\s\S]*\}", text)
if not match:
return None
try:
obj = json.loads(match.group(0))
return obj if isinstance(obj, dict) else None
except json.JSONDecodeError:
return None
def maybe_duplicate(
existing: list[ExistingProblem],
title: str,
statement_md: str,
threshold: float,
) -> tuple[bool, int | None, float]:
best_id = None
best_score = 0.0
statement_head = statement_md[:1600]
for p in existing:
t_sim = similarity(title, p.title)
s_sim = similarity(statement_head, p.statement_md[:1600])
score = max(t_sim, s_sim * 0.85 + t_sim * 0.15)
if score > best_score:
best_score = score
best_id = p.id
return best_score >= threshold, best_id, best_score
def choose_exam_type(rng: random.Random, j_weight: int, s_weight: int) -> str:
j = max(1, j_weight)
s = max(1, s_weight)
return "csp-j" if rng.randint(1, j + s) <= j else "csp-s"
def choose_difficulty(rng: random.Random, exam_type: str, pools: dict[str, list[int]]) -> int:
pool = pools.get(exam_type) or []
if pool:
return int(rng.choice(pool))
if exam_type == "csp-j":
return int(rng.choice([2, 2, 3, 3, 4]))
return int(rng.choice([3, 4, 4, 5, 6]))
def choose_context(
rng: random.Random,
corpus: list[CorpusEntry],
exam_type: str,
max_context_chars: int,
) -> list[CorpusEntry]:
preferred = [x for x in corpus if x.exam_hint == exam_type]
candidates = preferred if len(preferred) >= 2 else corpus
if not candidates:
return []
picks = rng.sample(candidates, k=min(3, len(candidates)))
used = 0
out: list[CorpusEntry] = []
for row in picks:
left = max_context_chars - used
if left <= 0:
break
clipped = row.text[:left]
out.append(CorpusEntry(path=row.path, text=clipped, exam_hint=row.exam_hint))
used += len(clipped)
return out
def build_prompt(
exam_type: str,
difficulty: int,
context_rows: list[CorpusEntry],
avoid_titles: list[str],
) -> str:
ctx_parts: list[str] = []
for idx, row in enumerate(context_rows, start=1):
ctx_parts.append(f"[参考材料{idx}] 来源: {row.path}\n{row.text}")
ctx_text = "\n\n".join(ctx_parts)
avoid_text = "".join(avoid_titles[:25]) if avoid_titles else ""
exam_cn = "CSP-J" if exam_type == "csp-j" else "CSP-S"
return f"""
你是资深 OI 出题人。请根据给定本地 PDF 参考材料,生成 1 道原创 {exam_cn} 题目。
要求:
1. 难度尽量贴近 {difficulty}(范围 1~10
2. 题型风格贴合 {exam_cn}
3. 严禁与已有题目雷同,尤其避免这些标题:{avoid_text}
4. 输出必须是 JSON 对象,不要输出任何额外文本。
5. JSON 字段固定:
{{
"title": "题目标题",
"difficulty": 3,
"statement_md": "完整 Markdown 题面(题目描述/输入格式/输出格式/数据范围)",
"sample_input": "样例输入",
"sample_output": "样例输出",
"answer": "核心答案要点",
"explanation": "简要讲解",
"knowledge_points": ["知识点1","知识点2"],
"tags": ["csp-j 或 csp-s","算法标签"]
}}
参考材料RAG
{ctx_text}
""".strip()
def llm_generate(prompt: str, timeout: int, retries: int, sleep_sec: float) -> dict[str, Any]:
url = os.getenv("OI_LLM_API_URL", "").strip()
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
if not url:
raise RuntimeError("missing OI_LLM_API_URL")
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
body = {
"model": model,
"stream": False,
"temperature": 0.7,
"messages": [
{"role": "system", "content": "你是 OI/CSP 出题助手。只输出 JSON。"},
{"role": "user", "content": prompt},
],
}
payload = requests_retry_post(
url=url,
headers=headers,
body=body,
timeout=timeout,
retries=retries,
sleep_sec=sleep_sec,
)
choices = payload.get("choices") if isinstance(payload, dict) else None
if not isinstance(choices, list) or not choices:
raise RuntimeError("llm response missing choices")
content = ((choices[0] or {}).get("message") or {}).get("content")
if not content:
raise RuntimeError("llm response missing content")
obj = extract_json_object(str(content))
if not isinstance(obj, dict):
raise RuntimeError("llm returned non-json content")
return obj
def fallback_candidate(
attempt_id: int,
exam_type: str,
difficulty: int,
source_paths: list[str],
llm_error: str,
) -> CandidateProblem:
exam_cn = "CSP-J" if exam_type == "csp-j" else "CSP-S"
title = f"{exam_cn} 训练题·生成草案 {now_sec()}-{attempt_id}"
statement_md = """
# 题目描述
给定一个长度为 n 的整数序列,请你计算满足特定条件的子区间数量。
## 输入格式
第一行一个整数 n。
第二行 n 个整数 a_i。
## 输出格式
输出一个整数,表示答案。
## 数据范围
- 1 <= n <= 2e5
- |a_i| <= 1e9
## 提示
可尝试前缀和、哈希计数或双指针等思路。
""".strip()
return CandidateProblem(
title=title,
difficulty=max(1, min(10, difficulty)),
statement_md=statement_md,
sample_input="6\n1 2 3 4 5 6\n",
sample_output="0\n",
answer="可从前缀和统计角度考虑。",
explanation="使用前缀和 + 哈希统计满足条件的区间。",
knowledge_points=["前缀和", "哈希"],
tags=[exam_type, "fallback-generated", "local-pdf-rag"],
exam_type=exam_type,
source_paths=source_paths,
source="fallback",
llm_error=llm_error,
)
def build_candidate_from_json(
obj: dict[str, Any],
exam_type: str,
difficulty_hint: int,
source_paths: list[str],
source: str,
llm_error: str,
) -> CandidateProblem:
title = str(obj.get("title") or "").strip()
if not title:
raise RuntimeError("generated title is empty")
difficulty = int(obj.get("difficulty") or difficulty_hint or 3)
difficulty = max(1, min(10, difficulty))
statement_md = str(obj.get("statement_md") or "").strip()
if not statement_md:
desc = str(obj.get("description") or "").strip()
in_fmt = str(obj.get("input_format") or "").strip()
out_fmt = str(obj.get("output_format") or "").strip()
statement_md = (
"# 题目描述\n"
+ (desc or "请根据题意补全。")
+ "\n\n## 输入格式\n"
+ (in_fmt or "见题目描述")
+ "\n\n## 输出格式\n"
+ (out_fmt or "见题目描述")
).strip()
sample_input = str(obj.get("sample_input") or "").strip()
sample_output = str(obj.get("sample_output") or "").strip()
answer = str(obj.get("answer") or "").strip()
explanation = str(obj.get("explanation") or "").strip()
kp_raw = obj.get("knowledge_points")
knowledge_points = [str(x).strip() for x in kp_raw] if isinstance(kp_raw, list) else []
knowledge_points = [x for x in knowledge_points if x][:8]
tags_raw = obj.get("tags")
tags = [str(x).strip() for x in tags_raw] if isinstance(tags_raw, list) else []
tags = [x for x in tags if x][:12]
tags.extend([exam_type, "local-pdf-rag"])
normalized_tags: list[str] = []
seen: set[str] = set()
for tag in tags:
n = normalize(tag).replace(" ", "-")
if not n or n in seen:
continue
seen.add(n)
normalized_tags.append(n)
return CandidateProblem(
title=title,
difficulty=difficulty,
statement_md=statement_md,
sample_input=sample_input,
sample_output=sample_output,
answer=answer,
explanation=explanation,
knowledge_points=knowledge_points,
tags=normalized_tags,
exam_type=exam_type,
source_paths=source_paths,
source=source,
llm_error=llm_error,
)
def make_slug(title: str, attempt_id: int) -> str:
stem = normalize(title).replace(" ", "-")
stem = re.sub(r"[^a-z0-9\-]+", "", stem).strip("-")
if not stem:
stem = "local-pdf-rag"
return f"{stem[:48]}-{now_sec()}-{attempt_id}"
def insert_problem(conn: sqlite3.Connection, cand: CandidateProblem, attempt_id: int) -> int:
ts = now_sec()
profile = {
"schema_version": 1,
"platform": "llm-local-pdf-rag" if cand.source == "llm" else "fallback-local-pdf-rag",
"difficulty": cand.difficulty,
"answer": cand.answer,
"explanation": cand.explanation,
"knowledge_points": cand.knowledge_points,
"tags": cand.tags,
"exam_type": cand.exam_type,
"source_paths": cand.source_paths,
"generated_at": ts,
}
if cand.llm_error:
profile["llm_error"] = cand.llm_error[:300]
cur = conn.cursor()
cur.execute(
"""
INSERT INTO problems(
slug,title,statement_md,difficulty,source,statement_url,llm_profile_json,
sample_input,sample_output,created_at
) VALUES(?,?,?,?,?,?,?,?,?,?)
""",
(
make_slug(cand.title, attempt_id),
cand.title,
cand.statement_md,
cand.difficulty,
f"llm:{cand.exam_type}:local-pdf-rag",
"",
json.dumps(profile, ensure_ascii=False),
cand.sample_input,
cand.sample_output,
ts,
),
)
problem_id = int(cur.lastrowid)
for tag in cand.tags:
cur.execute(
"INSERT OR IGNORE INTO problem_tags(problem_id,tag) VALUES(?,?)",
(problem_id, tag),
)
conn.commit()
return problem_id
def run_attempt(
attempt_id: int,
corpus: list[CorpusEntry],
max_context_chars: int,
j_weight: int,
s_weight: int,
pools: dict[str, list[int]],
existing_titles: list[str],
timeout: int,
retries: int,
retry_sleep_sec: float,
) -> AttemptResult:
rng = random.Random(now_sec() + attempt_id * 131)
exam_type = choose_exam_type(rng, j_weight, s_weight)
difficulty = choose_difficulty(rng, exam_type, pools)
context_rows = choose_context(rng, corpus, exam_type, max_context_chars=max_context_chars)
if not context_rows:
return AttemptResult(attempt_id=attempt_id, source_paths=[], candidate=None, error="no context")
source_paths = [x.path for x in context_rows]
avoid_titles = rng.sample(existing_titles, k=min(30, len(existing_titles))) if existing_titles else []
prompt = build_prompt(exam_type, difficulty, context_rows, avoid_titles)
source = "llm"
llm_error = ""
try:
obj = llm_generate(prompt, timeout=timeout, retries=retries, sleep_sec=retry_sleep_sec)
except Exception as exc:
source = "fallback"
llm_error = str(exc)
return AttemptResult(
attempt_id=attempt_id,
source_paths=source_paths,
candidate=fallback_candidate(
attempt_id=attempt_id,
exam_type=exam_type,
difficulty=difficulty,
source_paths=source_paths,
llm_error=llm_error,
),
error="",
)
try:
candidate = build_candidate_from_json(
obj=obj,
exam_type=exam_type,
difficulty_hint=difficulty,
source_paths=source_paths,
source=source,
llm_error=llm_error,
)
except Exception as exc:
return AttemptResult(
attempt_id=attempt_id,
source_paths=source_paths,
candidate=None,
error=f"invalid generated payload: {exc}",
)
return AttemptResult(attempt_id=attempt_id, source_paths=source_paths, candidate=candidate, error="")
def main() -> int:
parser = argparse.ArgumentParser(description="Generate CSP J/S from local PDF RAG")
parser.add_argument("--db-path", required=True)
parser.add_argument("--pdf-dir", default="/data/local_pdfs")
parser.add_argument("--target-total", type=int, default=5000)
parser.add_argument("--workers", type=int, default=3)
parser.add_argument("--dedupe-threshold", type=float, default=0.72)
parser.add_argument("--max-attempt-multiplier", type=int, default=8)
parser.add_argument("--max-pages", type=int, default=6)
parser.add_argument("--max-chars-per-pdf", type=int, default=12000)
parser.add_argument("--max-context-chars", type=int, default=12000)
parser.add_argument("--timeout", type=int, default=90)
parser.add_argument("--retries", type=int, default=4)
parser.add_argument("--retry-sleep-sec", type=float, default=1.5)
parser.add_argument("--job-trigger", default="manual")
args = parser.parse_args()
workers = max(1, min(16, args.workers))
target_total = max(1, min(50000, args.target_total))
dedupe_threshold = max(0.1, min(0.99, args.dedupe_threshold))
conn = sqlite3.connect(args.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys=ON")
conn.execute("PRAGMA busy_timeout=5000")
ensure_core_tables(conn)
ensure_problem_columns(conn)
ensure_import_tables(conn)
existing = load_existing(conn)
current_total = len(existing)
need = max(0, target_total - current_total)
attempts_budget = max(need, need * max(2, args.max_attempt_multiplier))
options_json = json.dumps(
{
"source": "local_pdf_rag",
"mode": "local_pdf_rag",
"pdf_dir": args.pdf_dir,
"local_pdf_dir": args.pdf_dir,
"workers": workers,
"target_total": target_total,
"max_attempt_multiplier": max(2, args.max_attempt_multiplier),
"dedupe_threshold": dedupe_threshold,
},
ensure_ascii=False,
)
job_id = create_import_job(conn, args.job_trigger, attempts_budget, options_json)
if need <= 0:
finish_import_job(
conn,
job_id=job_id,
processed_count=0,
success_count=0,
failed_count=0,
last_error=f"target reached: current={current_total}, target={target_total}",
reached_target=True,
)
print(
json.dumps(
{
"job_id": job_id,
"target_total": target_total,
"current_total": current_total,
"inserted": 0,
"message": "target already reached",
},
ensure_ascii=False,
indent=2,
)
)
conn.close()
return 0
pdf_files = discover_pdf_files(Path(args.pdf_dir))
corpus: list[CorpusEntry] = []
fallback_reason = ""
if pdf_files:
corpus = build_corpus(
files=pdf_files,
workers=workers,
max_pages=args.max_pages,
max_chars=args.max_chars_per_pdf,
)
if not corpus:
fallback_reason = "all pdf extraction failed or too short, fallback to existing problems"
else:
fallback_reason = f"no pdf files found in: {args.pdf_dir}, fallback to existing problems"
if not corpus:
corpus = build_corpus_from_existing_problems(existing, args.max_chars_per_pdf)
if not corpus:
reason = f"{fallback_reason}; and no usable existing problems" if fallback_reason else \
"no corpus available from pdfs or existing problems"
mark_job_failed_early(conn, job_id, reason)
print(json.dumps({"job_id": job_id, "error": reason}, ensure_ascii=False, indent=2))
conn.close()
return 1
if fallback_reason:
print(
json.dumps(
{
"job_id": job_id,
"warning": fallback_reason,
"fallback_corpus_size": len(corpus),
},
ensure_ascii=False,
indent=2,
)
)
j_weight, s_weight = load_exam_type_weights(conn)
pools = {
"csp-j": load_difficulty_pool(conn, "csp-j"),
"csp-s": load_difficulty_pool(conn, "csp-s"),
}
inserted = 0
skipped_duplicate = 0
failed = 0
processed = 0
remaining_needed = need
last_error = ""
next_attempt = 1
max_attempt = max(need, attempts_budget)
futures: dict[Future[AttemptResult], int] = {}
def submit_one(executor: ThreadPoolExecutor, attempt_id: int) -> None:
titles_snapshot = [x.title for x in existing]
future = executor.submit(
run_attempt,
attempt_id,
corpus,
max(1000, args.max_context_chars),
j_weight,
s_weight,
pools,
titles_snapshot,
args.timeout,
args.retries,
args.retry_sleep_sec,
)
futures[future] = attempt_id
with ThreadPoolExecutor(max_workers=workers) as executor:
while next_attempt <= max_attempt and len(futures) < workers:
submit_one(executor, next_attempt)
next_attempt += 1
while futures:
done, _ = wait(list(futures.keys()), return_when=FIRST_COMPLETED)
for future in done:
attempt_id = futures.pop(future)
processed += 1
source_key = f"attempt-{attempt_id:06d}"
try:
result = future.result()
except Exception as exc:
failed += 1
last_error = str(exc)
insert_item(
conn,
job_id,
source_path=source_key,
status="failed",
title="",
difficulty=0,
problem_id=None,
error_text=f"worker panic: {exc}",
)
update_import_job_progress(
conn,
job_id=job_id,
total_count=max_attempt,
processed_count=processed,
success_count=inserted,
failed_count=failed,
last_error=last_error,
)
continue
joined_source = " | ".join(result.source_paths[:2])
source_path = f"{source_key}:{joined_source}" if joined_source else source_key
if result.candidate is None:
failed += 1
last_error = result.error or "unknown generation error"
insert_item(
conn,
job_id,
source_path=source_path,
status="failed",
title="",
difficulty=0,
problem_id=None,
error_text=last_error,
)
update_import_job_progress(
conn,
job_id=job_id,
total_count=max_attempt,
processed_count=processed,
success_count=inserted,
failed_count=failed,
last_error=last_error,
)
continue
if remaining_needed <= 0:
insert_item(
conn,
job_id,
source_path=source_path,
status="skipped",
title=result.candidate.title,
difficulty=result.candidate.difficulty,
problem_id=None,
error_text="target reached before insert",
)
update_import_job_progress(
conn,
job_id=job_id,
total_count=max_attempt,
processed_count=processed,
success_count=inserted,
failed_count=failed,
last_error=last_error,
)
continue
dup, dup_id, dup_score = maybe_duplicate(
existing,
result.candidate.title,
result.candidate.statement_md,
dedupe_threshold,
)
if dup:
skipped_duplicate += 1
insert_item(
conn,
job_id,
source_path=source_path,
status="skipped",
title=result.candidate.title,
difficulty=result.candidate.difficulty,
problem_id=None,
error_text=f"duplicate with problem_id={dup_id}, similarity={dup_score:.4f}",
)
update_import_job_progress(
conn,
job_id=job_id,
total_count=max_attempt,
processed_count=processed,
success_count=inserted,
failed_count=failed,
last_error=last_error,
)
continue
try:
problem_id = insert_problem(conn, result.candidate, attempt_id)
inserted += 1
remaining_needed -= 1
existing.append(
ExistingProblem(
id=problem_id,
title=result.candidate.title,
statement_md=result.candidate.statement_md,
)
)
insert_item(
conn,
job_id,
source_path=source_path,
status="success",
title=result.candidate.title,
difficulty=result.candidate.difficulty,
problem_id=problem_id,
error_text=(
"source=llm"
if result.candidate.source == "llm"
else f"source=fallback; {result.candidate.llm_error[:280]}"
),
)
except Exception as exc:
failed += 1
last_error = str(exc)
insert_item(
conn,
job_id,
source_path=source_path,
status="failed",
title=result.candidate.title,
difficulty=result.candidate.difficulty,
problem_id=None,
error_text=f"insert failed: {exc}",
)
update_import_job_progress(
conn,
job_id=job_id,
total_count=max_attempt,
processed_count=processed,
success_count=inserted,
failed_count=failed,
last_error=last_error,
)
while remaining_needed > 0 and next_attempt <= max_attempt and len(futures) < workers:
submit_one(executor, next_attempt)
next_attempt += 1
if remaining_needed <= 0 and not futures:
break
reached_target = remaining_needed <= 0
summary = (
f"inserted={inserted}, skipped_duplicate={skipped_duplicate}, "
f"failed={failed}, target_remaining={remaining_needed}"
)
if last_error:
summary = f"{summary}; last_error={last_error[:180]}"
finish_import_job(
conn,
job_id=job_id,
processed_count=processed,
success_count=inserted,
failed_count=failed,
last_error=summary,
reached_target=reached_target,
)
conn.close()
print(
json.dumps(
{
"job_id": job_id,
"pdf_dir": args.pdf_dir,
"target_total": target_total,
"current_total_before": current_total,
"need": need,
"attempt_budget": max_attempt,
"processed": processed,
"inserted": inserted,
"skipped_duplicate": skipped_duplicate,
"failed": failed,
"reached_target": reached_target,
"usable_pdf_count": len(corpus),
},
ensure_ascii=False,
indent=2,
)
)
return 0 if reached_target else 1
if __name__ == "__main__":
raise SystemExit(main())