1303 行
40 KiB
Python
1303 行
40 KiB
Python
#!/usr/bin/env python3
|
|
"""Import OI statement PDFs from winterant/oi with PDF + LLM pipeline.
|
|
|
|
Pipeline:
|
|
1) Discover candidate PDF files from GitHub tree.
|
|
2) Optionally clear old problem set.
|
|
3) Download each PDF to local cache.
|
|
4) Extract text via pdftotext.
|
|
5) Call LLM with retry on HTTP 500/502/503/504.
|
|
6) Upsert into SQLite problems/problem_tags.
|
|
7) Persist progress/result into import_jobs/import_job_items for UI status pages.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import fcntl
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sqlite3
|
|
import subprocess
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
import requests
|
|
|
|
|
|
DEFAULT_OWNER = "winterant"
|
|
DEFAULT_REPO = "oi"
|
|
SOURCE_PREFIX = "winterant/oi"
|
|
RETRYABLE_HTTP_CODES = {500, 502, 503, 504}
|
|
|
|
PDF_EXTS = {".pdf"}
|
|
ROOT_PREFIXES = ("ioi/", "noi/", "noip/", "csp-j/", "csp-s/", "csp-x")
|
|
SKIP_KEYWORDS = (
|
|
"solution",
|
|
"sol",
|
|
"answer",
|
|
"answers",
|
|
"readme",
|
|
"sample",
|
|
"data",
|
|
"testdata",
|
|
"overview",
|
|
"answer-sheet",
|
|
"standard",
|
|
"templates",
|
|
"模板",
|
|
"答案",
|
|
"题解",
|
|
"解析",
|
|
"答题",
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProblemRecord:
|
|
slug: str
|
|
title: str
|
|
statement_md: str
|
|
difficulty: int
|
|
source: str
|
|
statement_url: str
|
|
llm_profile_json: str
|
|
sample_input: str
|
|
sample_output: str
|
|
tags: list[str]
|
|
|
|
|
|
@dataclass
|
|
class WorkResult:
|
|
path: str
|
|
record: ProblemRecord | None
|
|
llm_ok: bool
|
|
error: str
|
|
|
|
|
|
def now_sec() -> int:
|
|
return int(time.time())
|
|
|
|
|
|
def acquire_import_lock(lock_file: str):
|
|
lock_path = Path(lock_file)
|
|
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
|
handle = lock_path.open("a+")
|
|
try:
|
|
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
except OSError as exc:
|
|
handle.close()
|
|
raise RuntimeError(
|
|
f"another import process is running (lock_file={lock_file}): {exc}"
|
|
) from exc
|
|
return handle
|
|
|
|
|
|
def sleep_backoff(base_sleep: float, attempt: int) -> None:
|
|
time.sleep(base_sleep * attempt)
|
|
|
|
|
|
def to_bool(value: str | None, default: bool = False) -> bool:
|
|
if value is None:
|
|
return default
|
|
normalized = value.strip().lower()
|
|
if normalized in {"1", "true", "yes", "on"}:
|
|
return True
|
|
if normalized in {"0", "false", "no", "off"}:
|
|
return False
|
|
return default
|
|
|
|
|
|
def requests_with_retry(
|
|
method: str,
|
|
url: str,
|
|
*,
|
|
headers: dict[str, str] | None = None,
|
|
json_body: dict[str, Any] | None = None,
|
|
timeout: int = 90,
|
|
stream: bool = False,
|
|
retries: int = 5,
|
|
base_sleep: float = 1.5,
|
|
) -> requests.Response:
|
|
allow_direct_fallback = to_bool(os.getenv("OI_IMPORT_DIRECT_FALLBACK"), True)
|
|
prefer_direct = to_bool(os.getenv("OI_IMPORT_PREFER_DIRECT"), True)
|
|
request_order = [True, False] if allow_direct_fallback and prefer_direct else [False, True]
|
|
if not allow_direct_fallback:
|
|
request_order = [False]
|
|
|
|
def do_request(use_direct: bool) -> requests.Response:
|
|
if not use_direct:
|
|
return requests.request(
|
|
method,
|
|
url,
|
|
headers=headers,
|
|
json=json_body,
|
|
timeout=timeout,
|
|
stream=stream,
|
|
)
|
|
with requests.Session() as session:
|
|
session.trust_env = False
|
|
return session.request(
|
|
method,
|
|
url,
|
|
headers=headers,
|
|
json=json_body,
|
|
timeout=timeout,
|
|
stream=stream,
|
|
)
|
|
|
|
last_error: Exception | None = None
|
|
last_resp: requests.Response | None = None
|
|
|
|
for attempt in range(1, retries + 1):
|
|
seen_retryable = False
|
|
for use_direct in request_order:
|
|
try:
|
|
resp = do_request(use_direct)
|
|
last_resp = resp
|
|
except requests.RequestException as exc:
|
|
last_error = exc
|
|
continue
|
|
|
|
if resp.status_code in RETRYABLE_HTTP_CODES:
|
|
seen_retryable = True
|
|
continue
|
|
return resp
|
|
|
|
if attempt < retries and seen_retryable:
|
|
sleep_backoff(base_sleep, attempt)
|
|
continue
|
|
if attempt < retries and last_error is not None:
|
|
sleep_backoff(base_sleep, attempt)
|
|
continue
|
|
break
|
|
|
|
if last_error:
|
|
raise RuntimeError(f"request failed after retry: {url}: {last_error}") from last_error
|
|
if last_resp is not None:
|
|
raise RuntimeError(
|
|
f"request failed after retry: {url}: status={last_resp.status_code} "
|
|
f"body={last_resp.text[:300]}"
|
|
)
|
|
raise RuntimeError(f"request failed after retry: {url}: unknown error")
|
|
|
|
|
|
def github_headers() -> dict[str, str]:
|
|
headers = {"Accept": "application/vnd.github+json"}
|
|
token = os.getenv("GITHUB_TOKEN", "").strip()
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
return headers
|
|
|
|
|
|
def api_get_json(url: str, *, timeout: int = 30, retries: int = 5, base_sleep: float = 1.5) -> dict:
|
|
resp = requests_with_retry(
|
|
"GET",
|
|
url,
|
|
headers=github_headers(),
|
|
timeout=timeout,
|
|
retries=retries,
|
|
base_sleep=base_sleep,
|
|
)
|
|
if resp.status_code >= 400:
|
|
raise RuntimeError(f"GET {url} failed: {resp.status_code}: {resp.text[:200]}")
|
|
return resp.json()
|
|
|
|
|
|
def load_repo_tree(owner: str, repo: str) -> tuple[str, list[dict]]:
|
|
repo_meta = api_get_json(f"https://api.github.com/repos/{owner}/{repo}")
|
|
branch = repo_meta.get("default_branch") or "master"
|
|
tree = api_get_json(
|
|
f"https://api.github.com/repos/{owner}/{repo}/git/trees/{quote(branch)}?recursive=1"
|
|
)
|
|
if tree.get("truncated"):
|
|
raise RuntimeError("github tree API truncated response; aborting import")
|
|
return branch, tree.get("tree", [])
|
|
|
|
|
|
def looks_like_problem_file(path: str) -> bool:
|
|
lower = path.lower()
|
|
ext = os.path.splitext(lower)[1]
|
|
if ext not in PDF_EXTS:
|
|
return False
|
|
if not lower.startswith(ROOT_PREFIXES):
|
|
return False
|
|
if any(key in lower for key in SKIP_KEYWORDS):
|
|
return False
|
|
return True
|
|
|
|
|
|
def normalize_title(path: str) -> str:
|
|
base = os.path.splitext(os.path.basename(path))[0]
|
|
base = base.replace("_", " ").replace("-", " ").strip()
|
|
parent = os.path.basename(os.path.dirname(path)).strip()
|
|
if base.lower() in {"day0", "day1", "day2", "problem", "task", "test"}:
|
|
return f"{parent} {base}".strip()
|
|
if len(base) <= 2:
|
|
return f"{parent} {base}".strip()
|
|
return base
|
|
|
|
|
|
def estimate_difficulty(path: str) -> int:
|
|
p = path.lower()
|
|
if p.startswith("ioi/"):
|
|
return 5
|
|
if p.startswith("noi/"):
|
|
return 4
|
|
if "提高" in path or p.startswith("csp-s/"):
|
|
return 3
|
|
if "普及" in path or p.startswith("csp-j/") or p.startswith("csp-x"):
|
|
return 2
|
|
return 3
|
|
|
|
|
|
def build_tags(path: str) -> list[str]:
|
|
tags: set[str] = set()
|
|
root = path.split("/", 1)[0].lower()
|
|
tags.add(root)
|
|
|
|
lower = path.lower()
|
|
if "day1" in lower:
|
|
tags.add("day1")
|
|
if "day2" in lower:
|
|
tags.add("day2")
|
|
if "round1" in lower:
|
|
tags.add("round1")
|
|
if "round2" in lower:
|
|
tags.add("round2")
|
|
if "提高" in path:
|
|
tags.add("senior")
|
|
if "普及" in path:
|
|
tags.add("junior")
|
|
|
|
for part in path.split("/"):
|
|
if re.fullmatch(r"(19|20)\d{2}", part):
|
|
tags.add(part)
|
|
break
|
|
return sorted(tags)
|
|
|
|
|
|
def make_slug(path: str) -> str:
|
|
stem = os.path.splitext(path)[0].lower()
|
|
stem = re.sub(r"[^a-z0-9]+", "-", stem).strip("-")
|
|
if not stem:
|
|
stem = "oi"
|
|
digest = hashlib.sha1(path.encode("utf-8")).hexdigest()[:8]
|
|
if len(stem) > 54:
|
|
stem = stem[:54].rstrip("-")
|
|
return f"{stem}-{digest}"
|
|
|
|
|
|
def build_urls(owner: str, repo: str, branch: str, path: str) -> tuple[str, str]:
|
|
quoted_path = quote(path, safe="/")
|
|
quoted_branch = quote(branch, safe="")
|
|
blob_url = f"https://github.com/{owner}/{repo}/blob/{quoted_branch}/{quoted_path}"
|
|
raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{quoted_branch}/{quoted_path}"
|
|
return blob_url, raw_url
|
|
|
|
|
|
def download_pdf(
|
|
owner: str,
|
|
repo: str,
|
|
branch: str,
|
|
path: str,
|
|
cache_dir: Path,
|
|
retry_max: int,
|
|
retry_sleep: float,
|
|
) -> tuple[Path, str, str]:
|
|
blob_url, raw_url = build_urls(owner, repo, branch, path)
|
|
local_path = cache_dir / Path(path)
|
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
resp = requests_with_retry(
|
|
"GET",
|
|
raw_url,
|
|
timeout=45,
|
|
stream=True,
|
|
retries=retry_max,
|
|
base_sleep=retry_sleep,
|
|
)
|
|
if resp.status_code >= 400:
|
|
raise RuntimeError(f"download pdf failed: {raw_url}: {resp.status_code}")
|
|
with local_path.open("wb") as f:
|
|
for chunk in resp.iter_content(chunk_size=1024 * 64):
|
|
if chunk:
|
|
f.write(chunk)
|
|
except Exception:
|
|
curl_bin = shutil.which("curl")
|
|
if not curl_bin:
|
|
raise
|
|
cmd = [
|
|
curl_bin,
|
|
"-fL",
|
|
"--retry",
|
|
str(max(1, retry_max)),
|
|
"--retry-delay",
|
|
str(max(1, int(retry_sleep))),
|
|
"--connect-timeout",
|
|
"10",
|
|
"--max-time",
|
|
"120",
|
|
"-o",
|
|
str(local_path),
|
|
raw_url,
|
|
]
|
|
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
|
|
if local_path.stat().st_size <= 0:
|
|
raise RuntimeError(f"downloaded empty pdf: {raw_url}")
|
|
return local_path, blob_url, raw_url
|
|
|
|
|
|
def extract_pdf_text(pdf_path: Path, max_pages: int, max_chars: int) -> str:
|
|
txt_path = pdf_path.with_suffix(pdf_path.suffix + ".txt")
|
|
cmd = [
|
|
"pdftotext",
|
|
"-enc",
|
|
"UTF-8",
|
|
"-f",
|
|
"1",
|
|
"-l",
|
|
str(max_pages),
|
|
str(pdf_path),
|
|
str(txt_path),
|
|
]
|
|
timeout_sec = float(os.getenv("OI_PDFTOTEXT_TIMEOUT_SEC", "20"))
|
|
try:
|
|
subprocess.run(
|
|
cmd,
|
|
check=False,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
timeout=max(1.0, timeout_sec),
|
|
)
|
|
except Exception:
|
|
return ""
|
|
|
|
if not txt_path.exists():
|
|
return ""
|
|
try:
|
|
text = txt_path.read_text(encoding="utf-8", errors="ignore")
|
|
except Exception:
|
|
return ""
|
|
text = re.sub(r"\r\n?", "\n", text)
|
|
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
|
return text[:max_chars]
|
|
|
|
|
|
def extract_first_json(text: str) -> dict | None:
|
|
text = text.strip()
|
|
if not text:
|
|
return None
|
|
text = re.sub(r"^```(?:json)?\s*", "", text)
|
|
text = re.sub(r"\s*```$", "", text)
|
|
try:
|
|
parsed = json.loads(text)
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except json.JSONDecodeError:
|
|
pass
|
|
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
|
|
if not match:
|
|
return None
|
|
try:
|
|
parsed = json.loads(match.group(0))
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
|
|
def llm_is_enabled() -> bool:
|
|
return bool(os.getenv("OI_LLM_API_URL", "").strip()) and bool(
|
|
os.getenv("OI_LLM_API_KEY", "").strip()
|
|
)
|
|
|
|
|
|
def read_stream_content(resp: requests.Response) -> str:
|
|
chunks: list[str] = []
|
|
for line in resp.iter_lines(decode_unicode=True):
|
|
if not line:
|
|
continue
|
|
if isinstance(line, bytes):
|
|
line = line.decode("utf-8", errors="ignore")
|
|
if not line.startswith("data:"):
|
|
continue
|
|
data = line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
try:
|
|
event = json.loads(data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
choice = (event.get("choices") or [{}])[0]
|
|
delta = (choice.get("delta") or {}).get("content")
|
|
if isinstance(delta, str):
|
|
chunks.append(delta)
|
|
msg = (choice.get("message") or {}).get("content")
|
|
if isinstance(msg, str):
|
|
chunks.append(msg)
|
|
return "".join(chunks).strip()
|
|
|
|
|
|
def call_llm_with_retry(prompt: str) -> dict | None:
|
|
api_url = os.getenv("OI_LLM_API_URL", "").strip()
|
|
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
|
|
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
|
|
retries = int(os.getenv("OI_LLM_RETRY_MAX", "5"))
|
|
base_sleep = float(os.getenv("OI_LLM_RETRY_SLEEP_SEC", "1.5"))
|
|
stream_enabled = to_bool(os.getenv("OI_LLM_STREAM"), False)
|
|
if not api_url or not api_key:
|
|
return None
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Return strict JSON only. "
|
|
"Schema: {"
|
|
"\"title\":string,"
|
|
"\"difficulty\":1-5,"
|
|
"\"answer\":string,"
|
|
"\"explanation\":string,"
|
|
"\"knowledge_points\":[string],"
|
|
"\"tags\":[string],"
|
|
"\"statement_summary_md\":string,"
|
|
"\"sample_input\":string,"
|
|
"\"sample_output\":string"
|
|
"}."
|
|
),
|
|
},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
"stream": stream_enabled,
|
|
"temperature": 0.2,
|
|
}
|
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
|
|
for attempt in range(1, retries + 1):
|
|
try:
|
|
resp = requests_with_retry(
|
|
"POST",
|
|
api_url,
|
|
headers=headers,
|
|
json_body=payload,
|
|
timeout=120,
|
|
stream=stream_enabled,
|
|
retries=1,
|
|
base_sleep=base_sleep,
|
|
)
|
|
except Exception:
|
|
if attempt == retries:
|
|
return None
|
|
sleep_backoff(base_sleep, attempt)
|
|
continue
|
|
|
|
if resp.status_code in RETRYABLE_HTTP_CODES:
|
|
if attempt == retries:
|
|
return None
|
|
sleep_backoff(base_sleep, attempt)
|
|
continue
|
|
if resp.status_code >= 400:
|
|
return None
|
|
|
|
try:
|
|
if stream_enabled:
|
|
content = read_stream_content(resp)
|
|
else:
|
|
body = resp.json()
|
|
content = body["choices"][0]["message"]["content"]
|
|
except Exception:
|
|
if attempt == retries:
|
|
return None
|
|
sleep_backoff(base_sleep, attempt)
|
|
continue
|
|
|
|
parsed = extract_first_json(content)
|
|
if parsed is not None:
|
|
return parsed
|
|
|
|
if attempt < retries:
|
|
sleep_backoff(base_sleep, attempt)
|
|
return None
|
|
|
|
|
|
def as_str(v: Any, default: str = "") -> str:
|
|
if isinstance(v, str):
|
|
return v.strip()
|
|
if v is None:
|
|
return default
|
|
return str(v).strip()
|
|
|
|
|
|
def as_difficulty(v: Any, default: int) -> int:
|
|
try:
|
|
n = int(v)
|
|
except Exception:
|
|
n = default
|
|
return max(1, min(5, n))
|
|
|
|
|
|
def as_list_of_str(v: Any, *, max_items: int, max_len: int) -> list[str]:
|
|
if not isinstance(v, list):
|
|
return []
|
|
out: list[str] = []
|
|
seen: set[str] = set()
|
|
for item in v:
|
|
text = as_str(item)
|
|
if not text:
|
|
continue
|
|
if len(text) > max_len:
|
|
text = text[:max_len].strip()
|
|
key = text.lower()
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
out.append(text)
|
|
if len(out) >= max_items:
|
|
break
|
|
return out
|
|
|
|
|
|
def build_llm_prompt(
|
|
*,
|
|
path: str,
|
|
fallback_title: str,
|
|
fallback_difficulty: int,
|
|
blob_url: str,
|
|
raw_url: str,
|
|
extracted_text: str,
|
|
) -> str:
|
|
snippet = extracted_text.strip() or "(empty; rely on URL and metadata)"
|
|
return (
|
|
"You are an olympiad/CSP exam analyst.\n"
|
|
"Read this PDF metadata and OCR text, then return strict JSON only.\n"
|
|
"Keep answer and explanation concise.\n"
|
|
"Difficulty range is 1..5.\n\n"
|
|
f"File path: {path}\n"
|
|
f"Accessible PDF URL: {raw_url}\n"
|
|
f"GitHub page URL: {blob_url}\n"
|
|
f"Fallback title: {fallback_title}\n"
|
|
f"Fallback difficulty: {fallback_difficulty}\n\n"
|
|
"Extracted text snippet:\n"
|
|
f"{snippet}\n"
|
|
)
|
|
|
|
|
|
def build_statement_markdown(
|
|
*,
|
|
title: str,
|
|
path: str,
|
|
blob_url: str,
|
|
summary_md: str,
|
|
answer: str,
|
|
explanation: str,
|
|
knowledge_points: list[str],
|
|
) -> str:
|
|
lines: list[str] = [
|
|
f"# {title}",
|
|
"",
|
|
f"- Source repo: `{SOURCE_PREFIX}`",
|
|
f"- Source file: `{path}`",
|
|
f"- Original statement PDF: {blob_url}",
|
|
]
|
|
if summary_md:
|
|
lines += ["", "## Summary", "", summary_md]
|
|
if knowledge_points:
|
|
lines += ["", "## Knowledge Points", ""]
|
|
lines += [f"- {kp}" for kp in knowledge_points]
|
|
if answer:
|
|
lines += ["", "## Answer", "", answer]
|
|
if explanation:
|
|
lines += ["", "## Explanation", "", explanation]
|
|
return "\n".join(lines).strip()
|
|
|
|
|
|
def fallback_summary_from_text(text: str, max_chars: int = 500) -> str:
|
|
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
|
if not lines:
|
|
return "题面 OCR 摘要生成失败,请参考原始 PDF。"
|
|
summary = " ".join(lines[:8]).strip()
|
|
if len(summary) > max_chars:
|
|
summary = summary[:max_chars].rstrip() + "..."
|
|
return summary
|
|
|
|
|
|
def fallback_knowledge_points(path: str) -> list[str]:
|
|
p = path.lower()
|
|
points: list[str] = []
|
|
if p.startswith("csp-j/"):
|
|
points += ["模拟", "基础算法"]
|
|
elif p.startswith("csp-s/"):
|
|
points += ["数据结构", "算法设计"]
|
|
elif p.startswith("noi/") or p.startswith("ioi/"):
|
|
points += ["高级算法", "复杂度分析"]
|
|
else:
|
|
points += ["算法基础"]
|
|
|
|
if "round2" in p or "day2" in p:
|
|
points.append("综合应用")
|
|
if "round1" in p or "day1" in p:
|
|
points.append("基础能力")
|
|
|
|
out: list[str] = []
|
|
seen: set[str] = set()
|
|
for item in points:
|
|
key = item.lower()
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
out.append(item)
|
|
if len(out) >= 6:
|
|
break
|
|
return out
|
|
|
|
|
|
def build_record(
|
|
*,
|
|
owner: str,
|
|
repo: str,
|
|
branch: str,
|
|
path: str,
|
|
cache_dir: Path,
|
|
pdf_retry_max: int,
|
|
pdf_retry_sleep: float,
|
|
llm_enabled: bool,
|
|
pdf_text_max_pages: int,
|
|
pdf_text_max_chars: int,
|
|
) -> tuple[ProblemRecord, bool]:
|
|
fallback_title = normalize_title(path)
|
|
fallback_difficulty = estimate_difficulty(path)
|
|
base_tags = build_tags(path)
|
|
|
|
pdf_path, blob_url, raw_url = download_pdf(
|
|
owner, repo, branch, path, cache_dir, pdf_retry_max, pdf_retry_sleep
|
|
)
|
|
extracted_text = extract_pdf_text(pdf_path, pdf_text_max_pages, pdf_text_max_chars)
|
|
|
|
llm_ok = False
|
|
llm_data: dict[str, Any] = {}
|
|
if llm_enabled:
|
|
prompt = build_llm_prompt(
|
|
path=path,
|
|
fallback_title=fallback_title,
|
|
fallback_difficulty=fallback_difficulty,
|
|
blob_url=blob_url,
|
|
raw_url=raw_url,
|
|
extracted_text=extracted_text,
|
|
)
|
|
parsed = call_llm_with_retry(prompt)
|
|
if isinstance(parsed, dict):
|
|
llm_data = parsed
|
|
llm_ok = True
|
|
|
|
title = as_str(llm_data.get("title"), fallback_title) if llm_ok else fallback_title
|
|
title = title[:180] if title else fallback_title
|
|
difficulty = as_difficulty(llm_data.get("difficulty"), fallback_difficulty) if llm_ok else fallback_difficulty
|
|
answer = as_str(llm_data.get("answer"), "") if llm_ok else ""
|
|
explanation = as_str(llm_data.get("explanation"), "") if llm_ok else ""
|
|
summary_md = as_str(llm_data.get("statement_summary_md"), "") if llm_ok else ""
|
|
sample_input = as_str(llm_data.get("sample_input"), "")
|
|
sample_output = as_str(llm_data.get("sample_output"), "")
|
|
|
|
knowledge_points = as_list_of_str(llm_data.get("knowledge_points"), max_items=12, max_len=120)
|
|
llm_tags = as_list_of_str(llm_data.get("tags"), max_items=16, max_len=48)
|
|
tags = sorted({*(t.lower() for t in base_tags), *(t.lower() for t in llm_tags)})
|
|
tags = [t for t in tags if t][:20]
|
|
|
|
if not summary_md:
|
|
summary_md = fallback_summary_from_text(extracted_text)
|
|
if not knowledge_points:
|
|
knowledge_points = fallback_knowledge_points(path)
|
|
if not answer:
|
|
answer = "LLM 识别失败;请参考原始 PDF 与样例输出(待人工复核)。"
|
|
if not explanation:
|
|
explanation = "已保留原题 PDF 与 OCR 摘要,可据此补充标准解法。"
|
|
|
|
statement_md = build_statement_markdown(
|
|
title=title,
|
|
path=path,
|
|
blob_url=blob_url,
|
|
summary_md=summary_md,
|
|
answer=answer,
|
|
explanation=explanation,
|
|
knowledge_points=knowledge_points,
|
|
)
|
|
|
|
profile = {
|
|
"schema_version": 1,
|
|
"title": title,
|
|
"difficulty": difficulty,
|
|
"answer": answer,
|
|
"explanation": explanation,
|
|
"knowledge_points": knowledge_points,
|
|
"tags": tags,
|
|
"statement_summary_md": summary_md,
|
|
"sample_input": sample_input,
|
|
"sample_output": sample_output,
|
|
"source": {
|
|
"repo": SOURCE_PREFIX,
|
|
"path": path,
|
|
"blob_url": blob_url,
|
|
"raw_pdf_url": raw_url,
|
|
},
|
|
"llm": {
|
|
"enabled": llm_enabled,
|
|
"success": llm_ok,
|
|
"model": os.getenv("OI_LLM_MODEL", "qwen3-max").strip(),
|
|
"generated_at": now_sec(),
|
|
},
|
|
}
|
|
|
|
record = ProblemRecord(
|
|
slug=make_slug(path),
|
|
title=title,
|
|
statement_md=statement_md,
|
|
difficulty=difficulty,
|
|
source=f"{SOURCE_PREFIX}:{path}",
|
|
statement_url=blob_url,
|
|
llm_profile_json=json.dumps(profile, ensure_ascii=False),
|
|
sample_input=sample_input,
|
|
sample_output=sample_output,
|
|
tags=tags,
|
|
)
|
|
return record, llm_ok
|
|
|
|
|
|
def ensure_problem_columns(conn: sqlite3.Connection) -> None:
|
|
cur = conn.cursor()
|
|
cur.execute("PRAGMA table_info(problems)")
|
|
existing = {str(row[1]) for row in cur.fetchall()}
|
|
required = {
|
|
"sample_input": "ALTER TABLE problems ADD COLUMN sample_input TEXT NOT NULL DEFAULT ''",
|
|
"sample_output": "ALTER TABLE problems ADD COLUMN sample_output TEXT NOT NULL DEFAULT ''",
|
|
"statement_url": "ALTER TABLE problems ADD COLUMN statement_url TEXT NOT NULL DEFAULT ''",
|
|
"llm_profile_json": "ALTER TABLE problems ADD COLUMN llm_profile_json TEXT NOT NULL DEFAULT '{}'",
|
|
}
|
|
for col, sql in required.items():
|
|
if col not in existing:
|
|
cur.execute(sql)
|
|
conn.commit()
|
|
|
|
|
|
def ensure_import_tables(conn: sqlite3.Connection) -> None:
|
|
cur = conn.cursor()
|
|
cur.executescript(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS import_jobs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
status TEXT NOT NULL,
|
|
trigger TEXT NOT NULL DEFAULT 'manual',
|
|
total_count INTEGER NOT NULL DEFAULT 0,
|
|
processed_count INTEGER NOT NULL DEFAULT 0,
|
|
success_count INTEGER NOT NULL DEFAULT 0,
|
|
failed_count INTEGER NOT NULL DEFAULT 0,
|
|
options_json TEXT NOT NULL DEFAULT '{}',
|
|
last_error TEXT NOT NULL DEFAULT '',
|
|
started_at INTEGER NOT NULL,
|
|
finished_at INTEGER,
|
|
updated_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS import_job_items (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
job_id INTEGER NOT NULL,
|
|
source_path TEXT NOT NULL,
|
|
status TEXT NOT NULL DEFAULT 'queued',
|
|
title TEXT NOT NULL DEFAULT '',
|
|
difficulty INTEGER NOT NULL DEFAULT 0,
|
|
problem_id INTEGER,
|
|
error_text TEXT NOT NULL DEFAULT '',
|
|
started_at INTEGER,
|
|
finished_at INTEGER,
|
|
updated_at INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
UNIQUE(job_id, source_path)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_import_jobs_created_at ON import_jobs(created_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_import_jobs_status ON import_jobs(status, updated_at DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_import_job_items_job_status
|
|
ON import_job_items(job_id, status, updated_at DESC);
|
|
"""
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def mark_stale_jobs(conn: sqlite3.Connection) -> None:
|
|
now = now_sec()
|
|
conn.execute(
|
|
"""
|
|
UPDATE import_jobs
|
|
SET status='failed',
|
|
last_error=CASE
|
|
WHEN last_error='' THEN 'interrupted before finish'
|
|
ELSE last_error
|
|
END,
|
|
finished_at=?,
|
|
updated_at=?
|
|
WHERE status='running'
|
|
""",
|
|
(now, now),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def create_job(conn: sqlite3.Connection, trigger: str, options_json: str, total_count: int) -> int:
|
|
now = now_sec()
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO import_jobs(
|
|
status, trigger, total_count, processed_count, success_count, failed_count,
|
|
options_json, last_error, started_at, finished_at, updated_at, created_at
|
|
) VALUES(?, ?, ?, 0, 0, 0, ?, '', ?, NULL, ?, ?)
|
|
""",
|
|
("running", trigger, total_count, options_json, now, now, now),
|
|
)
|
|
conn.commit()
|
|
return int(cur.lastrowid)
|
|
|
|
|
|
def seed_job_items(conn: sqlite3.Connection, job_id: int, paths: list[str]) -> None:
|
|
now = now_sec()
|
|
conn.executemany(
|
|
"""
|
|
INSERT OR IGNORE INTO import_job_items(
|
|
job_id, source_path, status, title, difficulty, problem_id, error_text,
|
|
started_at, finished_at, updated_at, created_at
|
|
) VALUES(?, ?, 'queued', '', 0, NULL, '', NULL, NULL, ?, ?)
|
|
""",
|
|
[(job_id, p, now, now) for p in paths],
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def mark_item_running(db_path: str, job_id: int, source_path: str) -> None:
|
|
now = now_sec()
|
|
with sqlite3.connect(db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
UPDATE import_job_items
|
|
SET status='running', started_at=?, updated_at=?
|
|
WHERE job_id=? AND source_path=?
|
|
""",
|
|
(now, now, job_id, source_path),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def update_item_success(
|
|
conn: sqlite3.Connection,
|
|
*,
|
|
job_id: int,
|
|
source_path: str,
|
|
title: str,
|
|
difficulty: int,
|
|
problem_id: int,
|
|
) -> None:
|
|
now = now_sec()
|
|
conn.execute(
|
|
"""
|
|
UPDATE import_job_items
|
|
SET status='success', title=?, difficulty=?, problem_id=?, error_text='',
|
|
finished_at=?, updated_at=?
|
|
WHERE job_id=? AND source_path=?
|
|
""",
|
|
(title, difficulty, problem_id, now, now, job_id, source_path),
|
|
)
|
|
|
|
|
|
def update_item_failed(conn: sqlite3.Connection, *, job_id: int, source_path: str, error_text: str) -> None:
|
|
now = now_sec()
|
|
conn.execute(
|
|
"""
|
|
UPDATE import_job_items
|
|
SET status='failed', error_text=?, finished_at=?, updated_at=?
|
|
WHERE job_id=? AND source_path=?
|
|
""",
|
|
(error_text[:1000], now, now, job_id, source_path),
|
|
)
|
|
|
|
|
|
def update_job_progress(
|
|
conn: sqlite3.Connection,
|
|
*,
|
|
job_id: int,
|
|
processed: int,
|
|
success: int,
|
|
failed: int,
|
|
) -> None:
|
|
now = now_sec()
|
|
conn.execute(
|
|
"""
|
|
UPDATE import_jobs
|
|
SET processed_count=?, success_count=?, failed_count=?, updated_at=?
|
|
WHERE id=?
|
|
""",
|
|
(processed, success, failed, now, job_id),
|
|
)
|
|
|
|
|
|
def finish_job(conn: sqlite3.Connection, *, job_id: int, processed: int, success: int, failed: int, last_error: str) -> None:
|
|
now = now_sec()
|
|
status = "success"
|
|
if success == 0 and failed > 0:
|
|
status = "failed"
|
|
elif failed > 0:
|
|
status = "partial_failed"
|
|
|
|
conn.execute(
|
|
"""
|
|
UPDATE import_jobs
|
|
SET status=?, processed_count=?, success_count=?, failed_count=?,
|
|
last_error=?, finished_at=?, updated_at=?
|
|
WHERE id=?
|
|
""",
|
|
(status, processed, success, failed, last_error[:1000], now, now, job_id),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def clear_existing_records(
|
|
conn: sqlite3.Connection,
|
|
*,
|
|
clear_all_problems: bool,
|
|
clear_existing: bool,
|
|
clear_existing_source_prefix: str,
|
|
) -> int:
|
|
if clear_all_problems:
|
|
cur = conn.execute("SELECT COUNT(1) FROM problems")
|
|
count = int(cur.fetchone()[0] or 0)
|
|
conn.execute("DELETE FROM problems")
|
|
conn.commit()
|
|
return count
|
|
|
|
if clear_existing:
|
|
like_pattern = f"{clear_existing_source_prefix}:%"
|
|
cur = conn.execute("SELECT COUNT(1) FROM problems WHERE source LIKE ?", (like_pattern,))
|
|
count = int(cur.fetchone()[0] or 0)
|
|
conn.execute("DELETE FROM problems WHERE source LIKE ?", (like_pattern,))
|
|
conn.commit()
|
|
return count
|
|
return 0
|
|
|
|
|
|
def upsert_single_record(conn: sqlite3.Connection, rec: ProblemRecord) -> tuple[int, bool]:
|
|
cur = conn.cursor()
|
|
cur.execute("SELECT id FROM problems WHERE slug=?", (rec.slug,))
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO problems(
|
|
slug, title, statement_md, difficulty, source, statement_url, llm_profile_json,
|
|
sample_input, sample_output, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
rec.slug,
|
|
rec.title,
|
|
rec.statement_md,
|
|
rec.difficulty,
|
|
rec.source,
|
|
rec.statement_url,
|
|
rec.llm_profile_json,
|
|
rec.sample_input,
|
|
rec.sample_output,
|
|
now_sec(),
|
|
),
|
|
)
|
|
problem_id = int(cur.lastrowid)
|
|
inserted = True
|
|
else:
|
|
problem_id = int(row[0])
|
|
cur.execute(
|
|
"""
|
|
UPDATE problems
|
|
SET title=?, statement_md=?, difficulty=?, source=?, statement_url=?,
|
|
llm_profile_json=?, sample_input=?, sample_output=?
|
|
WHERE id=?
|
|
""",
|
|
(
|
|
rec.title,
|
|
rec.statement_md,
|
|
rec.difficulty,
|
|
rec.source,
|
|
rec.statement_url,
|
|
rec.llm_profile_json,
|
|
rec.sample_input,
|
|
rec.sample_output,
|
|
problem_id,
|
|
),
|
|
)
|
|
inserted = False
|
|
|
|
cur.execute("DELETE FROM problem_tags WHERE problem_id=?", (problem_id,))
|
|
for tag in rec.tags:
|
|
cur.execute("INSERT OR IGNORE INTO problem_tags(problem_id, tag) VALUES (?, ?)", (problem_id, tag))
|
|
conn.commit()
|
|
return problem_id, inserted
|
|
|
|
|
|
def worker_build_record(
|
|
*,
|
|
db_path: str,
|
|
job_id: int,
|
|
owner: str,
|
|
repo: str,
|
|
branch: str,
|
|
path: str,
|
|
cache_dir: Path,
|
|
pdf_retry_max: int,
|
|
pdf_retry_sleep: float,
|
|
llm_enabled: bool,
|
|
pdf_text_max_pages: int,
|
|
pdf_text_max_chars: int,
|
|
) -> WorkResult:
|
|
try:
|
|
mark_item_running(db_path, job_id, path)
|
|
rec, llm_ok = build_record(
|
|
owner=owner,
|
|
repo=repo,
|
|
branch=branch,
|
|
path=path,
|
|
cache_dir=cache_dir,
|
|
pdf_retry_max=pdf_retry_max,
|
|
pdf_retry_sleep=pdf_retry_sleep,
|
|
llm_enabled=llm_enabled,
|
|
pdf_text_max_pages=pdf_text_max_pages,
|
|
pdf_text_max_chars=pdf_text_max_chars,
|
|
)
|
|
return WorkResult(path=path, record=rec, llm_ok=llm_ok, error="")
|
|
except Exception as exc:
|
|
return WorkResult(path=path, record=None, llm_ok=False, error=str(exc))
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(
|
|
description="Import OI PDF catalog from winterant/oi via PDF download + LLM recognition"
|
|
)
|
|
parser.add_argument(
|
|
"--db-path",
|
|
default=os.getenv("CSP_DB_PATH", "/var/lib/docker/volumes/csp_csp_data/_data/csp.db"),
|
|
help="SQLite DB path",
|
|
)
|
|
parser.add_argument("--owner", default=DEFAULT_OWNER)
|
|
parser.add_argument("--repo", default=DEFAULT_REPO)
|
|
parser.add_argument("--max-problems", type=int, default=0, help="0 means all candidates")
|
|
parser.add_argument("--workers", type=int, default=int(os.getenv("OI_IMPORT_WORKERS", "3")))
|
|
parser.add_argument("--skip-llm", action="store_true", help="Skip LLM parsing and use fallback metadata")
|
|
parser.add_argument("--llm-limit", type=int, default=0, help="Max records allowed to call LLM, 0 means no limit")
|
|
parser.add_argument("--job-trigger", default=os.getenv("OI_IMPORT_JOB_TRIGGER", "manual"))
|
|
parser.add_argument("--clear-existing", action="store_true", help="Delete previous imported records before run")
|
|
parser.add_argument(
|
|
"--clear-existing-source-prefix",
|
|
default=os.getenv("OI_IMPORT_CLEAR_SOURCE_PREFIX", SOURCE_PREFIX),
|
|
help="Source prefix for --clear-existing, like winterant/oi",
|
|
)
|
|
parser.add_argument("--clear-all-problems", action="store_true", help="Delete all records in problems before run")
|
|
parser.add_argument("--pdf-cache-dir", default=os.getenv("OI_PDF_CACHE_DIR", "/tmp/csp-oi-pdf-cache"))
|
|
parser.add_argument("--pdf-text-max-pages", type=int, default=int(os.getenv("OI_PDF_TEXT_MAX_PAGES", "8")))
|
|
parser.add_argument("--pdf-text-max-chars", type=int, default=int(os.getenv("OI_PDF_TEXT_MAX_CHARS", "10000")))
|
|
parser.add_argument("--pdf-retry-max", type=int, default=int(os.getenv("OI_PDF_RETRY_MAX", "5")))
|
|
parser.add_argument("--pdf-retry-sleep-sec", type=float, default=float(os.getenv("OI_PDF_RETRY_SLEEP_SEC", "1.5")))
|
|
parser.add_argument("--lock-file", default=os.getenv("OI_IMPORT_LOCK_FILE", ""))
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.db_path):
|
|
raise FileNotFoundError(f"database file not found: {args.db_path}")
|
|
if not args.skip_llm and not llm_is_enabled():
|
|
raise RuntimeError("LLM config missing. Set OI_LLM_API_URL and OI_LLM_API_KEY, or pass --skip-llm")
|
|
|
|
worker_count = max(1, args.workers)
|
|
cache_dir = Path(args.pdf_cache_dir)
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
lock_file = args.lock_file.strip() or f"{args.db_path}.import.lock"
|
|
lock_handle = acquire_import_lock(lock_file)
|
|
_ = lock_handle
|
|
|
|
conn = sqlite3.connect(args.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
ensure_problem_columns(conn)
|
|
ensure_import_tables(conn)
|
|
mark_stale_jobs(conn)
|
|
|
|
branch, tree = load_repo_tree(args.owner, args.repo)
|
|
candidate_paths = sorted(
|
|
item["path"]
|
|
for item in tree
|
|
if item.get("type") == "blob" and looks_like_problem_file(item.get("path", ""))
|
|
)
|
|
if args.max_problems > 0:
|
|
candidate_paths = candidate_paths[: args.max_problems]
|
|
|
|
cleared_count = clear_existing_records(
|
|
conn,
|
|
clear_all_problems=bool(args.clear_all_problems),
|
|
clear_existing=bool(args.clear_existing),
|
|
clear_existing_source_prefix=args.clear_existing_source_prefix.strip() or SOURCE_PREFIX,
|
|
)
|
|
|
|
options_json = json.dumps(
|
|
{
|
|
"owner": args.owner,
|
|
"repo": args.repo,
|
|
"branch": branch,
|
|
"workers": worker_count,
|
|
"skip_llm": bool(args.skip_llm),
|
|
"llm_limit": max(0, args.llm_limit),
|
|
"clear_existing": bool(args.clear_existing),
|
|
"clear_all_problems": bool(args.clear_all_problems),
|
|
"clear_existing_source_prefix": args.clear_existing_source_prefix,
|
|
"cleared_count": cleared_count,
|
|
"pdf_cache_dir": str(cache_dir),
|
|
"lock_file": lock_file,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
job_id = create_job(conn, args.job_trigger, options_json, len(candidate_paths))
|
|
seed_job_items(conn, job_id, candidate_paths)
|
|
|
|
processed = 0
|
|
success = 0
|
|
failed = 0
|
|
inserted = 0
|
|
updated = 0
|
|
llm_success = 0
|
|
llm_failed = 0
|
|
last_error = ""
|
|
remaining_llm = None if args.llm_limit <= 0 else max(0, args.llm_limit)
|
|
|
|
def can_use_llm() -> bool:
|
|
nonlocal remaining_llm
|
|
if args.skip_llm:
|
|
return False
|
|
if remaining_llm is None:
|
|
return True
|
|
if remaining_llm <= 0:
|
|
return False
|
|
remaining_llm -= 1
|
|
return True
|
|
|
|
futures = {}
|
|
with ThreadPoolExecutor(max_workers=worker_count) as executor:
|
|
for path in candidate_paths:
|
|
use_llm = can_use_llm()
|
|
fut = executor.submit(
|
|
worker_build_record,
|
|
db_path=args.db_path,
|
|
job_id=job_id,
|
|
owner=args.owner,
|
|
repo=args.repo,
|
|
branch=branch,
|
|
path=path,
|
|
cache_dir=cache_dir,
|
|
pdf_retry_max=max(1, args.pdf_retry_max),
|
|
pdf_retry_sleep=max(0.1, args.pdf_retry_sleep_sec),
|
|
llm_enabled=use_llm,
|
|
pdf_text_max_pages=max(1, args.pdf_text_max_pages),
|
|
pdf_text_max_chars=max(2000, args.pdf_text_max_chars),
|
|
)
|
|
futures[fut] = use_llm
|
|
|
|
total = len(futures)
|
|
for fut in as_completed(futures):
|
|
use_llm = futures[fut]
|
|
result = fut.result()
|
|
processed += 1
|
|
|
|
if result.record is not None:
|
|
problem_id, is_inserted = upsert_single_record(conn, result.record)
|
|
update_item_success(
|
|
conn,
|
|
job_id=job_id,
|
|
source_path=result.path,
|
|
title=result.record.title,
|
|
difficulty=result.record.difficulty,
|
|
problem_id=problem_id,
|
|
)
|
|
if is_inserted:
|
|
inserted += 1
|
|
else:
|
|
updated += 1
|
|
success += 1
|
|
if use_llm:
|
|
if result.llm_ok:
|
|
llm_success += 1
|
|
else:
|
|
llm_failed += 1
|
|
print(
|
|
f"[{processed}/{total}] {result.path} -> {result.record.title} "
|
|
f"(difficulty={result.record.difficulty}, llm={'ok' if result.llm_ok else 'fallback'})",
|
|
flush=True,
|
|
)
|
|
else:
|
|
failed += 1
|
|
last_error = result.error
|
|
update_item_failed(
|
|
conn,
|
|
job_id=job_id,
|
|
source_path=result.path,
|
|
error_text=result.error,
|
|
)
|
|
print(f"[skip] {result.path}: {result.error}", flush=True)
|
|
|
|
update_job_progress(
|
|
conn,
|
|
job_id=job_id,
|
|
processed=processed,
|
|
success=success,
|
|
failed=failed,
|
|
)
|
|
conn.commit()
|
|
|
|
finish_job(
|
|
conn,
|
|
job_id=job_id,
|
|
processed=processed,
|
|
success=success,
|
|
failed=failed,
|
|
last_error=last_error,
|
|
)
|
|
conn.close()
|
|
|
|
print(
|
|
json.dumps(
|
|
{
|
|
"job_id": job_id,
|
|
"db_path": args.db_path,
|
|
"branch": branch,
|
|
"workers": worker_count,
|
|
"candidates": len(candidate_paths),
|
|
"processed": processed,
|
|
"success": success,
|
|
"failed": failed,
|
|
"inserted": inserted,
|
|
"updated": updated,
|
|
"cleared_count": cleared_count,
|
|
"llm_enabled_default": not args.skip_llm,
|
|
"llm_success": llm_success,
|
|
"llm_failed_fallback": llm_failed,
|
|
"pdf_cache_dir": str(cache_dir),
|
|
"lock_file": lock_file,
|
|
},
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
)
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|