文件
csp/scripts/import_winterant_oi.py

1303 行
40 KiB
Python

#!/usr/bin/env python3
"""Import OI statement PDFs from winterant/oi with PDF + LLM pipeline.
Pipeline:
1) Discover candidate PDF files from GitHub tree.
2) Optionally clear old problem set.
3) Download each PDF to local cache.
4) Extract text via pdftotext.
5) Call LLM with retry on HTTP 500/502/503/504.
6) Upsert into SQLite problems/problem_tags.
7) Persist progress/result into import_jobs/import_job_items for UI status pages.
"""
from __future__ import annotations
import argparse
import fcntl
import hashlib
import json
import os
import re
import shutil
import sqlite3
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import quote
import requests
DEFAULT_OWNER = "winterant"
DEFAULT_REPO = "oi"
SOURCE_PREFIX = "winterant/oi"
RETRYABLE_HTTP_CODES = {500, 502, 503, 504}
PDF_EXTS = {".pdf"}
ROOT_PREFIXES = ("ioi/", "noi/", "noip/", "csp-j/", "csp-s/", "csp-x")
SKIP_KEYWORDS = (
"solution",
"sol",
"answer",
"answers",
"readme",
"sample",
"data",
"testdata",
"overview",
"answer-sheet",
"standard",
"templates",
"模板",
"答案",
"题解",
"解析",
"答题",
)
@dataclass
class ProblemRecord:
slug: str
title: str
statement_md: str
difficulty: int
source: str
statement_url: str
llm_profile_json: str
sample_input: str
sample_output: str
tags: list[str]
@dataclass
class WorkResult:
path: str
record: ProblemRecord | None
llm_ok: bool
error: str
def now_sec() -> int:
return int(time.time())
def acquire_import_lock(lock_file: str):
lock_path = Path(lock_file)
lock_path.parent.mkdir(parents=True, exist_ok=True)
handle = lock_path.open("a+")
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError as exc:
handle.close()
raise RuntimeError(
f"another import process is running (lock_file={lock_file}): {exc}"
) from exc
return handle
def sleep_backoff(base_sleep: float, attempt: int) -> None:
time.sleep(base_sleep * attempt)
def to_bool(value: str | None, default: bool = False) -> bool:
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def requests_with_retry(
method: str,
url: str,
*,
headers: dict[str, str] | None = None,
json_body: dict[str, Any] | None = None,
timeout: int = 90,
stream: bool = False,
retries: int = 5,
base_sleep: float = 1.5,
) -> requests.Response:
allow_direct_fallback = to_bool(os.getenv("OI_IMPORT_DIRECT_FALLBACK"), True)
prefer_direct = to_bool(os.getenv("OI_IMPORT_PREFER_DIRECT"), True)
request_order = [True, False] if allow_direct_fallback and prefer_direct else [False, True]
if not allow_direct_fallback:
request_order = [False]
def do_request(use_direct: bool) -> requests.Response:
if not use_direct:
return requests.request(
method,
url,
headers=headers,
json=json_body,
timeout=timeout,
stream=stream,
)
with requests.Session() as session:
session.trust_env = False
return session.request(
method,
url,
headers=headers,
json=json_body,
timeout=timeout,
stream=stream,
)
last_error: Exception | None = None
last_resp: requests.Response | None = None
for attempt in range(1, retries + 1):
seen_retryable = False
for use_direct in request_order:
try:
resp = do_request(use_direct)
last_resp = resp
except requests.RequestException as exc:
last_error = exc
continue
if resp.status_code in RETRYABLE_HTTP_CODES:
seen_retryable = True
continue
return resp
if attempt < retries and seen_retryable:
sleep_backoff(base_sleep, attempt)
continue
if attempt < retries and last_error is not None:
sleep_backoff(base_sleep, attempt)
continue
break
if last_error:
raise RuntimeError(f"request failed after retry: {url}: {last_error}") from last_error
if last_resp is not None:
raise RuntimeError(
f"request failed after retry: {url}: status={last_resp.status_code} "
f"body={last_resp.text[:300]}"
)
raise RuntimeError(f"request failed after retry: {url}: unknown error")
def github_headers() -> dict[str, str]:
headers = {"Accept": "application/vnd.github+json"}
token = os.getenv("GITHUB_TOKEN", "").strip()
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def api_get_json(url: str, *, timeout: int = 30, retries: int = 5, base_sleep: float = 1.5) -> dict:
resp = requests_with_retry(
"GET",
url,
headers=github_headers(),
timeout=timeout,
retries=retries,
base_sleep=base_sleep,
)
if resp.status_code >= 400:
raise RuntimeError(f"GET {url} failed: {resp.status_code}: {resp.text[:200]}")
return resp.json()
def load_repo_tree(owner: str, repo: str) -> tuple[str, list[dict]]:
repo_meta = api_get_json(f"https://api.github.com/repos/{owner}/{repo}")
branch = repo_meta.get("default_branch") or "master"
tree = api_get_json(
f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1"
)
if tree.get("truncated"):
raise RuntimeError("github tree API truncated response; aborting import")
return branch, tree.get("tree", [])
def looks_like_problem_file(path: str) -> bool:
lower = path.lower()
ext = os.path.splitext(lower)[1]
if ext not in PDF_EXTS:
return False
if not lower.startswith(ROOT_PREFIXES):
return False
if any(key in lower for key in SKIP_KEYWORDS):
return False
return True
def normalize_title(path: str) -> str:
base = os.path.splitext(os.path.basename(path))[0]
base = base.replace("_", " ").replace("-", " ").strip()
parent = os.path.basename(os.path.dirname(path)).strip()
if base.lower() in {"day0", "day1", "day2", "problem", "task", "test"}:
return f"{parent} {base}".strip()
if len(base) <= 2:
return f"{parent} {base}".strip()
return base
def estimate_difficulty(path: str) -> int:
p = path.lower()
if p.startswith("ioi/"):
return 5
if p.startswith("noi/"):
return 4
if "提高" in path or p.startswith("csp-s/"):
return 3
if "普及" in path or p.startswith("csp-j/") or p.startswith("csp-x"):
return 2
return 3
def build_tags(path: str) -> list[str]:
tags: set[str] = set()
root = path.split("/", 1)[0].lower()
tags.add(root)
lower = path.lower()
if "day1" in lower:
tags.add("day1")
if "day2" in lower:
tags.add("day2")
if "round1" in lower:
tags.add("round1")
if "round2" in lower:
tags.add("round2")
if "提高" in path:
tags.add("senior")
if "普及" in path:
tags.add("junior")
for part in path.split("/"):
if re.fullmatch(r"(19|20)\d{2}", part):
tags.add(part)
break
return sorted(tags)
def make_slug(path: str) -> str:
stem = os.path.splitext(path)[0].lower()
stem = re.sub(r"[^a-z0-9]+", "-", stem).strip("-")
if not stem:
stem = "oi"
digest = hashlib.sha1(path.encode("utf-8")).hexdigest()[:8]
if len(stem) > 54:
stem = stem[:54].rstrip("-")
return f"{stem}-{digest}"
def build_urls(owner: str, repo: str, branch: str, path: str) -> tuple[str, str]:
quoted_path = quote(path, safe="/")
quoted_branch = quote(branch, safe="")
blob_url = f"https://github.com/{owner}/{repo}/blob/{quoted_branch}/{quoted_path}"
raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{quoted_branch}/{quoted_path}"
return blob_url, raw_url
def download_pdf(
owner: str,
repo: str,
branch: str,
path: str,
cache_dir: Path,
retry_max: int,
retry_sleep: float,
) -> tuple[Path, str, str]:
blob_url, raw_url = build_urls(owner, repo, branch, path)
local_path = cache_dir / Path(path)
local_path.parent.mkdir(parents=True, exist_ok=True)
try:
resp = requests_with_retry(
"GET",
raw_url,
timeout=45,
stream=True,
retries=retry_max,
base_sleep=retry_sleep,
)
if resp.status_code >= 400:
raise RuntimeError(f"download pdf failed: {raw_url}: {resp.status_code}")
with local_path.open("wb") as f:
for chunk in resp.iter_content(chunk_size=1024 * 64):
if chunk:
f.write(chunk)
except Exception:
curl_bin = shutil.which("curl")
if not curl_bin:
raise
cmd = [
curl_bin,
"-fL",
"--retry",
str(max(1, retry_max)),
"--retry-delay",
str(max(1, int(retry_sleep))),
"--connect-timeout",
"10",
"--max-time",
"120",
"-o",
str(local_path),
raw_url,
]
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
if local_path.stat().st_size <= 0:
raise RuntimeError(f"downloaded empty pdf: {raw_url}")
return local_path, blob_url, raw_url
def extract_pdf_text(pdf_path: Path, max_pages: int, max_chars: int) -> str:
txt_path = pdf_path.with_suffix(pdf_path.suffix + ".txt")
cmd = [
"pdftotext",
"-enc",
"UTF-8",
"-f",
"1",
"-l",
str(max_pages),
str(pdf_path),
str(txt_path),
]
timeout_sec = float(os.getenv("OI_PDFTOTEXT_TIMEOUT_SEC", "20"))
try:
subprocess.run(
cmd,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
timeout=max(1.0, timeout_sec),
)
except Exception:
return ""
if not txt_path.exists():
return ""
try:
text = txt_path.read_text(encoding="utf-8", errors="ignore")
except Exception:
return ""
text = re.sub(r"\r\n?", "\n", text)
text = re.sub(r"\n{3,}", "\n\n", text).strip()
return text[:max_chars]
def extract_first_json(text: str) -> dict | None:
text = text.strip()
if not text:
return None
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
try:
parsed = json.loads(text)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
return None
def llm_is_enabled() -> bool:
return bool(os.getenv("OI_LLM_API_URL", "").strip()) and bool(
os.getenv("OI_LLM_API_KEY", "").strip()
)
def read_stream_content(resp: requests.Response) -> str:
chunks: list[str] = []
for line in resp.iter_lines(decode_unicode=True):
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
try:
event = json.loads(data)
except json.JSONDecodeError:
continue
choice = (event.get("choices") or [{}])[0]
delta = (choice.get("delta") or {}).get("content")
if isinstance(delta, str):
chunks.append(delta)
msg = (choice.get("message") or {}).get("content")
if isinstance(msg, str):
chunks.append(msg)
return "".join(chunks).strip()
def call_llm_with_retry(prompt: str) -> dict | None:
api_url = os.getenv("OI_LLM_API_URL", "").strip()
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
retries = int(os.getenv("OI_LLM_RETRY_MAX", "5"))
base_sleep = float(os.getenv("OI_LLM_RETRY_SLEEP_SEC", "1.5"))
stream_enabled = to_bool(os.getenv("OI_LLM_STREAM"), False)
if not api_url or not api_key:
return None
payload = {
"model": model,
"messages": [
{
"role": "system",
"content": (
"Return strict JSON only. "
"Schema: {"
"\"title\":string,"
"\"difficulty\":1-5,"
"\"answer\":string,"
"\"explanation\":string,"
"\"knowledge_points\":[string],"
"\"tags\":[string],"
"\"statement_summary_md\":string,"
"\"sample_input\":string,"
"\"sample_output\":string"
"}."
),
},
{"role": "user", "content": prompt},
],
"stream": stream_enabled,
"temperature": 0.2,
}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
for attempt in range(1, retries + 1):
try:
resp = requests_with_retry(
"POST",
api_url,
headers=headers,
json_body=payload,
timeout=120,
stream=stream_enabled,
retries=1,
base_sleep=base_sleep,
)
except Exception:
if attempt == retries:
return None
sleep_backoff(base_sleep, attempt)
continue
if resp.status_code in RETRYABLE_HTTP_CODES:
if attempt == retries:
return None
sleep_backoff(base_sleep, attempt)
continue
if resp.status_code >= 400:
return None
try:
if stream_enabled:
content = read_stream_content(resp)
else:
body = resp.json()
content = body["choices"][0]["message"]["content"]
except Exception:
if attempt == retries:
return None
sleep_backoff(base_sleep, attempt)
continue
parsed = extract_first_json(content)
if parsed is not None:
return parsed
if attempt < retries:
sleep_backoff(base_sleep, attempt)
return None
def as_str(v: Any, default: str = "") -> str:
if isinstance(v, str):
return v.strip()
if v is None:
return default
return str(v).strip()
def as_difficulty(v: Any, default: int) -> int:
try:
n = int(v)
except Exception:
n = default
return max(1, min(5, n))
def as_list_of_str(v: Any, *, max_items: int, max_len: int) -> list[str]:
if not isinstance(v, list):
return []
out: list[str] = []
seen: set[str] = set()
for item in v:
text = as_str(item)
if not text:
continue
if len(text) > max_len:
text = text[:max_len].strip()
key = text.lower()
if key in seen:
continue
seen.add(key)
out.append(text)
if len(out) >= max_items:
break
return out
def build_llm_prompt(
*,
path: str,
fallback_title: str,
fallback_difficulty: int,
blob_url: str,
raw_url: str,
extracted_text: str,
) -> str:
snippet = extracted_text.strip() or "(empty; rely on URL and metadata)"
return (
"You are an olympiad/CSP exam analyst.\n"
"Read this PDF metadata and OCR text, then return strict JSON only.\n"
"Keep answer and explanation concise.\n"
"Difficulty range is 1..5.\n\n"
f"File path: {path}\n"
f"Accessible PDF URL: {raw_url}\n"
f"GitHub page URL: {blob_url}\n"
f"Fallback title: {fallback_title}\n"
f"Fallback difficulty: {fallback_difficulty}\n\n"
"Extracted text snippet:\n"
f"{snippet}\n"
)
def build_statement_markdown(
*,
title: str,
path: str,
blob_url: str,
summary_md: str,
answer: str,
explanation: str,
knowledge_points: list[str],
) -> str:
lines: list[str] = [
f"# {title}",
"",
f"- Source repo: `{SOURCE_PREFIX}`",
f"- Source file: `{path}`",
f"- Original statement PDF: {blob_url}",
]
if summary_md:
lines += ["", "## Summary", "", summary_md]
if knowledge_points:
lines += ["", "## Knowledge Points", ""]
lines += [f"- {kp}" for kp in knowledge_points]
if answer:
lines += ["", "## Answer", "", answer]
if explanation:
lines += ["", "## Explanation", "", explanation]
return "\n".join(lines).strip()
def fallback_summary_from_text(text: str, max_chars: int = 500) -> str:
lines = [line.strip() for line in text.splitlines() if line.strip()]
if not lines:
return "题面 OCR 摘要生成失败,请参考原始 PDF。"
summary = " ".join(lines[:8]).strip()
if len(summary) > max_chars:
summary = summary[:max_chars].rstrip() + "..."
return summary
def fallback_knowledge_points(path: str) -> list[str]:
p = path.lower()
points: list[str] = []
if p.startswith("csp-j/"):
points += ["模拟", "基础算法"]
elif p.startswith("csp-s/"):
points += ["数据结构", "算法设计"]
elif p.startswith("noi/") or p.startswith("ioi/"):
points += ["高级算法", "复杂度分析"]
else:
points += ["算法基础"]
if "round2" in p or "day2" in p:
points.append("综合应用")
if "round1" in p or "day1" in p:
points.append("基础能力")
out: list[str] = []
seen: set[str] = set()
for item in points:
key = item.lower()
if key in seen:
continue
seen.add(key)
out.append(item)
if len(out) >= 6:
break
return out
def build_record(
*,
owner: str,
repo: str,
branch: str,
path: str,
cache_dir: Path,
pdf_retry_max: int,
pdf_retry_sleep: float,
llm_enabled: bool,
pdf_text_max_pages: int,
pdf_text_max_chars: int,
) -> tuple[ProblemRecord, bool]:
fallback_title = normalize_title(path)
fallback_difficulty = estimate_difficulty(path)
base_tags = build_tags(path)
pdf_path, blob_url, raw_url = download_pdf(
owner, repo, branch, path, cache_dir, pdf_retry_max, pdf_retry_sleep
)
extracted_text = extract_pdf_text(pdf_path, pdf_text_max_pages, pdf_text_max_chars)
llm_ok = False
llm_data: dict[str, Any] = {}
if llm_enabled:
prompt = build_llm_prompt(
path=path,
fallback_title=fallback_title,
fallback_difficulty=fallback_difficulty,
blob_url=blob_url,
raw_url=raw_url,
extracted_text=extracted_text,
)
parsed = call_llm_with_retry(prompt)
if isinstance(parsed, dict):
llm_data = parsed
llm_ok = True
title = as_str(llm_data.get("title"), fallback_title) if llm_ok else fallback_title
title = title[:180] if title else fallback_title
difficulty = as_difficulty(llm_data.get("difficulty"), fallback_difficulty) if llm_ok else fallback_difficulty
answer = as_str(llm_data.get("answer"), "") if llm_ok else ""
explanation = as_str(llm_data.get("explanation"), "") if llm_ok else ""
summary_md = as_str(llm_data.get("statement_summary_md"), "") if llm_ok else ""
sample_input = as_str(llm_data.get("sample_input"), "")
sample_output = as_str(llm_data.get("sample_output"), "")
knowledge_points = as_list_of_str(llm_data.get("knowledge_points"), max_items=12, max_len=120)
llm_tags = as_list_of_str(llm_data.get("tags"), max_items=16, max_len=48)
tags = sorted({*(t.lower() for t in base_tags), *(t.lower() for t in llm_tags)})
tags = [t for t in tags if t][:20]
if not summary_md:
summary_md = fallback_summary_from_text(extracted_text)
if not knowledge_points:
knowledge_points = fallback_knowledge_points(path)
if not answer:
answer = "LLM 识别失败;请参考原始 PDF 与样例输出(待人工复核)。"
if not explanation:
explanation = "已保留原题 PDF 与 OCR 摘要,可据此补充标准解法。"
statement_md = build_statement_markdown(
title=title,
path=path,
blob_url=blob_url,
summary_md=summary_md,
answer=answer,
explanation=explanation,
knowledge_points=knowledge_points,
)
profile = {
"schema_version": 1,
"title": title,
"difficulty": difficulty,
"answer": answer,
"explanation": explanation,
"knowledge_points": knowledge_points,
"tags": tags,
"statement_summary_md": summary_md,
"sample_input": sample_input,
"sample_output": sample_output,
"source": {
"repo": SOURCE_PREFIX,
"path": path,
"blob_url": blob_url,
"raw_pdf_url": raw_url,
},
"llm": {
"enabled": llm_enabled,
"success": llm_ok,
"model": os.getenv("OI_LLM_MODEL", "qwen3-max").strip(),
"generated_at": now_sec(),
},
}
record = ProblemRecord(
slug=make_slug(path),
title=title,
statement_md=statement_md,
difficulty=difficulty,
source=f"{SOURCE_PREFIX}:{path}",
statement_url=blob_url,
llm_profile_json=json.dumps(profile, ensure_ascii=False),
sample_input=sample_input,
sample_output=sample_output,
tags=tags,
)
return record, llm_ok
def ensure_problem_columns(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute("PRAGMA table_info(problems)")
existing = {str(row[1]) for row in cur.fetchall()}
required = {
"sample_input": "ALTER TABLE problems ADD COLUMN sample_input TEXT NOT NULL DEFAULT ''",
"sample_output": "ALTER TABLE problems ADD COLUMN sample_output TEXT NOT NULL DEFAULT ''",
"statement_url": "ALTER TABLE problems ADD COLUMN statement_url TEXT NOT NULL DEFAULT ''",
"llm_profile_json": "ALTER TABLE problems ADD COLUMN llm_profile_json TEXT NOT NULL DEFAULT '{}'",
}
for col, sql in required.items():
if col not in existing:
cur.execute(sql)
conn.commit()
def ensure_import_tables(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.executescript(
"""
CREATE TABLE IF NOT EXISTS import_jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
status TEXT NOT NULL,
trigger TEXT NOT NULL DEFAULT 'manual',
total_count INTEGER NOT NULL DEFAULT 0,
processed_count INTEGER NOT NULL DEFAULT 0,
success_count INTEGER NOT NULL DEFAULT 0,
failed_count INTEGER NOT NULL DEFAULT 0,
options_json TEXT NOT NULL DEFAULT '{}',
last_error TEXT NOT NULL DEFAULT '',
started_at INTEGER NOT NULL,
finished_at INTEGER,
updated_at INTEGER NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS import_job_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
job_id INTEGER NOT NULL,
source_path TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
title TEXT NOT NULL DEFAULT '',
difficulty INTEGER NOT NULL DEFAULT 0,
problem_id INTEGER,
error_text TEXT NOT NULL DEFAULT '',
started_at INTEGER,
finished_at INTEGER,
updated_at INTEGER NOT NULL,
created_at INTEGER NOT NULL,
UNIQUE(job_id, source_path)
);
CREATE INDEX IF NOT EXISTS idx_import_jobs_created_at ON import_jobs(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_import_jobs_status ON import_jobs(status, updated_at DESC);
CREATE INDEX IF NOT EXISTS idx_import_job_items_job_status
ON import_job_items(job_id, status, updated_at DESC);
"""
)
conn.commit()
def mark_stale_jobs(conn: sqlite3.Connection) -> None:
now = now_sec()
conn.execute(
"""
UPDATE import_jobs
SET status='failed',
last_error=CASE
WHEN last_error='' THEN 'interrupted before finish'
ELSE last_error
END,
finished_at=?,
updated_at=?
WHERE status='running'
""",
(now, now),
)
conn.commit()
def create_job(conn: sqlite3.Connection, trigger: str, options_json: str, total_count: int) -> int:
now = now_sec()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO import_jobs(
status, trigger, total_count, processed_count, success_count, failed_count,
options_json, last_error, started_at, finished_at, updated_at, created_at
) VALUES(?, ?, ?, 0, 0, 0, ?, '', ?, NULL, ?, ?)
""",
("running", trigger, total_count, options_json, now, now, now),
)
conn.commit()
return int(cur.lastrowid)
def seed_job_items(conn: sqlite3.Connection, job_id: int, paths: list[str]) -> None:
now = now_sec()
conn.executemany(
"""
INSERT OR IGNORE INTO import_job_items(
job_id, source_path, status, title, difficulty, problem_id, error_text,
started_at, finished_at, updated_at, created_at
) VALUES(?, ?, 'queued', '', 0, NULL, '', NULL, NULL, ?, ?)
""",
[(job_id, p, now, now) for p in paths],
)
conn.commit()
def mark_item_running(db_path: str, job_id: int, source_path: str) -> None:
now = now_sec()
with sqlite3.connect(db_path) as conn:
conn.execute(
"""
UPDATE import_job_items
SET status='running', started_at=?, updated_at=?
WHERE job_id=? AND source_path=?
""",
(now, now, job_id, source_path),
)
conn.commit()
def update_item_success(
conn: sqlite3.Connection,
*,
job_id: int,
source_path: str,
title: str,
difficulty: int,
problem_id: int,
) -> None:
now = now_sec()
conn.execute(
"""
UPDATE import_job_items
SET status='success', title=?, difficulty=?, problem_id=?, error_text='',
finished_at=?, updated_at=?
WHERE job_id=? AND source_path=?
""",
(title, difficulty, problem_id, now, now, job_id, source_path),
)
def update_item_failed(conn: sqlite3.Connection, *, job_id: int, source_path: str, error_text: str) -> None:
now = now_sec()
conn.execute(
"""
UPDATE import_job_items
SET status='failed', error_text=?, finished_at=?, updated_at=?
WHERE job_id=? AND source_path=?
""",
(error_text[:1000], now, now, job_id, source_path),
)
def update_job_progress(
conn: sqlite3.Connection,
*,
job_id: int,
processed: int,
success: int,
failed: int,
) -> None:
now = now_sec()
conn.execute(
"""
UPDATE import_jobs
SET processed_count=?, success_count=?, failed_count=?, updated_at=?
WHERE id=?
""",
(processed, success, failed, now, job_id),
)
def finish_job(conn: sqlite3.Connection, *, job_id: int, processed: int, success: int, failed: int, last_error: str) -> None:
now = now_sec()
status = "success"
if success == 0 and failed > 0:
status = "failed"
elif failed > 0:
status = "partial_failed"
conn.execute(
"""
UPDATE import_jobs
SET status=?, processed_count=?, success_count=?, failed_count=?,
last_error=?, finished_at=?, updated_at=?
WHERE id=?
""",
(status, processed, success, failed, last_error[:1000], now, now, job_id),
)
conn.commit()
def clear_existing_records(
conn: sqlite3.Connection,
*,
clear_all_problems: bool,
clear_existing: bool,
clear_existing_source_prefix: str,
) -> int:
if clear_all_problems:
cur = conn.execute("SELECT COUNT(1) FROM problems")
count = int(cur.fetchone()[0] or 0)
conn.execute("DELETE FROM problems")
conn.commit()
return count
if clear_existing:
like_pattern = f"{clear_existing_source_prefix}:%"
cur = conn.execute("SELECT COUNT(1) FROM problems WHERE source LIKE ?", (like_pattern,))
count = int(cur.fetchone()[0] or 0)
conn.execute("DELETE FROM problems WHERE source LIKE ?", (like_pattern,))
conn.commit()
return count
return 0
def upsert_single_record(conn: sqlite3.Connection, rec: ProblemRecord) -> tuple[int, bool]:
cur = conn.cursor()
cur.execute("SELECT id FROM problems WHERE slug=?", (rec.slug,))
row = cur.fetchone()
if row is None:
cur.execute(
"""
INSERT INTO problems(
slug, title, statement_md, difficulty, source, statement_url, llm_profile_json,
sample_input, sample_output, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
rec.slug,
rec.title,
rec.statement_md,
rec.difficulty,
rec.source,
rec.statement_url,
rec.llm_profile_json,
rec.sample_input,
rec.sample_output,
now_sec(),
),
)
problem_id = int(cur.lastrowid)
inserted = True
else:
problem_id = int(row[0])
cur.execute(
"""
UPDATE problems
SET title=?, statement_md=?, difficulty=?, source=?, statement_url=?,
llm_profile_json=?, sample_input=?, sample_output=?
WHERE id=?
""",
(
rec.title,
rec.statement_md,
rec.difficulty,
rec.source,
rec.statement_url,
rec.llm_profile_json,
rec.sample_input,
rec.sample_output,
problem_id,
),
)
inserted = False
cur.execute("DELETE FROM problem_tags WHERE problem_id=?", (problem_id,))
for tag in rec.tags:
cur.execute("INSERT OR IGNORE INTO problem_tags(problem_id, tag) VALUES (?, ?)", (problem_id, tag))
conn.commit()
return problem_id, inserted
def worker_build_record(
*,
db_path: str,
job_id: int,
owner: str,
repo: str,
branch: str,
path: str,
cache_dir: Path,
pdf_retry_max: int,
pdf_retry_sleep: float,
llm_enabled: bool,
pdf_text_max_pages: int,
pdf_text_max_chars: int,
) -> WorkResult:
try:
mark_item_running(db_path, job_id, path)
rec, llm_ok = build_record(
owner=owner,
repo=repo,
branch=branch,
path=path,
cache_dir=cache_dir,
pdf_retry_max=pdf_retry_max,
pdf_retry_sleep=pdf_retry_sleep,
llm_enabled=llm_enabled,
pdf_text_max_pages=pdf_text_max_pages,
pdf_text_max_chars=pdf_text_max_chars,
)
return WorkResult(path=path, record=rec, llm_ok=llm_ok, error="")
except Exception as exc:
return WorkResult(path=path, record=None, llm_ok=False, error=str(exc))
def main() -> int:
parser = argparse.ArgumentParser(
description="Import OI PDF catalog from winterant/oi via PDF download + LLM recognition"
)
parser.add_argument(
"--db-path",
default=os.getenv("CSP_DB_PATH", "/var/lib/docker/volumes/csp_csp_data/_data/csp.db"),
help="SQLite DB path",
)
parser.add_argument("--owner", default=DEFAULT_OWNER)
parser.add_argument("--repo", default=DEFAULT_REPO)
parser.add_argument("--max-problems", type=int, default=0, help="0 means all candidates")
parser.add_argument("--workers", type=int, default=int(os.getenv("OI_IMPORT_WORKERS", "3")))
parser.add_argument("--skip-llm", action="store_true", help="Skip LLM parsing and use fallback metadata")
parser.add_argument("--llm-limit", type=int, default=0, help="Max records allowed to call LLM, 0 means no limit")
parser.add_argument("--job-trigger", default=os.getenv("OI_IMPORT_JOB_TRIGGER", "manual"))
parser.add_argument("--clear-existing", action="store_true", help="Delete previous imported records before run")
parser.add_argument(
"--clear-existing-source-prefix",
default=os.getenv("OI_IMPORT_CLEAR_SOURCE_PREFIX", SOURCE_PREFIX),
help="Source prefix for --clear-existing, like winterant/oi",
)
parser.add_argument("--clear-all-problems", action="store_true", help="Delete all records in problems before run")
parser.add_argument("--pdf-cache-dir", default=os.getenv("OI_PDF_CACHE_DIR", "/tmp/csp-oi-pdf-cache"))
parser.add_argument("--pdf-text-max-pages", type=int, default=int(os.getenv("OI_PDF_TEXT_MAX_PAGES", "8")))
parser.add_argument("--pdf-text-max-chars", type=int, default=int(os.getenv("OI_PDF_TEXT_MAX_CHARS", "10000")))
parser.add_argument("--pdf-retry-max", type=int, default=int(os.getenv("OI_PDF_RETRY_MAX", "5")))
parser.add_argument("--pdf-retry-sleep-sec", type=float, default=float(os.getenv("OI_PDF_RETRY_SLEEP_SEC", "1.5")))
parser.add_argument("--lock-file", default=os.getenv("OI_IMPORT_LOCK_FILE", ""))
args = parser.parse_args()
if not os.path.exists(args.db_path):
raise FileNotFoundError(f"database file not found: {args.db_path}")
if not args.skip_llm and not llm_is_enabled():
raise RuntimeError("LLM config missing. Set OI_LLM_API_URL and OI_LLM_API_KEY, or pass --skip-llm")
worker_count = max(1, args.workers)
cache_dir = Path(args.pdf_cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
lock_file = args.lock_file.strip() or f"{args.db_path}.import.lock"
lock_handle = acquire_import_lock(lock_file)
_ = lock_handle
conn = sqlite3.connect(args.db_path)
conn.row_factory = sqlite3.Row
ensure_problem_columns(conn)
ensure_import_tables(conn)
mark_stale_jobs(conn)
branch, tree = load_repo_tree(args.owner, args.repo)
candidate_paths = sorted(
item["path"]
for item in tree
if item.get("type") == "blob" and looks_like_problem_file(item.get("path", ""))
)
if args.max_problems > 0:
candidate_paths = candidate_paths[: args.max_problems]
cleared_count = clear_existing_records(
conn,
clear_all_problems=bool(args.clear_all_problems),
clear_existing=bool(args.clear_existing),
clear_existing_source_prefix=args.clear_existing_source_prefix.strip() or SOURCE_PREFIX,
)
options_json = json.dumps(
{
"owner": args.owner,
"repo": args.repo,
"branch": branch,
"workers": worker_count,
"skip_llm": bool(args.skip_llm),
"llm_limit": max(0, args.llm_limit),
"clear_existing": bool(args.clear_existing),
"clear_all_problems": bool(args.clear_all_problems),
"clear_existing_source_prefix": args.clear_existing_source_prefix,
"cleared_count": cleared_count,
"pdf_cache_dir": str(cache_dir),
"lock_file": lock_file,
},
ensure_ascii=False,
)
job_id = create_job(conn, args.job_trigger, options_json, len(candidate_paths))
seed_job_items(conn, job_id, candidate_paths)
processed = 0
success = 0
failed = 0
inserted = 0
updated = 0
llm_success = 0
llm_failed = 0
last_error = ""
remaining_llm = None if args.llm_limit <= 0 else max(0, args.llm_limit)
def can_use_llm() -> bool:
nonlocal remaining_llm
if args.skip_llm:
return False
if remaining_llm is None:
return True
if remaining_llm <= 0:
return False
remaining_llm -= 1
return True
futures = {}
with ThreadPoolExecutor(max_workers=worker_count) as executor:
for path in candidate_paths:
use_llm = can_use_llm()
fut = executor.submit(
worker_build_record,
db_path=args.db_path,
job_id=job_id,
owner=args.owner,
repo=args.repo,
branch=branch,
path=path,
cache_dir=cache_dir,
pdf_retry_max=max(1, args.pdf_retry_max),
pdf_retry_sleep=max(0.1, args.pdf_retry_sleep_sec),
llm_enabled=use_llm,
pdf_text_max_pages=max(1, args.pdf_text_max_pages),
pdf_text_max_chars=max(2000, args.pdf_text_max_chars),
)
futures[fut] = use_llm
total = len(futures)
for fut in as_completed(futures):
use_llm = futures[fut]
result = fut.result()
processed += 1
if result.record is not None:
problem_id, is_inserted = upsert_single_record(conn, result.record)
update_item_success(
conn,
job_id=job_id,
source_path=result.path,
title=result.record.title,
difficulty=result.record.difficulty,
problem_id=problem_id,
)
if is_inserted:
inserted += 1
else:
updated += 1
success += 1
if use_llm:
if result.llm_ok:
llm_success += 1
else:
llm_failed += 1
print(
f"[{processed}/{total}] {result.path} -> {result.record.title} "
f"(difficulty={result.record.difficulty}, llm={'ok' if result.llm_ok else 'fallback'})",
flush=True,
)
else:
failed += 1
last_error = result.error
update_item_failed(
conn,
job_id=job_id,
source_path=result.path,
error_text=result.error,
)
print(f"[skip] {result.path}: {result.error}", flush=True)
update_job_progress(
conn,
job_id=job_id,
processed=processed,
success=success,
failed=failed,
)
conn.commit()
finish_job(
conn,
job_id=job_id,
processed=processed,
success=success,
failed=failed,
last_error=last_error,
)
conn.close()
print(
json.dumps(
{
"job_id": job_id,
"db_path": args.db_path,
"branch": branch,
"workers": worker_count,
"candidates": len(candidate_paths),
"processed": processed,
"success": success,
"failed": failed,
"inserted": inserted,
"updated": updated,
"cleared_count": cleared_count,
"llm_enabled_default": not args.skip_llm,
"llm_success": llm_success,
"llm_failed_fallback": llm_failed,
"pdf_cache_dir": str(cache_dir),
"lock_file": lock_file,
},
ensure_ascii=False,
indent=2,
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())