文件
csp/scripts/generate_problem_solutions.py

649 行
21 KiB
Python
原始文件 Blame 文件历史

此文件含有模棱两可的 Unicode 字符
此文件含有可能会与其他字符混淆的 Unicode 字符。 如果您是想特意这样的,可以安全地忽略该警告。 使用 Escape 按钮显示他们。
#!/usr/bin/env python3
"""Asynchronously generate multiple solutions for a problem and store into SQLite."""
from __future__ import annotations
import argparse
import json
import os
import re
import shutil
import sqlite3
import subprocess
import tempfile
import time
from dataclasses import dataclass
from typing import Any
import requests
RETRYABLE_HTTP_CODES = {500, 502, 503, 504}
CLANG_FORMAT_BIN = shutil.which("clang-format")
GXX_BIN = shutil.which("g++")
PLACEHOLDER_CODE_MARKERS = (
"todo",
"to do",
"请根据题意补全",
"待补全",
"自行补全",
"省略",
"your code here",
)
CPP17_BANNED_PATTERNS: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(r"\bif\s+constexpr\b"), "if constexpr"),
(re.compile(r"\bstd::optional\b"), "std::optional"),
(re.compile(r"\bstd::variant\b"), "std::variant"),
(re.compile(r"\bstd::any\b"), "std::any"),
(re.compile(r"\bstd::string_view\b"), "std::string_view"),
(re.compile(r"\bstd::filesystem\b"), "std::filesystem"),
(re.compile(r"\bstd::byte\b"), "std::byte"),
(re.compile(r"\bstd::clamp\s*\("), "std::clamp"),
(re.compile(r"\bstd::gcd\s*\("), "std::gcd"),
(re.compile(r"\bstd::lcm\s*\("), "std::lcm"),
(re.compile(r"#\s*include\s*<\s*(optional|variant|any|string_view|filesystem|charconv|execution)\s*>"), "C++17 header"),
(
re.compile(
r"\b(?:const\s+)?auto(?:\s*&|\s*&&)?\s*\[[^\]\n]+\]\s*=",
flags=re.MULTILINE,
),
"structured bindings",
),
)
@dataclass
class Problem:
id: int
title: str
statement_md: str
difficulty: int
source: str
sample_input: str
sample_output: str
def now_sec() -> int:
return int(time.time())
def env_bool(key: str, default: bool) -> bool:
raw = os.getenv(key, "").strip().lower()
if not raw:
return default
if raw in {"1", "true", "yes", "on"}:
return True
if raw in {"0", "false", "no", "off"}:
return False
return default
def extract_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if raw.startswith("```"):
raw = re.sub(r"^```[a-zA-Z0-9_-]*", "", raw).strip()
raw = raw.removesuffix("```").strip()
try:
obj = json.loads(raw)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
pass
match = re.search(r"\{[\s\S]*\}", text)
if not match:
return None
try:
obj = json.loads(match.group(0))
return obj if isinstance(obj, dict) else None
except json.JSONDecodeError:
return None
def extract_message_text(content: Any) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict):
text = item.get("text")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
continue
nested = item.get("content")
if isinstance(nested, str) and nested.strip():
parts.append(nested.strip())
return "\n".join(parts).strip()
if isinstance(content, dict):
text = content.get("text")
if isinstance(text, str):
return text.strip()
nested = content.get("content")
if isinstance(nested, str):
return nested.strip()
return ""
def iter_json_candidates(text: str) -> list[str]:
if not text:
return []
raw = text.strip()
candidates: list[str] = [raw] if raw else []
for match in re.finditer(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.IGNORECASE):
block = match.group(1).strip()
if block:
candidates.append(block)
decoder = json.JSONDecoder()
limit = min(len(text), 200000)
sample = text[:limit]
for idx, ch in enumerate(sample):
if ch not in "{[":
continue
try:
_, end = decoder.raw_decode(sample[idx:])
except json.JSONDecodeError:
continue
snippet = sample[idx : idx + end].strip()
if snippet:
candidates.append(snippet)
seen: set[str] = set()
deduped: list[str] = []
for cand in candidates:
if cand in seen:
continue
seen.add(cand)
deduped.append(cand)
return deduped
def extract_solution_rows(content: str) -> list[dict[str, Any]]:
for candidate in iter_json_candidates(content):
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
rows: Any = None
if isinstance(parsed, dict):
rows = parsed.get("solutions")
if rows is None and isinstance(parsed.get("data"), dict):
rows = parsed["data"].get("solutions")
elif isinstance(parsed, list):
rows = parsed
if isinstance(rows, list):
filtered = [x for x in rows if isinstance(x, dict)]
if filtered:
return filtered
return []
def is_placeholder_code(code: str) -> bool:
lower = (code or "").lower()
if any(marker in lower for marker in PLACEHOLDER_CODE_MARKERS):
return True
if "..." in code:
return True
return False
def cpp14_violations(code: str) -> list[str]:
hits: list[str] = []
for pattern, label in CPP17_BANNED_PATTERNS:
if pattern.search(code):
hits.append(label)
return hits
def compiles_under_cpp14(code: str) -> tuple[bool, str]:
if not GXX_BIN:
return True, ""
with tempfile.TemporaryDirectory(prefix="csp_sol_cpp14_") as tmp:
src_path = os.path.join(tmp, "main.cpp")
with open(src_path, "w", encoding="utf-8") as f:
f.write(code if code.endswith("\n") else f"{code}\n")
proc = subprocess.run(
[GXX_BIN, "-std=gnu++14", "-O2", "-Wall", "-Wextra", "-Wpedantic", "-fsyntax-only", src_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
timeout=12,
)
if proc.returncode == 0:
return True, ""
err = proc.stderr.decode("utf-8", errors="ignore").strip()
return False, err[:400]
def normalize_solutions(rows: list[dict[str, Any]], max_solutions: int) -> tuple[list[dict[str, Any]], list[str]]:
normalized: list[dict[str, Any]] = []
rejected: list[str] = []
for row in rows:
title = str(row.get("title") or "").strip()
idea_md = str(row.get("idea_md") or "").strip()
explanation_md = str(row.get("explanation_md") or "").strip()
complexity = str(row.get("complexity") or "").strip()
code_cpp = str(row.get("code_cpp") or "")
tags = row.get("tags") if isinstance(row.get("tags"), list) else []
if not code_cpp.strip():
rejected.append("empty code_cpp")
continue
if "main(" not in code_cpp:
rejected.append("missing main()")
continue
if is_placeholder_code(code_cpp):
rejected.append("placeholder code")
continue
violations = cpp14_violations(code_cpp)
if violations:
rejected.append(f"C++17+ feature: {', '.join(violations[:3])}")
continue
ok_cpp14, compile_msg = compiles_under_cpp14(code_cpp)
if not ok_cpp14:
rejected.append(f"cannot compile with -std=gnu++14: {compile_msg}")
continue
normalized.append(
{
"title": title,
"idea_md": idea_md,
"explanation_md": explanation_md,
"complexity": complexity,
"code_cpp": code_cpp,
"tags": tags,
}
)
if len(normalized) >= max_solutions:
break
return normalized, rejected
def build_prompt(problem: Problem, max_solutions: int) -> str:
return f"""
请为下面这道 CSP 题生成 {max_solutions} 种不同思路的题解(可从模拟/贪心/DP/图论/数据结构等不同角度切入),并给出可直接提交的 C++14 参考代码。
硬性要求:
1. 必须只输出一个 JSON 对象,不能有任何 JSON 外文本。
2. JSON 必须符合下面格式,且 solutions 数组长度应为 {max_solutions}
3. 每个 code_cpp 必须是完整、可编译、可运行的 C++14 程序(包含 main 函数),不能出现 TODO、伪代码、占位注释、省略号。
4. 必须兼容 GCC 4.9/5.4 + -std=gnu++14严禁使用 C++17 及以上特性(如 structured bindings、if constexpr、std::optional、std::variant、std::any、std::string_view、<filesystem>)。
5. 建议使用标准头文件(如 <iostream>/<vector>/<algorithm> 等),不要使用 <bits/stdc++.h>。
6. main 必须是 int main(),并且 return 0;。若使用 scanf/printf 处理 long long,格式符必须用 %lld,不要用 %I64d。
7. 代码风格清晰,变量命名可读,注释简洁。
输出 JSON,格式固定
{{
"solutions": [
{{
"title": "解法标题",
"idea_md": "思路要点Markdown",
"explanation_md": "详细讲解Markdown",
"complexity": "时间/空间复杂度",
"code_cpp": "完整 C++14 代码",
"tags": ["标签1","标签2"]
}}
]
}}
题目信息:
- 题目:{problem.title}
- 难度:{problem.difficulty}
- 来源:{problem.source}
完整题面(原文,不做截断):
{problem.statement_md}
样例输入(原文):
{problem.sample_input}
样例输出(原文):
{problem.sample_output}
""".strip()
def parse_solutions_or_raise(content: str, max_solutions: int) -> list[dict[str, Any]]:
rows = extract_solution_rows(content)
if not rows:
raise RuntimeError("llm response missing valid solutions array")
normalized, rejected = normalize_solutions(rows, max_solutions=max_solutions)
if not normalized:
reason = f"; rejected sample: {rejected[0][:180]}" if rejected else ""
raise RuntimeError(f"llm response contains no runnable full code{reason}")
return normalized
def generate_solutions_with_llm(
prompt: str,
max_solutions: int,
timeout: int,
retries: int,
sleep_sec: float,
) -> list[dict[str, Any]]:
first_content = llm_request(prompt, timeout=timeout, retries=retries, sleep_sec=sleep_sec)
try:
return parse_solutions_or_raise(first_content, max_solutions=max_solutions)
except Exception as first_exc:
repair_prompt = (
"你上一条回复不符合要求,原因是:"
f"{str(first_exc)[:240]}。请只输出合法 JSON,并确保 code_cpp 是完整可运行 C++14 代码(兼容 -std=gnu++14\n\n"
+ prompt
)
second_content = llm_request(
repair_prompt,
timeout=timeout,
retries=retries,
sleep_sec=sleep_sec,
)
try:
return parse_solutions_or_raise(second_content, max_solutions=max_solutions)
except Exception as second_exc:
raise RuntimeError(
f"parse failed after retry: first={str(first_exc)[:200]}; second={str(second_exc)[:200]}"
) from second_exc
def llm_request(prompt: str, timeout: int, retries: int, sleep_sec: float) -> str:
url = os.getenv("OI_LLM_API_URL", "").strip()
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
if not url:
raise RuntimeError("missing OI_LLM_API_URL")
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
body = {
"model": model,
"stream": False,
"temperature": 0.3,
"messages": [
{
"role": "system",
"content": "你是资深 OI/CSP 教练。严格输出 JSON,不要输出任何额外文本。",
},
{"role": "user", "content": prompt},
],
}
last_error: Exception | None = None
for i in range(1, retries + 1):
try:
resp = requests.post(url, headers=headers, json=body, timeout=timeout)
except requests.RequestException as exc:
last_error = exc
if i < retries:
time.sleep(sleep_sec * i)
continue
raise RuntimeError(f"llm request failed: {exc}") from exc
if resp.status_code in RETRYABLE_HTTP_CODES:
if i < retries:
time.sleep(sleep_sec * i)
continue
raise RuntimeError(f"llm retry exhausted: HTTP {resp.status_code}")
if resp.status_code >= 400:
raise RuntimeError(f"llm request failed: HTTP {resp.status_code}: {resp.text[:300]}")
payload = resp.json()
choices = payload.get("choices") or []
if not choices:
raise RuntimeError("llm response missing choices")
choice0 = choices[0] or {}
message = choice0.get("message") or {}
content = extract_message_text(message.get("content"))
if not content:
content = extract_message_text(choice0.get("text"))
if not content:
raise RuntimeError("llm response missing content")
return content
if last_error:
raise RuntimeError(f"llm request failed: {last_error}") from last_error
raise RuntimeError("llm request failed")
def fallback_solutions(max_solutions: int) -> list[dict[str, Any]]:
base = [
{
"title": "解法一:直接模拟/枚举",
"idea_md": "按题意拆分步骤,先写可过样例的直观解法,再补边界处理。",
"explanation_md": "适用于数据范围较小或规则清晰的题。",
"complexity": "时间复杂度依题而定,通常 O(n)~O(n^2)",
"code_cpp": """
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
// TODO: 请根据题意补全读入、核心逻辑与输出。
return 0;
}
""".strip(),
"tags": ["simulation", "implementation"],
},
{
"title": "解法二:优化思路(前缀/贪心/DP 视题而定)",
"idea_md": "分析状态与重复计算,尝试用前缀和、贪心或动态规划优化。",
"explanation_md": "比直接模拟更稳定,通常能覆盖更大数据规模。",
"complexity": "通常优于朴素解法",
"code_cpp": """
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
// TODO: 请根据题意补全优化版思路实现。
return 0;
}
""".strip(),
"tags": ["optimization", "dp"],
},
]
return base[: max(1, max_solutions)]
def format_cpp_code(raw: str) -> str:
code = (raw or "").replace("\r\n", "\n").replace("\r", "\n")
if not code.strip():
return ""
if not code.endswith("\n"):
code += "\n"
if not CLANG_FORMAT_BIN:
return code
try:
proc = subprocess.run(
[CLANG_FORMAT_BIN, "--style={BasedOnStyle: Google, IndentWidth: 2, ColumnLimit: 0}"],
input=code.encode("utf-8"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
timeout=6,
)
if proc.returncode == 0 and proc.stdout:
out = proc.stdout.decode("utf-8", errors="ignore")
if out.strip():
return out if out.endswith("\n") else f"{out}\n"
except Exception:
pass
return code
def load_problem(conn: sqlite3.Connection, problem_id: int) -> Problem:
cur = conn.execute(
"SELECT id,title,statement_md,difficulty,source,sample_input,sample_output FROM problems WHERE id=?",
(problem_id,),
)
row = cur.fetchone()
if row is None:
raise RuntimeError(f"problem not found: {problem_id}")
return Problem(
id=int(row[0]),
title=str(row[1] or ""),
statement_md=str(row[2] or ""),
difficulty=int(row[3] or 1),
source=str(row[4] or ""),
sample_input=str(row[5] or ""),
sample_output=str(row[6] or ""),
)
def update_job(conn: sqlite3.Connection, job_id: int, **fields: Any) -> None:
if not fields:
return
keys = []
vals: list[Any] = []
for k, v in fields.items():
keys.append(f"{k}=?")
vals.append(v)
vals.append(job_id)
conn.execute(
f"UPDATE problem_solution_jobs SET {', '.join(keys)} WHERE id=?",
tuple(vals),
)
conn.commit()
def store_solutions(conn: sqlite3.Connection, problem_id: int, rows: list[dict[str, Any]], source: str) -> int:
ts = now_sec()
conn.execute("DELETE FROM problem_solutions WHERE problem_id=?", (problem_id,))
saved = 0
seen_titles: set[str] = set()
for idx, row in enumerate(rows, start=1):
title = str(row.get("title") or f"解法 {idx}").strip()
if title in seen_titles:
continue
seen_titles.add(title)
idea_md = str(row.get("idea_md") or "").strip()
explanation_md = str(row.get("explanation_md") or "").strip()
code_cpp = format_cpp_code(str(row.get("code_cpp") or ""))
complexity = str(row.get("complexity") or "").strip()
tags = row.get("tags") if isinstance(row.get("tags"), list) else []
conn.execute(
"""
INSERT INTO problem_solutions(
problem_id,variant,title,idea_md,explanation_md,code_cpp,complexity,tags_json,source,created_at,updated_at
) VALUES(?,?,?,?,?,?,?,?,?,?,?)
""",
(
problem_id,
idx,
title,
idea_md,
explanation_md,
code_cpp,
complexity,
json.dumps(tags, ensure_ascii=False),
source,
ts,
ts,
),
)
saved += 1
conn.commit()
return saved
def main() -> int:
parser = argparse.ArgumentParser(description="Generate multi-solution explanations")
parser.add_argument("--db-path", required=True)
parser.add_argument("--problem-id", type=int, required=True)
parser.add_argument("--job-id", type=int, required=True)
parser.add_argument("--max-solutions", type=int, default=3)
parser.add_argument("--timeout", type=int, default=90)
parser.add_argument("--retries", type=int, default=4)
parser.add_argument("--retry-sleep-sec", type=float, default=1.5)
args = parser.parse_args()
conn = sqlite3.connect(args.db_path)
conn.execute("PRAGMA foreign_keys=ON")
conn.execute("PRAGMA busy_timeout=5000")
ts = now_sec()
update_job(
conn,
args.job_id,
status="running",
progress=1,
message="starting",
started_at=ts,
updated_at=ts,
)
try:
problem = load_problem(conn, args.problem_id)
requested_solutions = max(1, min(5, args.max_solutions))
allow_fallback = env_bool("CSP_SOLUTION_ALLOW_FALLBACK", False)
prompt = build_prompt(problem, max_solutions=requested_solutions)
update_job(conn, args.job_id, progress=25, message="requesting llm", updated_at=now_sec())
source = "llm"
solutions: list[dict[str, Any]]
try:
solutions = generate_solutions_with_llm(
prompt=prompt,
max_solutions=requested_solutions,
timeout=args.timeout,
retries=args.retries,
sleep_sec=args.retry_sleep_sec,
)
except Exception as exc:
if not allow_fallback:
raise RuntimeError(f"llm generation failed: {str(exc)[:280]}") from exc
source = "fallback"
solutions = fallback_solutions(args.max_solutions)
solutions = solutions[:requested_solutions]
update_job(conn, args.job_id, progress=70, message="writing solutions", updated_at=now_sec())
saved = store_solutions(conn, args.problem_id, solutions, source)
done_msg = f"completed: {saved} solutions ({source})"
update_job(
conn,
args.job_id,
status="completed",
progress=100,
message=done_msg,
finished_at=now_sec(),
updated_at=now_sec(),
)
conn.close()
return 0
except Exception as exc:
update_job(
conn,
args.job_id,
status="failed",
progress=100,
message=f"failed: {str(exc)[:400]}",
finished_at=now_sec(),
updated_at=now_sec(),
)
conn.close()
return 1
if __name__ == "__main__":
raise SystemExit(main())