#!/usr/bin/env python3 """Asynchronously generate multiple solutions for a problem and store into SQLite.""" from __future__ import annotations import argparse import json import os import re import shutil import sqlite3 import subprocess import tempfile import time from dataclasses import dataclass from typing import Any import requests RETRYABLE_HTTP_CODES = {500, 502, 503, 504} CLANG_FORMAT_BIN = shutil.which("clang-format") GXX_BIN = shutil.which("g++") PLACEHOLDER_CODE_MARKERS = ( "todo", "to do", "请根据题意补全", "待补全", "自行补全", "省略", "your code here", ) CPP17_BANNED_PATTERNS: tuple[tuple[re.Pattern[str], str], ...] = ( (re.compile(r"\bif\s+constexpr\b"), "if constexpr"), (re.compile(r"\bstd::optional\b"), "std::optional"), (re.compile(r"\bstd::variant\b"), "std::variant"), (re.compile(r"\bstd::any\b"), "std::any"), (re.compile(r"\bstd::string_view\b"), "std::string_view"), (re.compile(r"\bstd::filesystem\b"), "std::filesystem"), (re.compile(r"\bstd::byte\b"), "std::byte"), (re.compile(r"\bstd::clamp\s*\("), "std::clamp"), (re.compile(r"\bstd::gcd\s*\("), "std::gcd"), (re.compile(r"\bstd::lcm\s*\("), "std::lcm"), (re.compile(r"#\s*include\s*<\s*(optional|variant|any|string_view|filesystem|charconv|execution)\s*>"), "C++17 header"), ( re.compile( r"\b(?:const\s+)?auto(?:\s*&|\s*&&)?\s*\[[^\]\n]+\]\s*=", flags=re.MULTILINE, ), "structured bindings", ), ) @dataclass class Problem: id: int title: str statement_md: str difficulty: int source: str sample_input: str sample_output: str def now_sec() -> int: return int(time.time()) def env_bool(key: str, default: bool) -> bool: raw = os.getenv(key, "").strip().lower() if not raw: return default if raw in {"1", "true", "yes", "on"}: return True if raw in {"0", "false", "no", "off"}: return False return default def extract_json_object(text: str) -> dict[str, Any] | None: raw = text.strip() if raw.startswith("```"): raw = re.sub(r"^```[a-zA-Z0-9_-]*", "", raw).strip() raw = raw.removesuffix("```").strip() try: obj = json.loads(raw) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass match = re.search(r"\{[\s\S]*\}", text) if not match: return None try: obj = json.loads(match.group(0)) return obj if isinstance(obj, dict) else None except json.JSONDecodeError: return None def extract_message_text(content: Any) -> str: if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict): text = item.get("text") if isinstance(text, str) and text.strip(): parts.append(text.strip()) continue nested = item.get("content") if isinstance(nested, str) and nested.strip(): parts.append(nested.strip()) return "\n".join(parts).strip() if isinstance(content, dict): text = content.get("text") if isinstance(text, str): return text.strip() nested = content.get("content") if isinstance(nested, str): return nested.strip() return "" def iter_json_candidates(text: str) -> list[str]: if not text: return [] raw = text.strip() candidates: list[str] = [raw] if raw else [] for match in re.finditer(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.IGNORECASE): block = match.group(1).strip() if block: candidates.append(block) decoder = json.JSONDecoder() limit = min(len(text), 200000) sample = text[:limit] for idx, ch in enumerate(sample): if ch not in "{[": continue try: _, end = decoder.raw_decode(sample[idx:]) except json.JSONDecodeError: continue snippet = sample[idx : idx + end].strip() if snippet: candidates.append(snippet) seen: set[str] = set() deduped: list[str] = [] for cand in candidates: if cand in seen: continue seen.add(cand) deduped.append(cand) return deduped def extract_solution_rows(content: str) -> list[dict[str, Any]]: for candidate in iter_json_candidates(content): try: parsed = json.loads(candidate) except json.JSONDecodeError: continue rows: Any = None if isinstance(parsed, dict): rows = parsed.get("solutions") if rows is None and isinstance(parsed.get("data"), dict): rows = parsed["data"].get("solutions") elif isinstance(parsed, list): rows = parsed if isinstance(rows, list): filtered = [x for x in rows if isinstance(x, dict)] if filtered: return filtered return [] def is_placeholder_code(code: str) -> bool: lower = (code or "").lower() if any(marker in lower for marker in PLACEHOLDER_CODE_MARKERS): return True if "..." in code: return True return False def cpp14_violations(code: str) -> list[str]: hits: list[str] = [] for pattern, label in CPP17_BANNED_PATTERNS: if pattern.search(code): hits.append(label) return hits def compiles_under_cpp14(code: str) -> tuple[bool, str]: if not GXX_BIN: return True, "" with tempfile.TemporaryDirectory(prefix="csp_sol_cpp14_") as tmp: src_path = os.path.join(tmp, "main.cpp") with open(src_path, "w", encoding="utf-8") as f: f.write(code if code.endswith("\n") else f"{code}\n") proc = subprocess.run( [GXX_BIN, "-std=gnu++14", "-O2", "-Wall", "-Wextra", "-Wpedantic", "-fsyntax-only", src_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False, timeout=12, ) if proc.returncode == 0: return True, "" err = proc.stderr.decode("utf-8", errors="ignore").strip() return False, err[:400] def normalize_solutions(rows: list[dict[str, Any]], max_solutions: int) -> tuple[list[dict[str, Any]], list[str]]: normalized: list[dict[str, Any]] = [] rejected: list[str] = [] for row in rows: title = str(row.get("title") or "").strip() idea_md = str(row.get("idea_md") or "").strip() explanation_md = str(row.get("explanation_md") or "").strip() complexity = str(row.get("complexity") or "").strip() code_cpp = str(row.get("code_cpp") or "") tags = row.get("tags") if isinstance(row.get("tags"), list) else [] if not code_cpp.strip(): rejected.append("empty code_cpp") continue if "main(" not in code_cpp: rejected.append("missing main()") continue if is_placeholder_code(code_cpp): rejected.append("placeholder code") continue violations = cpp14_violations(code_cpp) if violations: rejected.append(f"C++17+ feature: {', '.join(violations[:3])}") continue ok_cpp14, compile_msg = compiles_under_cpp14(code_cpp) if not ok_cpp14: rejected.append(f"cannot compile with -std=gnu++14: {compile_msg}") continue normalized.append( { "title": title, "idea_md": idea_md, "explanation_md": explanation_md, "complexity": complexity, "code_cpp": code_cpp, "tags": tags, } ) if len(normalized) >= max_solutions: break return normalized, rejected def build_prompt(problem: Problem, max_solutions: int) -> str: return f""" 请为下面这道 CSP 题生成 {max_solutions} 种不同思路的题解(可从模拟/贪心/DP/图论/数据结构等不同角度切入),并给出可直接提交的 C++14 参考代码。 硬性要求: 1. 必须只输出一个 JSON 对象,不能有任何 JSON 外文本。 2. JSON 必须符合下面格式,且 solutions 数组长度应为 {max_solutions}。 3. 每个 code_cpp 必须是完整、可编译、可运行的 C++14 程序(包含 main 函数),不能出现 TODO、伪代码、占位注释、省略号。 4. 必须兼容 GCC 4.9/5.4 + -std=gnu++14:严禁使用 C++17 及以上特性(如 structured bindings、if constexpr、std::optional、std::variant、std::any、std::string_view、)。 5. 建议使用标准头文件(如 // 等),不要使用 。 6. main 必须是 int main(),并且 return 0;。若使用 scanf/printf 处理 long long,格式符必须用 %lld,不要用 %I64d。 7. 代码风格清晰,变量命名可读,注释简洁。 输出 JSON,格式固定: {{ "solutions": [ {{ "title": "解法标题", "idea_md": "思路要点(Markdown)", "explanation_md": "详细讲解(Markdown)", "complexity": "时间/空间复杂度", "code_cpp": "完整 C++14 代码", "tags": ["标签1","标签2"] }} ] }} 题目信息: - 题目:{problem.title} - 难度:{problem.difficulty} - 来源:{problem.source} 完整题面(原文,不做截断): {problem.statement_md} 样例输入(原文): {problem.sample_input} 样例输出(原文): {problem.sample_output} """.strip() def parse_solutions_or_raise(content: str, max_solutions: int) -> list[dict[str, Any]]: rows = extract_solution_rows(content) if not rows: raise RuntimeError("llm response missing valid solutions array") normalized, rejected = normalize_solutions(rows, max_solutions=max_solutions) if not normalized: reason = f"; rejected sample: {rejected[0][:180]}" if rejected else "" raise RuntimeError(f"llm response contains no runnable full code{reason}") return normalized def generate_solutions_with_llm( prompt: str, max_solutions: int, timeout: int, retries: int, sleep_sec: float, ) -> list[dict[str, Any]]: first_content = llm_request(prompt, timeout=timeout, retries=retries, sleep_sec=sleep_sec) try: return parse_solutions_or_raise(first_content, max_solutions=max_solutions) except Exception as first_exc: repair_prompt = ( "你上一条回复不符合要求,原因是:" f"{str(first_exc)[:240]}。请只输出合法 JSON,并确保 code_cpp 是完整可运行 C++14 代码(兼容 -std=gnu++14)。\n\n" + prompt ) second_content = llm_request( repair_prompt, timeout=timeout, retries=retries, sleep_sec=sleep_sec, ) try: return parse_solutions_or_raise(second_content, max_solutions=max_solutions) except Exception as second_exc: raise RuntimeError( f"parse failed after retry: first={str(first_exc)[:200]}; second={str(second_exc)[:200]}" ) from second_exc def llm_request(prompt: str, timeout: int, retries: int, sleep_sec: float) -> str: url = os.getenv("OI_LLM_API_URL", "").strip() api_key = os.getenv("OI_LLM_API_KEY", "").strip() model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip() if not url: raise RuntimeError("missing OI_LLM_API_URL") headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" body = { "model": model, "stream": False, "temperature": 0.3, "messages": [ { "role": "system", "content": "你是资深 OI/CSP 教练。严格输出 JSON,不要输出任何额外文本。", }, {"role": "user", "content": prompt}, ], } last_error: Exception | None = None for i in range(1, retries + 1): try: resp = requests.post(url, headers=headers, json=body, timeout=timeout) except requests.RequestException as exc: last_error = exc if i < retries: time.sleep(sleep_sec * i) continue raise RuntimeError(f"llm request failed: {exc}") from exc if resp.status_code in RETRYABLE_HTTP_CODES: if i < retries: time.sleep(sleep_sec * i) continue raise RuntimeError(f"llm retry exhausted: HTTP {resp.status_code}") if resp.status_code >= 400: raise RuntimeError(f"llm request failed: HTTP {resp.status_code}: {resp.text[:300]}") payload = resp.json() choices = payload.get("choices") or [] if not choices: raise RuntimeError("llm response missing choices") choice0 = choices[0] or {} message = choice0.get("message") or {} content = extract_message_text(message.get("content")) if not content: content = extract_message_text(choice0.get("text")) if not content: raise RuntimeError("llm response missing content") return content if last_error: raise RuntimeError(f"llm request failed: {last_error}") from last_error raise RuntimeError("llm request failed") def fallback_solutions(max_solutions: int) -> list[dict[str, Any]]: base = [ { "title": "解法一:直接模拟/枚举", "idea_md": "按题意拆分步骤,先写可过样例的直观解法,再补边界处理。", "explanation_md": "适用于数据范围较小或规则清晰的题。", "complexity": "时间复杂度依题而定,通常 O(n)~O(n^2)", "code_cpp": """ #include #include #include using namespace std; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); // TODO: 请根据题意补全读入、核心逻辑与输出。 return 0; } """.strip(), "tags": ["simulation", "implementation"], }, { "title": "解法二:优化思路(前缀/贪心/DP 视题而定)", "idea_md": "分析状态与重复计算,尝试用前缀和、贪心或动态规划优化。", "explanation_md": "比直接模拟更稳定,通常能覆盖更大数据规模。", "complexity": "通常优于朴素解法", "code_cpp": """ #include #include #include using namespace std; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); // TODO: 请根据题意补全优化版思路实现。 return 0; } """.strip(), "tags": ["optimization", "dp"], }, ] return base[: max(1, max_solutions)] def format_cpp_code(raw: str) -> str: code = (raw or "").replace("\r\n", "\n").replace("\r", "\n") if not code.strip(): return "" if not code.endswith("\n"): code += "\n" if not CLANG_FORMAT_BIN: return code try: proc = subprocess.run( [CLANG_FORMAT_BIN, "--style={BasedOnStyle: Google, IndentWidth: 2, ColumnLimit: 0}"], input=code.encode("utf-8"), stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False, timeout=6, ) if proc.returncode == 0 and proc.stdout: out = proc.stdout.decode("utf-8", errors="ignore") if out.strip(): return out if out.endswith("\n") else f"{out}\n" except Exception: pass return code def load_problem(conn: sqlite3.Connection, problem_id: int) -> Problem: cur = conn.execute( "SELECT id,title,statement_md,difficulty,source,sample_input,sample_output FROM problems WHERE id=?", (problem_id,), ) row = cur.fetchone() if row is None: raise RuntimeError(f"problem not found: {problem_id}") return Problem( id=int(row[0]), title=str(row[1] or ""), statement_md=str(row[2] or ""), difficulty=int(row[3] or 1), source=str(row[4] or ""), sample_input=str(row[5] or ""), sample_output=str(row[6] or ""), ) def update_job(conn: sqlite3.Connection, job_id: int, **fields: Any) -> None: if not fields: return keys = [] vals: list[Any] = [] for k, v in fields.items(): keys.append(f"{k}=?") vals.append(v) vals.append(job_id) conn.execute( f"UPDATE problem_solution_jobs SET {', '.join(keys)} WHERE id=?", tuple(vals), ) conn.commit() def store_solutions(conn: sqlite3.Connection, problem_id: int, rows: list[dict[str, Any]], source: str) -> int: ts = now_sec() conn.execute("DELETE FROM problem_solutions WHERE problem_id=?", (problem_id,)) saved = 0 seen_titles: set[str] = set() for idx, row in enumerate(rows, start=1): title = str(row.get("title") or f"解法 {idx}").strip() if title in seen_titles: continue seen_titles.add(title) idea_md = str(row.get("idea_md") or "").strip() explanation_md = str(row.get("explanation_md") or "").strip() code_cpp = format_cpp_code(str(row.get("code_cpp") or "")) complexity = str(row.get("complexity") or "").strip() tags = row.get("tags") if isinstance(row.get("tags"), list) else [] conn.execute( """ INSERT INTO problem_solutions( problem_id,variant,title,idea_md,explanation_md,code_cpp,complexity,tags_json,source,created_at,updated_at ) VALUES(?,?,?,?,?,?,?,?,?,?,?) """, ( problem_id, idx, title, idea_md, explanation_md, code_cpp, complexity, json.dumps(tags, ensure_ascii=False), source, ts, ts, ), ) saved += 1 conn.commit() return saved def main() -> int: parser = argparse.ArgumentParser(description="Generate multi-solution explanations") parser.add_argument("--db-path", required=True) parser.add_argument("--problem-id", type=int, required=True) parser.add_argument("--job-id", type=int, required=True) parser.add_argument("--max-solutions", type=int, default=3) parser.add_argument("--timeout", type=int, default=90) parser.add_argument("--retries", type=int, default=4) parser.add_argument("--retry-sleep-sec", type=float, default=1.5) args = parser.parse_args() conn = sqlite3.connect(args.db_path) conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA busy_timeout=5000") ts = now_sec() update_job( conn, args.job_id, status="running", progress=1, message="starting", started_at=ts, updated_at=ts, ) try: problem = load_problem(conn, args.problem_id) requested_solutions = max(1, min(5, args.max_solutions)) allow_fallback = env_bool("CSP_SOLUTION_ALLOW_FALLBACK", False) prompt = build_prompt(problem, max_solutions=requested_solutions) update_job(conn, args.job_id, progress=25, message="requesting llm", updated_at=now_sec()) source = "llm" solutions: list[dict[str, Any]] try: solutions = generate_solutions_with_llm( prompt=prompt, max_solutions=requested_solutions, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec, ) except Exception as exc: if not allow_fallback: raise RuntimeError(f"llm generation failed: {str(exc)[:280]}") from exc source = "fallback" solutions = fallback_solutions(args.max_solutions) solutions = solutions[:requested_solutions] update_job(conn, args.job_id, progress=70, message="writing solutions", updated_at=now_sec()) saved = store_solutions(conn, args.problem_id, solutions, source) done_msg = f"completed: {saved} solutions ({source})" update_job( conn, args.job_id, status="completed", progress=100, message=done_msg, finished_at=now_sec(), updated_at=now_sec(), ) conn.close() return 0 except Exception as exc: update_job( conn, args.job_id, status="failed", progress=100, message=f"failed: {str(exc)[:400]}", finished_at=now_sec(), updated_at=now_sec(), ) conn.close() return 1 if __name__ == "__main__": raise SystemExit(main())