1276 行
39 KiB
Python
1276 行
39 KiB
Python
#!/usr/bin/env python3
|
||
"""Generate CSP J/S problems from local PDFs via RAG + LLM into SQLite."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import sqlite3
|
||
import subprocess
|
||
import time
|
||
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
||
from dataclasses import dataclass
|
||
from difflib import SequenceMatcher
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import requests
|
||
|
||
RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}
|
||
|
||
|
||
@dataclass
|
||
class ExistingProblem:
|
||
id: int
|
||
title: str
|
||
statement_md: str
|
||
|
||
|
||
@dataclass
|
||
class CorpusEntry:
|
||
path: str
|
||
text: str
|
||
exam_hint: str
|
||
|
||
|
||
@dataclass
|
||
class CandidateProblem:
|
||
title: str
|
||
difficulty: int
|
||
statement_md: str
|
||
sample_input: str
|
||
sample_output: str
|
||
answer: str
|
||
explanation: str
|
||
knowledge_points: list[str]
|
||
tags: list[str]
|
||
exam_type: str
|
||
source_paths: list[str]
|
||
source: str
|
||
llm_error: str
|
||
|
||
|
||
@dataclass
|
||
class AttemptResult:
|
||
attempt_id: int
|
||
source_paths: list[str]
|
||
candidate: CandidateProblem | None
|
||
error: str
|
||
|
||
|
||
def now_sec() -> int:
|
||
return int(time.time())
|
||
|
||
|
||
def normalize(text: str) -> str:
|
||
text = text.lower().strip()
|
||
text = re.sub(r"\s+", " ", text)
|
||
text = re.sub(r"[^0-9a-z\u4e00-\u9fff ]+", " ", text)
|
||
return re.sub(r"\s+", " ", text).strip()
|
||
|
||
|
||
def similarity(a: str, b: str) -> float:
|
||
if not a or not b:
|
||
return 0.0
|
||
return SequenceMatcher(None, normalize(a), normalize(b)).ratio()
|
||
|
||
|
||
def to_exam_type(text: str) -> str:
|
||
t = text.lower()
|
||
if "csp-j" in t or "noip-j" in t or "junior" in t or "普及" in t:
|
||
return "csp-j"
|
||
if "csp-s" in t or "noip-s" in t or "senior" in t or "提高" in t:
|
||
return "csp-s"
|
||
return ""
|
||
|
||
|
||
def ensure_problem_columns(conn: sqlite3.Connection) -> None:
|
||
cur = conn.cursor()
|
||
cur.execute("PRAGMA table_info(problems)")
|
||
cols = {str(row[1]) for row in cur.fetchall()}
|
||
needed = {
|
||
"sample_input": "ALTER TABLE problems ADD COLUMN sample_input TEXT NOT NULL DEFAULT ''",
|
||
"sample_output": "ALTER TABLE problems ADD COLUMN sample_output TEXT NOT NULL DEFAULT ''",
|
||
"statement_url": "ALTER TABLE problems ADD COLUMN statement_url TEXT NOT NULL DEFAULT ''",
|
||
"llm_profile_json": "ALTER TABLE problems ADD COLUMN llm_profile_json TEXT NOT NULL DEFAULT '{}'",
|
||
}
|
||
for col, sql in needed.items():
|
||
if col not in cols:
|
||
cur.execute(sql)
|
||
conn.commit()
|
||
|
||
|
||
def ensure_core_tables(conn: sqlite3.Connection) -> None:
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS problems (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
slug TEXT NOT NULL UNIQUE,
|
||
title TEXT NOT NULL,
|
||
statement_md TEXT NOT NULL,
|
||
difficulty INTEGER NOT NULL DEFAULT 1,
|
||
source TEXT NOT NULL DEFAULT '',
|
||
statement_url TEXT NOT NULL DEFAULT '',
|
||
llm_profile_json TEXT NOT NULL DEFAULT '{}',
|
||
sample_input TEXT NOT NULL DEFAULT '',
|
||
sample_output TEXT NOT NULL DEFAULT '',
|
||
created_at INTEGER NOT NULL
|
||
)
|
||
"""
|
||
)
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS problem_tags (
|
||
problem_id INTEGER NOT NULL,
|
||
tag TEXT NOT NULL,
|
||
PRIMARY KEY(problem_id, tag)
|
||
)
|
||
"""
|
||
)
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_problem_tags_tag ON problem_tags(tag)")
|
||
conn.commit()
|
||
|
||
|
||
def ensure_import_tables(conn: sqlite3.Connection) -> None:
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS import_jobs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
status TEXT NOT NULL,
|
||
trigger TEXT NOT NULL DEFAULT 'manual',
|
||
total_count INTEGER NOT NULL DEFAULT 0,
|
||
processed_count INTEGER NOT NULL DEFAULT 0,
|
||
success_count INTEGER NOT NULL DEFAULT 0,
|
||
failed_count INTEGER NOT NULL DEFAULT 0,
|
||
options_json TEXT NOT NULL DEFAULT '{}',
|
||
last_error TEXT NOT NULL DEFAULT '',
|
||
started_at INTEGER NOT NULL,
|
||
finished_at INTEGER,
|
||
updated_at INTEGER NOT NULL,
|
||
created_at INTEGER NOT NULL
|
||
)
|
||
"""
|
||
)
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS import_job_items (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
job_id INTEGER NOT NULL,
|
||
source_path TEXT NOT NULL,
|
||
status TEXT NOT NULL DEFAULT 'queued',
|
||
title TEXT NOT NULL DEFAULT '',
|
||
difficulty INTEGER NOT NULL DEFAULT 0,
|
||
problem_id INTEGER,
|
||
error_text TEXT NOT NULL DEFAULT '',
|
||
started_at INTEGER,
|
||
finished_at INTEGER,
|
||
updated_at INTEGER NOT NULL,
|
||
created_at INTEGER NOT NULL,
|
||
UNIQUE(job_id, source_path)
|
||
)
|
||
"""
|
||
)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_import_jobs_created_at ON import_jobs(created_at DESC)"
|
||
)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_import_job_items_job_status "
|
||
"ON import_job_items(job_id, status, updated_at DESC)"
|
||
)
|
||
conn.commit()
|
||
|
||
|
||
def create_import_job(
|
||
conn: sqlite3.Connection, trigger: str, total_count: int, options_json: str
|
||
) -> int:
|
||
ts = now_sec()
|
||
cur = conn.cursor()
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO import_jobs(
|
||
status,trigger,total_count,processed_count,success_count,failed_count,
|
||
options_json,last_error,started_at,finished_at,updated_at,created_at
|
||
) VALUES(?,?,?,?,?,?,?,?,?,?,?,?)
|
||
""",
|
||
(
|
||
"running",
|
||
trigger or "manual",
|
||
max(0, total_count),
|
||
0,
|
||
0,
|
||
0,
|
||
options_json,
|
||
"",
|
||
ts,
|
||
None,
|
||
ts,
|
||
ts,
|
||
),
|
||
)
|
||
conn.commit()
|
||
return int(cur.lastrowid)
|
||
|
||
|
||
def insert_item(
|
||
conn: sqlite3.Connection,
|
||
job_id: int,
|
||
source_path: str,
|
||
status: str,
|
||
title: str,
|
||
difficulty: int,
|
||
problem_id: int | None,
|
||
error_text: str,
|
||
) -> None:
|
||
ts = now_sec()
|
||
conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO import_job_items(
|
||
job_id,source_path,status,title,difficulty,problem_id,error_text,
|
||
started_at,finished_at,updated_at,created_at
|
||
) VALUES(?,?,?,?,?,?,?,?,?,?,?)
|
||
""",
|
||
(
|
||
job_id,
|
||
source_path,
|
||
status,
|
||
title,
|
||
int(difficulty),
|
||
problem_id,
|
||
error_text[:500],
|
||
ts,
|
||
ts,
|
||
ts,
|
||
ts,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
|
||
def update_import_job_progress(
|
||
conn: sqlite3.Connection,
|
||
job_id: int,
|
||
total_count: int,
|
||
processed_count: int,
|
||
success_count: int,
|
||
failed_count: int,
|
||
last_error: str,
|
||
) -> None:
|
||
ts = now_sec()
|
||
conn.execute(
|
||
"""
|
||
UPDATE import_jobs
|
||
SET total_count=?,
|
||
processed_count=?,
|
||
success_count=?,
|
||
failed_count=?,
|
||
last_error=?,
|
||
updated_at=?
|
||
WHERE id=?
|
||
""",
|
||
(
|
||
max(0, total_count),
|
||
max(0, processed_count),
|
||
max(0, success_count),
|
||
max(0, failed_count),
|
||
last_error[:500],
|
||
ts,
|
||
job_id,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
|
||
def finish_import_job(
|
||
conn: sqlite3.Connection,
|
||
job_id: int,
|
||
processed_count: int,
|
||
success_count: int,
|
||
failed_count: int,
|
||
last_error: str,
|
||
reached_target: bool,
|
||
) -> None:
|
||
ts = now_sec()
|
||
status = "completed" if reached_target and failed_count == 0 else "completed_with_errors"
|
||
conn.execute(
|
||
"""
|
||
UPDATE import_jobs
|
||
SET status=?,
|
||
total_count=?,
|
||
processed_count=?,
|
||
success_count=?,
|
||
failed_count=?,
|
||
last_error=?,
|
||
finished_at=?,
|
||
updated_at=?
|
||
WHERE id=?
|
||
""",
|
||
(
|
||
status,
|
||
max(processed_count, 0),
|
||
processed_count,
|
||
success_count,
|
||
failed_count,
|
||
last_error[:500],
|
||
ts,
|
||
ts,
|
||
job_id,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
|
||
def mark_job_failed_early(conn: sqlite3.Connection, job_id: int, reason: str) -> None:
|
||
ts = now_sec()
|
||
conn.execute(
|
||
"""
|
||
UPDATE import_jobs
|
||
SET status='failed',
|
||
last_error=?,
|
||
finished_at=?,
|
||
updated_at=?
|
||
WHERE id=?
|
||
""",
|
||
(reason[:500], ts, ts, job_id),
|
||
)
|
||
conn.commit()
|
||
|
||
|
||
def load_existing(conn: sqlite3.Connection) -> list[ExistingProblem]:
|
||
cur = conn.execute("SELECT id,title,statement_md FROM problems")
|
||
rows: list[ExistingProblem] = []
|
||
for row in cur.fetchall():
|
||
rows.append(
|
||
ExistingProblem(
|
||
id=int(row[0]),
|
||
title=str(row[1] or ""),
|
||
statement_md=str(row[2] or ""),
|
||
)
|
||
)
|
||
return rows
|
||
|
||
|
||
def load_exam_type_weights(conn: sqlite3.Connection) -> tuple[int, int]:
|
||
j_cnt = 0
|
||
s_cnt = 0
|
||
cur = conn.execute(
|
||
"""
|
||
SELECT lower(tag), COUNT(1)
|
||
FROM problem_tags
|
||
WHERE lower(tag) IN ('csp-j','csp-s')
|
||
GROUP BY lower(tag)
|
||
"""
|
||
)
|
||
for row in cur.fetchall():
|
||
tag = str(row[0] or "")
|
||
cnt = int(row[1] or 0)
|
||
if tag == "csp-j":
|
||
j_cnt = cnt
|
||
elif tag == "csp-s":
|
||
s_cnt = cnt
|
||
return j_cnt, s_cnt
|
||
|
||
|
||
def load_difficulty_pool(conn: sqlite3.Connection, exam_type: str) -> list[int]:
|
||
cur = conn.execute(
|
||
"""
|
||
SELECT p.difficulty, COUNT(1)
|
||
FROM problems p
|
||
JOIN problem_tags t ON t.problem_id=p.id
|
||
WHERE lower(t.tag)=?
|
||
GROUP BY p.difficulty
|
||
""",
|
||
(exam_type,),
|
||
)
|
||
pool: list[int] = []
|
||
for row in cur.fetchall():
|
||
diff = max(1, min(10, int(row[0] or 1)))
|
||
cnt = max(1, int(row[1] or 1))
|
||
pool.extend([diff] * min(60, cnt))
|
||
return pool
|
||
|
||
|
||
def discover_pdf_files(pdf_dir: Path) -> list[Path]:
|
||
if not pdf_dir.exists() or not pdf_dir.is_dir():
|
||
return []
|
||
rows = [p for p in pdf_dir.rglob("*.pdf") if p.is_file()]
|
||
rows.sort(key=lambda p: str(p))
|
||
return rows
|
||
|
||
|
||
def build_corpus_from_existing_problems(
|
||
existing: list[ExistingProblem], max_chars_per_item: int
|
||
) -> list[CorpusEntry]:
|
||
corpus: list[CorpusEntry] = []
|
||
if not existing:
|
||
return corpus
|
||
|
||
cap = min(800, len(existing))
|
||
start = max(0, len(existing) - cap)
|
||
for p in existing[start:]:
|
||
text = (p.statement_md or "").strip()
|
||
if len(text) < 80:
|
||
continue
|
||
snippet = text[: max(500, max_chars_per_item)]
|
||
hint_source = f"{p.title} {snippet[:200]}"
|
||
corpus.append(
|
||
CorpusEntry(
|
||
path=f"db:problem:{p.id}:{p.title[:64]}",
|
||
text=snippet,
|
||
exam_hint=to_exam_type(hint_source),
|
||
)
|
||
)
|
||
return corpus
|
||
|
||
|
||
def extract_pdf_text(path: Path, max_pages: int, max_chars: int) -> str:
|
||
cmd = [
|
||
"pdftotext",
|
||
"-layout",
|
||
"-f",
|
||
"1",
|
||
"-l",
|
||
str(max(1, max_pages)),
|
||
str(path),
|
||
"-",
|
||
]
|
||
proc = subprocess.run(
|
||
cmd,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
check=False,
|
||
timeout=60,
|
||
)
|
||
if proc.returncode != 0:
|
||
return ""
|
||
|
||
text = proc.stdout.decode("utf-8", errors="ignore")
|
||
text = text.replace("\r\n", "\n").replace("\r", "\n")
|
||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||
text = re.sub(r"[ \t]+", " ", text)
|
||
text = text.strip()
|
||
if not text:
|
||
return ""
|
||
return text[: max(500, max_chars)]
|
||
|
||
|
||
def build_corpus(files: list[Path], workers: int, max_pages: int, max_chars: int) -> list[CorpusEntry]:
|
||
corpus: list[CorpusEntry] = []
|
||
|
||
def work(path: Path) -> CorpusEntry | None:
|
||
text = extract_pdf_text(path, max_pages=max_pages, max_chars=max_chars)
|
||
if len(text) < 80:
|
||
return None
|
||
return CorpusEntry(path=str(path), text=text, exam_hint=to_exam_type(str(path)))
|
||
|
||
with ThreadPoolExecutor(max_workers=max(1, workers)) as executor:
|
||
futures = [executor.submit(work, path) for path in files]
|
||
for future in futures:
|
||
row = future.result()
|
||
if row is not None:
|
||
corpus.append(row)
|
||
return corpus
|
||
|
||
|
||
def requests_retry_post(
|
||
url: str,
|
||
headers: dict[str, str],
|
||
body: dict[str, Any],
|
||
timeout: int,
|
||
retries: int,
|
||
sleep_sec: float,
|
||
) -> dict[str, Any]:
|
||
last_error: Exception | None = None
|
||
for i in range(1, retries + 1):
|
||
try:
|
||
resp = requests.post(url, headers=headers, json=body, timeout=timeout)
|
||
except requests.RequestException as exc:
|
||
last_error = exc
|
||
if i < retries:
|
||
time.sleep(sleep_sec * i)
|
||
continue
|
||
raise RuntimeError(f"llm request failed: {exc}") from exc
|
||
|
||
if resp.status_code in RETRYABLE_HTTP_CODES:
|
||
if i < retries:
|
||
time.sleep(sleep_sec * i)
|
||
continue
|
||
raise RuntimeError(f"llm retry exhausted: HTTP {resp.status_code}")
|
||
|
||
if resp.status_code >= 400:
|
||
raise RuntimeError(f"llm request failed: HTTP {resp.status_code}: {resp.text[:300]}")
|
||
|
||
payload = resp.json()
|
||
if isinstance(payload, dict):
|
||
return payload
|
||
raise RuntimeError("llm response is not JSON object")
|
||
|
||
if last_error:
|
||
raise RuntimeError(f"llm request failed: {last_error}") from last_error
|
||
raise RuntimeError("llm request failed")
|
||
|
||
|
||
def extract_json_object(text: str) -> dict[str, Any] | None:
|
||
raw = text.strip()
|
||
if raw.startswith("```"):
|
||
raw = re.sub(r"^```[a-zA-Z0-9_-]*", "", raw).strip()
|
||
raw = raw.removesuffix("```").strip()
|
||
try:
|
||
obj = json.loads(raw)
|
||
if isinstance(obj, dict):
|
||
return obj
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
match = re.search(r"\{[\s\S]*\}", text)
|
||
if not match:
|
||
return None
|
||
try:
|
||
obj = json.loads(match.group(0))
|
||
return obj if isinstance(obj, dict) else None
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
|
||
def maybe_duplicate(
|
||
existing: list[ExistingProblem],
|
||
title: str,
|
||
statement_md: str,
|
||
threshold: float,
|
||
) -> tuple[bool, int | None, float]:
|
||
best_id = None
|
||
best_score = 0.0
|
||
statement_head = statement_md[:1600]
|
||
for p in existing:
|
||
t_sim = similarity(title, p.title)
|
||
s_sim = similarity(statement_head, p.statement_md[:1600])
|
||
score = max(t_sim, s_sim * 0.85 + t_sim * 0.15)
|
||
if score > best_score:
|
||
best_score = score
|
||
best_id = p.id
|
||
return best_score >= threshold, best_id, best_score
|
||
|
||
|
||
def choose_exam_type(rng: random.Random, j_weight: int, s_weight: int) -> str:
|
||
j = max(1, j_weight)
|
||
s = max(1, s_weight)
|
||
return "csp-j" if rng.randint(1, j + s) <= j else "csp-s"
|
||
|
||
|
||
def choose_difficulty(rng: random.Random, exam_type: str, pools: dict[str, list[int]]) -> int:
|
||
pool = pools.get(exam_type) or []
|
||
if pool:
|
||
return int(rng.choice(pool))
|
||
if exam_type == "csp-j":
|
||
return int(rng.choice([2, 2, 3, 3, 4]))
|
||
return int(rng.choice([3, 4, 4, 5, 6]))
|
||
|
||
|
||
def choose_context(
|
||
rng: random.Random,
|
||
corpus: list[CorpusEntry],
|
||
exam_type: str,
|
||
max_context_chars: int,
|
||
) -> list[CorpusEntry]:
|
||
preferred = [x for x in corpus if x.exam_hint == exam_type]
|
||
candidates = preferred if len(preferred) >= 2 else corpus
|
||
if not candidates:
|
||
return []
|
||
|
||
picks = rng.sample(candidates, k=min(3, len(candidates)))
|
||
used = 0
|
||
out: list[CorpusEntry] = []
|
||
for row in picks:
|
||
left = max_context_chars - used
|
||
if left <= 0:
|
||
break
|
||
clipped = row.text[:left]
|
||
out.append(CorpusEntry(path=row.path, text=clipped, exam_hint=row.exam_hint))
|
||
used += len(clipped)
|
||
return out
|
||
|
||
|
||
def build_prompt(
|
||
exam_type: str,
|
||
difficulty: int,
|
||
context_rows: list[CorpusEntry],
|
||
avoid_titles: list[str],
|
||
) -> str:
|
||
ctx_parts: list[str] = []
|
||
for idx, row in enumerate(context_rows, start=1):
|
||
ctx_parts.append(f"[参考材料{idx}] 来源: {row.path}\n{row.text}")
|
||
ctx_text = "\n\n".join(ctx_parts)
|
||
avoid_text = ";".join(avoid_titles[:25]) if avoid_titles else "无"
|
||
exam_cn = "CSP-J" if exam_type == "csp-j" else "CSP-S"
|
||
|
||
return f"""
|
||
你是资深 OI 出题人。请根据给定本地 PDF 参考材料,生成 1 道原创 {exam_cn} 题目。
|
||
要求:
|
||
1. 难度尽量贴近 {difficulty}(范围 1~10)。
|
||
2. 题型风格贴合 {exam_cn}。
|
||
3. 严禁与已有题目雷同,尤其避免这些标题:{avoid_text}
|
||
4. 输出必须是 JSON 对象,不要输出任何额外文本。
|
||
5. JSON 字段固定:
|
||
{{
|
||
"title": "题目标题",
|
||
"difficulty": 3,
|
||
"statement_md": "完整 Markdown 题面(题目描述/输入格式/输出格式/数据范围)",
|
||
"sample_input": "样例输入",
|
||
"sample_output": "样例输出",
|
||
"answer": "核心答案要点",
|
||
"explanation": "简要讲解",
|
||
"knowledge_points": ["知识点1","知识点2"],
|
||
"tags": ["csp-j 或 csp-s","算法标签"]
|
||
}}
|
||
|
||
参考材料(RAG):
|
||
{ctx_text}
|
||
""".strip()
|
||
|
||
|
||
def llm_generate(prompt: str, timeout: int, retries: int, sleep_sec: float) -> dict[str, Any]:
|
||
url = os.getenv("OI_LLM_API_URL", "").strip()
|
||
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
|
||
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
|
||
if not url:
|
||
raise RuntimeError("missing OI_LLM_API_URL")
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
body = {
|
||
"model": model,
|
||
"stream": False,
|
||
"temperature": 0.7,
|
||
"messages": [
|
||
{"role": "system", "content": "你是 OI/CSP 出题助手。只输出 JSON。"},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
}
|
||
|
||
payload = requests_retry_post(
|
||
url=url,
|
||
headers=headers,
|
||
body=body,
|
||
timeout=timeout,
|
||
retries=retries,
|
||
sleep_sec=sleep_sec,
|
||
)
|
||
choices = payload.get("choices") if isinstance(payload, dict) else None
|
||
if not isinstance(choices, list) or not choices:
|
||
raise RuntimeError("llm response missing choices")
|
||
content = ((choices[0] or {}).get("message") or {}).get("content")
|
||
if not content:
|
||
raise RuntimeError("llm response missing content")
|
||
|
||
obj = extract_json_object(str(content))
|
||
if not isinstance(obj, dict):
|
||
raise RuntimeError("llm returned non-json content")
|
||
return obj
|
||
|
||
|
||
def fallback_candidate(
|
||
attempt_id: int,
|
||
exam_type: str,
|
||
difficulty: int,
|
||
source_paths: list[str],
|
||
llm_error: str,
|
||
) -> CandidateProblem:
|
||
exam_cn = "CSP-J" if exam_type == "csp-j" else "CSP-S"
|
||
title = f"{exam_cn} 训练题·生成草案 {now_sec()}-{attempt_id}"
|
||
statement_md = """
|
||
# 题目描述
|
||
给定一个长度为 n 的整数序列,请你计算满足特定条件的子区间数量。
|
||
|
||
## 输入格式
|
||
第一行一个整数 n。
|
||
第二行 n 个整数 a_i。
|
||
|
||
## 输出格式
|
||
输出一个整数,表示答案。
|
||
|
||
## 数据范围
|
||
- 1 <= n <= 2e5
|
||
- |a_i| <= 1e9
|
||
|
||
## 提示
|
||
可尝试前缀和、哈希计数或双指针等思路。
|
||
""".strip()
|
||
|
||
return CandidateProblem(
|
||
title=title,
|
||
difficulty=max(1, min(10, difficulty)),
|
||
statement_md=statement_md,
|
||
sample_input="6\n1 2 3 4 5 6\n",
|
||
sample_output="0\n",
|
||
answer="可从前缀和统计角度考虑。",
|
||
explanation="使用前缀和 + 哈希统计满足条件的区间。",
|
||
knowledge_points=["前缀和", "哈希"],
|
||
tags=[exam_type, "fallback-generated", "local-pdf-rag"],
|
||
exam_type=exam_type,
|
||
source_paths=source_paths,
|
||
source="fallback",
|
||
llm_error=llm_error,
|
||
)
|
||
|
||
|
||
def build_candidate_from_json(
|
||
obj: dict[str, Any],
|
||
exam_type: str,
|
||
difficulty_hint: int,
|
||
source_paths: list[str],
|
||
source: str,
|
||
llm_error: str,
|
||
) -> CandidateProblem:
|
||
title = str(obj.get("title") or "").strip()
|
||
if not title:
|
||
raise RuntimeError("generated title is empty")
|
||
|
||
difficulty = int(obj.get("difficulty") or difficulty_hint or 3)
|
||
difficulty = max(1, min(10, difficulty))
|
||
|
||
statement_md = str(obj.get("statement_md") or "").strip()
|
||
if not statement_md:
|
||
desc = str(obj.get("description") or "").strip()
|
||
in_fmt = str(obj.get("input_format") or "").strip()
|
||
out_fmt = str(obj.get("output_format") or "").strip()
|
||
statement_md = (
|
||
"# 题目描述\n"
|
||
+ (desc or "请根据题意补全。")
|
||
+ "\n\n## 输入格式\n"
|
||
+ (in_fmt or "见题目描述")
|
||
+ "\n\n## 输出格式\n"
|
||
+ (out_fmt or "见题目描述")
|
||
).strip()
|
||
|
||
sample_input = str(obj.get("sample_input") or "").strip()
|
||
sample_output = str(obj.get("sample_output") or "").strip()
|
||
answer = str(obj.get("answer") or "").strip()
|
||
explanation = str(obj.get("explanation") or "").strip()
|
||
|
||
kp_raw = obj.get("knowledge_points")
|
||
knowledge_points = [str(x).strip() for x in kp_raw] if isinstance(kp_raw, list) else []
|
||
knowledge_points = [x for x in knowledge_points if x][:8]
|
||
|
||
tags_raw = obj.get("tags")
|
||
tags = [str(x).strip() for x in tags_raw] if isinstance(tags_raw, list) else []
|
||
tags = [x for x in tags if x][:12]
|
||
tags.extend([exam_type, "local-pdf-rag"])
|
||
|
||
normalized_tags: list[str] = []
|
||
seen: set[str] = set()
|
||
for tag in tags:
|
||
n = normalize(tag).replace(" ", "-")
|
||
if not n or n in seen:
|
||
continue
|
||
seen.add(n)
|
||
normalized_tags.append(n)
|
||
|
||
return CandidateProblem(
|
||
title=title,
|
||
difficulty=difficulty,
|
||
statement_md=statement_md,
|
||
sample_input=sample_input,
|
||
sample_output=sample_output,
|
||
answer=answer,
|
||
explanation=explanation,
|
||
knowledge_points=knowledge_points,
|
||
tags=normalized_tags,
|
||
exam_type=exam_type,
|
||
source_paths=source_paths,
|
||
source=source,
|
||
llm_error=llm_error,
|
||
)
|
||
|
||
|
||
def make_slug(title: str, attempt_id: int) -> str:
|
||
stem = normalize(title).replace(" ", "-")
|
||
stem = re.sub(r"[^a-z0-9\-]+", "", stem).strip("-")
|
||
if not stem:
|
||
stem = "local-pdf-rag"
|
||
return f"{stem[:48]}-{now_sec()}-{attempt_id}"
|
||
|
||
|
||
def insert_problem(conn: sqlite3.Connection, cand: CandidateProblem, attempt_id: int) -> int:
|
||
ts = now_sec()
|
||
profile = {
|
||
"schema_version": 1,
|
||
"platform": "llm-local-pdf-rag" if cand.source == "llm" else "fallback-local-pdf-rag",
|
||
"difficulty": cand.difficulty,
|
||
"answer": cand.answer,
|
||
"explanation": cand.explanation,
|
||
"knowledge_points": cand.knowledge_points,
|
||
"tags": cand.tags,
|
||
"exam_type": cand.exam_type,
|
||
"source_paths": cand.source_paths,
|
||
"generated_at": ts,
|
||
}
|
||
if cand.llm_error:
|
||
profile["llm_error"] = cand.llm_error[:300]
|
||
|
||
cur = conn.cursor()
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO problems(
|
||
slug,title,statement_md,difficulty,source,statement_url,llm_profile_json,
|
||
sample_input,sample_output,created_at
|
||
) VALUES(?,?,?,?,?,?,?,?,?,?)
|
||
""",
|
||
(
|
||
make_slug(cand.title, attempt_id),
|
||
cand.title,
|
||
cand.statement_md,
|
||
cand.difficulty,
|
||
f"llm:{cand.exam_type}:local-pdf-rag",
|
||
"",
|
||
json.dumps(profile, ensure_ascii=False),
|
||
cand.sample_input,
|
||
cand.sample_output,
|
||
ts,
|
||
),
|
||
)
|
||
problem_id = int(cur.lastrowid)
|
||
|
||
for tag in cand.tags:
|
||
cur.execute(
|
||
"INSERT OR IGNORE INTO problem_tags(problem_id,tag) VALUES(?,?)",
|
||
(problem_id, tag),
|
||
)
|
||
conn.commit()
|
||
return problem_id
|
||
|
||
|
||
def run_attempt(
|
||
attempt_id: int,
|
||
corpus: list[CorpusEntry],
|
||
max_context_chars: int,
|
||
j_weight: int,
|
||
s_weight: int,
|
||
pools: dict[str, list[int]],
|
||
existing_titles: list[str],
|
||
timeout: int,
|
||
retries: int,
|
||
retry_sleep_sec: float,
|
||
) -> AttemptResult:
|
||
rng = random.Random(now_sec() + attempt_id * 131)
|
||
exam_type = choose_exam_type(rng, j_weight, s_weight)
|
||
difficulty = choose_difficulty(rng, exam_type, pools)
|
||
context_rows = choose_context(rng, corpus, exam_type, max_context_chars=max_context_chars)
|
||
if not context_rows:
|
||
return AttemptResult(attempt_id=attempt_id, source_paths=[], candidate=None, error="no context")
|
||
|
||
source_paths = [x.path for x in context_rows]
|
||
avoid_titles = rng.sample(existing_titles, k=min(30, len(existing_titles))) if existing_titles else []
|
||
prompt = build_prompt(exam_type, difficulty, context_rows, avoid_titles)
|
||
|
||
source = "llm"
|
||
llm_error = ""
|
||
try:
|
||
obj = llm_generate(prompt, timeout=timeout, retries=retries, sleep_sec=retry_sleep_sec)
|
||
except Exception as exc:
|
||
source = "fallback"
|
||
llm_error = str(exc)
|
||
return AttemptResult(
|
||
attempt_id=attempt_id,
|
||
source_paths=source_paths,
|
||
candidate=fallback_candidate(
|
||
attempt_id=attempt_id,
|
||
exam_type=exam_type,
|
||
difficulty=difficulty,
|
||
source_paths=source_paths,
|
||
llm_error=llm_error,
|
||
),
|
||
error="",
|
||
)
|
||
|
||
try:
|
||
candidate = build_candidate_from_json(
|
||
obj=obj,
|
||
exam_type=exam_type,
|
||
difficulty_hint=difficulty,
|
||
source_paths=source_paths,
|
||
source=source,
|
||
llm_error=llm_error,
|
||
)
|
||
except Exception as exc:
|
||
return AttemptResult(
|
||
attempt_id=attempt_id,
|
||
source_paths=source_paths,
|
||
candidate=None,
|
||
error=f"invalid generated payload: {exc}",
|
||
)
|
||
|
||
return AttemptResult(attempt_id=attempt_id, source_paths=source_paths, candidate=candidate, error="")
|
||
|
||
|
||
def main() -> int:
|
||
parser = argparse.ArgumentParser(description="Generate CSP J/S from local PDF RAG")
|
||
parser.add_argument("--db-path", required=True)
|
||
parser.add_argument("--pdf-dir", default="/data/local_pdfs")
|
||
parser.add_argument("--target-total", type=int, default=5000)
|
||
parser.add_argument("--workers", type=int, default=3)
|
||
parser.add_argument("--dedupe-threshold", type=float, default=0.72)
|
||
parser.add_argument("--max-attempt-multiplier", type=int, default=8)
|
||
parser.add_argument("--max-pages", type=int, default=6)
|
||
parser.add_argument("--max-chars-per-pdf", type=int, default=12000)
|
||
parser.add_argument("--max-context-chars", type=int, default=12000)
|
||
parser.add_argument("--timeout", type=int, default=90)
|
||
parser.add_argument("--retries", type=int, default=4)
|
||
parser.add_argument("--retry-sleep-sec", type=float, default=1.5)
|
||
parser.add_argument("--job-trigger", default="manual")
|
||
args = parser.parse_args()
|
||
|
||
workers = max(1, min(16, args.workers))
|
||
target_total = max(1, min(50000, args.target_total))
|
||
dedupe_threshold = max(0.1, min(0.99, args.dedupe_threshold))
|
||
|
||
conn = sqlite3.connect(args.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
conn.execute("PRAGMA foreign_keys=ON")
|
||
conn.execute("PRAGMA busy_timeout=5000")
|
||
|
||
ensure_core_tables(conn)
|
||
ensure_problem_columns(conn)
|
||
ensure_import_tables(conn)
|
||
|
||
existing = load_existing(conn)
|
||
current_total = len(existing)
|
||
need = max(0, target_total - current_total)
|
||
attempts_budget = max(need, need * max(2, args.max_attempt_multiplier))
|
||
|
||
options_json = json.dumps(
|
||
{
|
||
"source": "local_pdf_rag",
|
||
"mode": "local_pdf_rag",
|
||
"pdf_dir": args.pdf_dir,
|
||
"local_pdf_dir": args.pdf_dir,
|
||
"workers": workers,
|
||
"target_total": target_total,
|
||
"max_attempt_multiplier": max(2, args.max_attempt_multiplier),
|
||
"dedupe_threshold": dedupe_threshold,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
job_id = create_import_job(conn, args.job_trigger, attempts_budget, options_json)
|
||
|
||
if need <= 0:
|
||
finish_import_job(
|
||
conn,
|
||
job_id=job_id,
|
||
processed_count=0,
|
||
success_count=0,
|
||
failed_count=0,
|
||
last_error=f"target reached: current={current_total}, target={target_total}",
|
||
reached_target=True,
|
||
)
|
||
print(
|
||
json.dumps(
|
||
{
|
||
"job_id": job_id,
|
||
"target_total": target_total,
|
||
"current_total": current_total,
|
||
"inserted": 0,
|
||
"message": "target already reached",
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
)
|
||
)
|
||
conn.close()
|
||
return 0
|
||
|
||
pdf_files = discover_pdf_files(Path(args.pdf_dir))
|
||
corpus: list[CorpusEntry] = []
|
||
fallback_reason = ""
|
||
|
||
if pdf_files:
|
||
corpus = build_corpus(
|
||
files=pdf_files,
|
||
workers=workers,
|
||
max_pages=args.max_pages,
|
||
max_chars=args.max_chars_per_pdf,
|
||
)
|
||
if not corpus:
|
||
fallback_reason = "all pdf extraction failed or too short, fallback to existing problems"
|
||
else:
|
||
fallback_reason = f"no pdf files found in: {args.pdf_dir}, fallback to existing problems"
|
||
|
||
if not corpus:
|
||
corpus = build_corpus_from_existing_problems(existing, args.max_chars_per_pdf)
|
||
|
||
if not corpus:
|
||
reason = f"{fallback_reason}; and no usable existing problems" if fallback_reason else \
|
||
"no corpus available from pdfs or existing problems"
|
||
mark_job_failed_early(conn, job_id, reason)
|
||
print(json.dumps({"job_id": job_id, "error": reason}, ensure_ascii=False, indent=2))
|
||
conn.close()
|
||
return 1
|
||
|
||
if fallback_reason:
|
||
print(
|
||
json.dumps(
|
||
{
|
||
"job_id": job_id,
|
||
"warning": fallback_reason,
|
||
"fallback_corpus_size": len(corpus),
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
)
|
||
)
|
||
|
||
j_weight, s_weight = load_exam_type_weights(conn)
|
||
pools = {
|
||
"csp-j": load_difficulty_pool(conn, "csp-j"),
|
||
"csp-s": load_difficulty_pool(conn, "csp-s"),
|
||
}
|
||
|
||
inserted = 0
|
||
skipped_duplicate = 0
|
||
failed = 0
|
||
processed = 0
|
||
remaining_needed = need
|
||
last_error = ""
|
||
|
||
next_attempt = 1
|
||
max_attempt = max(need, attempts_budget)
|
||
futures: dict[Future[AttemptResult], int] = {}
|
||
|
||
def submit_one(executor: ThreadPoolExecutor, attempt_id: int) -> None:
|
||
titles_snapshot = [x.title for x in existing]
|
||
future = executor.submit(
|
||
run_attempt,
|
||
attempt_id,
|
||
corpus,
|
||
max(1000, args.max_context_chars),
|
||
j_weight,
|
||
s_weight,
|
||
pools,
|
||
titles_snapshot,
|
||
args.timeout,
|
||
args.retries,
|
||
args.retry_sleep_sec,
|
||
)
|
||
futures[future] = attempt_id
|
||
|
||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||
while next_attempt <= max_attempt and len(futures) < workers:
|
||
submit_one(executor, next_attempt)
|
||
next_attempt += 1
|
||
|
||
while futures:
|
||
done, _ = wait(list(futures.keys()), return_when=FIRST_COMPLETED)
|
||
for future in done:
|
||
attempt_id = futures.pop(future)
|
||
processed += 1
|
||
source_key = f"attempt-{attempt_id:06d}"
|
||
try:
|
||
result = future.result()
|
||
except Exception as exc:
|
||
failed += 1
|
||
last_error = str(exc)
|
||
insert_item(
|
||
conn,
|
||
job_id,
|
||
source_path=source_key,
|
||
status="failed",
|
||
title="",
|
||
difficulty=0,
|
||
problem_id=None,
|
||
error_text=f"worker panic: {exc}",
|
||
)
|
||
update_import_job_progress(
|
||
conn,
|
||
job_id=job_id,
|
||
total_count=max_attempt,
|
||
processed_count=processed,
|
||
success_count=inserted,
|
||
failed_count=failed,
|
||
last_error=last_error,
|
||
)
|
||
continue
|
||
|
||
joined_source = " | ".join(result.source_paths[:2])
|
||
source_path = f"{source_key}:{joined_source}" if joined_source else source_key
|
||
|
||
if result.candidate is None:
|
||
failed += 1
|
||
last_error = result.error or "unknown generation error"
|
||
insert_item(
|
||
conn,
|
||
job_id,
|
||
source_path=source_path,
|
||
status="failed",
|
||
title="",
|
||
difficulty=0,
|
||
problem_id=None,
|
||
error_text=last_error,
|
||
)
|
||
update_import_job_progress(
|
||
conn,
|
||
job_id=job_id,
|
||
total_count=max_attempt,
|
||
processed_count=processed,
|
||
success_count=inserted,
|
||
failed_count=failed,
|
||
last_error=last_error,
|
||
)
|
||
continue
|
||
|
||
if remaining_needed <= 0:
|
||
insert_item(
|
||
conn,
|
||
job_id,
|
||
source_path=source_path,
|
||
status="skipped",
|
||
title=result.candidate.title,
|
||
difficulty=result.candidate.difficulty,
|
||
problem_id=None,
|
||
error_text="target reached before insert",
|
||
)
|
||
update_import_job_progress(
|
||
conn,
|
||
job_id=job_id,
|
||
total_count=max_attempt,
|
||
processed_count=processed,
|
||
success_count=inserted,
|
||
failed_count=failed,
|
||
last_error=last_error,
|
||
)
|
||
continue
|
||
|
||
dup, dup_id, dup_score = maybe_duplicate(
|
||
existing,
|
||
result.candidate.title,
|
||
result.candidate.statement_md,
|
||
dedupe_threshold,
|
||
)
|
||
if dup:
|
||
skipped_duplicate += 1
|
||
insert_item(
|
||
conn,
|
||
job_id,
|
||
source_path=source_path,
|
||
status="skipped",
|
||
title=result.candidate.title,
|
||
difficulty=result.candidate.difficulty,
|
||
problem_id=None,
|
||
error_text=f"duplicate with problem_id={dup_id}, similarity={dup_score:.4f}",
|
||
)
|
||
update_import_job_progress(
|
||
conn,
|
||
job_id=job_id,
|
||
total_count=max_attempt,
|
||
processed_count=processed,
|
||
success_count=inserted,
|
||
failed_count=failed,
|
||
last_error=last_error,
|
||
)
|
||
continue
|
||
|
||
try:
|
||
problem_id = insert_problem(conn, result.candidate, attempt_id)
|
||
inserted += 1
|
||
remaining_needed -= 1
|
||
existing.append(
|
||
ExistingProblem(
|
||
id=problem_id,
|
||
title=result.candidate.title,
|
||
statement_md=result.candidate.statement_md,
|
||
)
|
||
)
|
||
insert_item(
|
||
conn,
|
||
job_id,
|
||
source_path=source_path,
|
||
status="success",
|
||
title=result.candidate.title,
|
||
difficulty=result.candidate.difficulty,
|
||
problem_id=problem_id,
|
||
error_text=(
|
||
"source=llm"
|
||
if result.candidate.source == "llm"
|
||
else f"source=fallback; {result.candidate.llm_error[:280]}"
|
||
),
|
||
)
|
||
except Exception as exc:
|
||
failed += 1
|
||
last_error = str(exc)
|
||
insert_item(
|
||
conn,
|
||
job_id,
|
||
source_path=source_path,
|
||
status="failed",
|
||
title=result.candidate.title,
|
||
difficulty=result.candidate.difficulty,
|
||
problem_id=None,
|
||
error_text=f"insert failed: {exc}",
|
||
)
|
||
|
||
update_import_job_progress(
|
||
conn,
|
||
job_id=job_id,
|
||
total_count=max_attempt,
|
||
processed_count=processed,
|
||
success_count=inserted,
|
||
failed_count=failed,
|
||
last_error=last_error,
|
||
)
|
||
|
||
while remaining_needed > 0 and next_attempt <= max_attempt and len(futures) < workers:
|
||
submit_one(executor, next_attempt)
|
||
next_attempt += 1
|
||
|
||
if remaining_needed <= 0 and not futures:
|
||
break
|
||
|
||
reached_target = remaining_needed <= 0
|
||
summary = (
|
||
f"inserted={inserted}, skipped_duplicate={skipped_duplicate}, "
|
||
f"failed={failed}, target_remaining={remaining_needed}"
|
||
)
|
||
if last_error:
|
||
summary = f"{summary}; last_error={last_error[:180]}"
|
||
|
||
finish_import_job(
|
||
conn,
|
||
job_id=job_id,
|
||
processed_count=processed,
|
||
success_count=inserted,
|
||
failed_count=failed,
|
||
last_error=summary,
|
||
reached_target=reached_target,
|
||
)
|
||
|
||
conn.close()
|
||
|
||
print(
|
||
json.dumps(
|
||
{
|
||
"job_id": job_id,
|
||
"pdf_dir": args.pdf_dir,
|
||
"target_total": target_total,
|
||
"current_total_before": current_total,
|
||
"need": need,
|
||
"attempt_budget": max_attempt,
|
||
"processed": processed,
|
||
"inserted": inserted,
|
||
"skipped_duplicate": skipped_duplicate,
|
||
"failed": failed,
|
||
"reached_target": reached_target,
|
||
"usable_pdf_count": len(corpus),
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
)
|
||
)
|
||
|
||
return 0 if reached_target else 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|