#!/usr/bin/env python3 """Asynchronously generate multiple solutions for a problem and store into SQLite.""" from __future__ import annotations import argparse import json import os import re import sqlite3 import time from dataclasses import dataclass from typing import Any import requests RETRYABLE_HTTP_CODES = {500, 502, 503, 504} @dataclass class Problem: id: int title: str statement_md: str difficulty: int source: str sample_input: str sample_output: str def now_sec() -> int: return int(time.time()) def extract_json_object(text: str) -> dict[str, Any] | None: raw = text.strip() if raw.startswith("```"): raw = re.sub(r"^```[a-zA-Z0-9_-]*", "", raw).strip() raw = raw.removesuffix("```").strip() try: obj = json.loads(raw) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass match = re.search(r"\{[\s\S]*\}", text) if not match: return None try: obj = json.loads(match.group(0)) return obj if isinstance(obj, dict) else None except json.JSONDecodeError: return None def llm_request(prompt: str, timeout: int, retries: int, sleep_sec: float) -> str: url = os.getenv("OI_LLM_API_URL", "").strip() api_key = os.getenv("OI_LLM_API_KEY", "").strip() model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip() if not url: raise RuntimeError("missing OI_LLM_API_URL") headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" body = { "model": model, "stream": False, "temperature": 0.3, "messages": [ { "role": "system", "content": "你是资深 OI/CSP 教练。严格输出 JSON,不要输出任何额外文本。", }, {"role": "user", "content": prompt}, ], } last_error: Exception | None = None for i in range(1, retries + 1): try: resp = requests.post(url, headers=headers, json=body, timeout=timeout) except requests.RequestException as exc: last_error = exc if i < retries: time.sleep(sleep_sec * i) continue raise RuntimeError(f"llm request failed: {exc}") from exc if resp.status_code in RETRYABLE_HTTP_CODES: if i < retries: time.sleep(sleep_sec * i) continue raise RuntimeError(f"llm retry exhausted: HTTP {resp.status_code}") if resp.status_code >= 400: raise RuntimeError(f"llm request failed: HTTP {resp.status_code}: {resp.text[:300]}") payload = resp.json() choices = payload.get("choices") or [] if not choices: raise RuntimeError("llm response missing choices") content = ((choices[0] or {}).get("message") or {}).get("content") if not content: raise RuntimeError("llm response missing content") return str(content) if last_error: raise RuntimeError(f"llm request failed: {last_error}") from last_error raise RuntimeError("llm request failed") def fallback_solutions(max_solutions: int) -> list[dict[str, Any]]: base = [ { "title": "解法一:直接模拟/枚举", "idea_md": "按题意拆分步骤,先写可过样例的直观解法,再补边界处理。", "explanation_md": "适用于数据范围较小或规则清晰的题。", "complexity": "时间复杂度依题而定,通常 O(n)~O(n^2)", "code_cpp": "// TODO: 请根据题意补全\n#include \nusing namespace std;\nint main(){ios::sync_with_stdio(false);cin.tie(nullptr);return 0;}\n", "tags": ["simulation", "implementation"], }, { "title": "解法二:优化思路(前缀/贪心/DP 视题而定)", "idea_md": "分析状态与重复计算,尝试用前缀和、贪心或动态规划优化。", "explanation_md": "比直接模拟更稳定,通常能覆盖更大数据规模。", "complexity": "通常优于朴素解法", "code_cpp": "// TODO: 请根据题意补全\n#include \nusing namespace std;\nint main(){ios::sync_with_stdio(false);cin.tie(nullptr);return 0;}\n", "tags": ["optimization", "dp"], }, ] return base[: max(1, max_solutions)] def load_problem(conn: sqlite3.Connection, problem_id: int) -> Problem: cur = conn.execute( "SELECT id,title,statement_md,difficulty,source,sample_input,sample_output FROM problems WHERE id=?", (problem_id,), ) row = cur.fetchone() if row is None: raise RuntimeError(f"problem not found: {problem_id}") return Problem( id=int(row[0]), title=str(row[1] or ""), statement_md=str(row[2] or ""), difficulty=int(row[3] or 1), source=str(row[4] or ""), sample_input=str(row[5] or ""), sample_output=str(row[6] or ""), ) def update_job(conn: sqlite3.Connection, job_id: int, **fields: Any) -> None: if not fields: return keys = [] vals: list[Any] = [] for k, v in fields.items(): keys.append(f"{k}=?") vals.append(v) vals.append(job_id) conn.execute( f"UPDATE problem_solution_jobs SET {', '.join(keys)} WHERE id=?", tuple(vals), ) conn.commit() def store_solutions(conn: sqlite3.Connection, problem_id: int, rows: list[dict[str, Any]], source: str) -> int: ts = now_sec() conn.execute("DELETE FROM problem_solutions WHERE problem_id=?", (problem_id,)) saved = 0 seen_titles: set[str] = set() for idx, row in enumerate(rows, start=1): title = str(row.get("title") or f"解法 {idx}").strip() if title in seen_titles: continue seen_titles.add(title) idea_md = str(row.get("idea_md") or "").strip() explanation_md = str(row.get("explanation_md") or "").strip() code_cpp = str(row.get("code_cpp") or "").strip() complexity = str(row.get("complexity") or "").strip() tags = row.get("tags") if isinstance(row.get("tags"), list) else [] conn.execute( """ INSERT INTO problem_solutions( problem_id,variant,title,idea_md,explanation_md,code_cpp,complexity,tags_json,source,created_at,updated_at ) VALUES(?,?,?,?,?,?,?,?,?,?,?) """, ( problem_id, idx, title, idea_md, explanation_md, code_cpp, complexity, json.dumps(tags, ensure_ascii=False), source, ts, ts, ), ) saved += 1 conn.commit() return saved def main() -> int: parser = argparse.ArgumentParser(description="Generate multi-solution explanations") parser.add_argument("--db-path", required=True) parser.add_argument("--problem-id", type=int, required=True) parser.add_argument("--job-id", type=int, required=True) parser.add_argument("--max-solutions", type=int, default=3) parser.add_argument("--timeout", type=int, default=90) parser.add_argument("--retries", type=int, default=4) parser.add_argument("--retry-sleep-sec", type=float, default=1.5) args = parser.parse_args() conn = sqlite3.connect(args.db_path) conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA busy_timeout=5000") ts = now_sec() update_job( conn, args.job_id, status="running", progress=1, message="starting", started_at=ts, updated_at=ts, ) try: problem = load_problem(conn, args.problem_id) prompt = f""" 请为下面这道 CSP 题生成 {max(1, min(5, args.max_solutions))} 种不同思路的题解(可从不同角度切入,例如模拟/贪心/DP/数据结构),并给出 C++ 参考代码。 输出 JSON,格式固定: {{ "solutions": [ {{ "title": "解法标题", "idea_md": "思路要点(Markdown)", "explanation_md": "详细讲解(Markdown)", "complexity": "时间/空间复杂度", "code_cpp": "完整 C++17 代码", "tags": ["标签1","标签2"] }} ] }} 题目:{problem.title} 难度:{problem.difficulty} 来源:{problem.source} 题面: {problem.statement_md[:12000]} 样例输入: {problem.sample_input[:1200]} 样例输出: {problem.sample_output[:1200]} """.strip() update_job(conn, args.job_id, progress=25, message="requesting llm", updated_at=now_sec()) source = "fallback" solutions: list[dict[str, Any]] try: content = llm_request( prompt, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec, ) obj = extract_json_object(content) raw = obj.get("solutions") if isinstance(obj, dict) else None if not isinstance(raw, list) or len(raw) == 0: raise RuntimeError("llm response missing solutions array") solutions = [x for x in raw if isinstance(x, dict)] if not solutions: raise RuntimeError("llm response has empty valid solutions") source = "llm" except Exception: solutions = fallback_solutions(args.max_solutions) solutions = solutions[: max(1, min(5, args.max_solutions))] update_job(conn, args.job_id, progress=70, message="writing solutions", updated_at=now_sec()) saved = store_solutions(conn, args.problem_id, solutions, source) update_job( conn, args.job_id, status="completed", progress=100, message=f"completed: {saved} solutions ({source})", finished_at=now_sec(), updated_at=now_sec(), ) conn.close() return 0 except Exception as exc: update_job( conn, args.job_id, status="failed", progress=100, message=f"failed: {str(exc)[:400]}", finished_at=now_sec(), updated_at=now_sec(), ) conn.close() return 1 if __name__ == "__main__": raise SystemExit(main())