#!/usr/bin/env python3 """Generate new CSP-J problems with RAG + dedupe checks.""" from __future__ import annotations import argparse import json import math import os import random import re import sqlite3 import time from dataclasses import dataclass from difflib import SequenceMatcher from typing import Any from urllib.parse import quote import requests DEFAULT_BASE_URL = "https://www.luogu.com.cn" DEFAULT_TAG_IDS = [343, 82] # CSP-J + NOIP junior RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504} CONTEXT_RE = re.compile( r']*id="lentille-context"[^>]*>(.*?)', re.DOTALL ) @dataclass class ExistingProblem: id: int title: str statement_md: str def now_sec() -> int: return int(time.time()) def normalize(text: str) -> str: text = text.lower().strip() text = re.sub(r"\s+", " ", text) text = re.sub(r"[^0-9a-z\u4e00-\u9fff ]+", " ", text) return re.sub(r"\s+", " ", text).strip() def similarity(a: str, b: str) -> float: if not a or not b: return 0.0 return SequenceMatcher(None, normalize(a), normalize(b)).ratio() def requests_with_retry(url: str, timeout: int, retries: int, sleep_sec: float) -> str: last_error: Exception | None = None for i in range(1, retries + 1): try: resp = requests.get(url, timeout=timeout) except requests.RequestException as exc: last_error = exc if i < retries: time.sleep(i * sleep_sec) continue raise RuntimeError(f"request failed: {exc}") from exc if resp.status_code in RETRYABLE_HTTP_CODES: if i < retries: time.sleep(i * sleep_sec) continue raise RuntimeError(f"request failed: HTTP {resp.status_code}") if resp.status_code >= 400: raise RuntimeError(f"request failed: HTTP {resp.status_code}") return resp.text if last_error: raise RuntimeError(str(last_error)) raise RuntimeError("request failed") def extract_context_json(html_text: str) -> dict[str, Any]: match = CONTEXT_RE.search(html_text) if not match: raise RuntimeError("lentille-context script not found") return json.loads(match.group(1)) def crawl_luogu_titles(base_url: str, timeout: int, retries: int, sleep_sec: float) -> list[str]: tags_csv = ",".join(str(x) for x in DEFAULT_TAG_IDS) url = f"{base_url}/problem/list?type=all&tag={quote(tags_csv)}&page=1" text = requests_with_retry(url, timeout=timeout, retries=retries, sleep_sec=sleep_sec) ctx = extract_context_json(text) result = (((ctx.get("data") or {}).get("problems") or {}).get("result") or []) titles: list[str] = [] for row in result: if not isinstance(row, dict): continue title = str(row.get("title") or "").strip() if title: titles.append(title) return titles def load_existing(conn: sqlite3.Connection) -> list[ExistingProblem]: cur = conn.execute("SELECT id,title,statement_md FROM problems") rows: list[ExistingProblem] = [] for row in cur.fetchall(): rows.append( ExistingProblem( id=int(row[0]), title=str(row[1] or ""), statement_md=str(row[2] or ""), ) ) return rows def collect_keywords(existing: list[ExistingProblem], luogu_titles: list[str]) -> list[str]: bucket: dict[str, int] = {} def add_word(w: str, weight: int = 1) -> None: w = normalize(w) if not w or len(w) < 2: return if w.isdigit(): return bucket[w] = bucket.get(w, 0) + weight for p in existing: parts = re.split(r"[\s,/|+()\[\]【】-]+", p.title) for part in parts: add_word(part, 1) for t in luogu_titles: parts = re.split(r"[\s,/|+()\[\]【】-]+", t) for part in parts: add_word(part, 2) ranked = sorted(bucket.items(), key=lambda x: x[1], reverse=True) return [k for k, _ in ranked[:40]] def llm_generate_problem(prompt: str, timeout: int, retries: int, sleep_sec: float) -> dict[str, Any]: url = os.getenv("OI_LLM_API_URL", "").strip() api_key = os.getenv("OI_LLM_API_KEY", "").strip() model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip() if not url: raise RuntimeError("missing OI_LLM_API_URL") headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" body = { "model": model, "stream": False, "temperature": 0.7, "messages": [ { "role": "system", "content": "你是 CSP-J 出题人。只输出 JSON,不输出额外解释。", }, {"role": "user", "content": prompt}, ], } for i in range(1, retries + 1): try: resp = requests.post(url, headers=headers, json=body, timeout=timeout) except requests.RequestException as exc: if i < retries: time.sleep(i * sleep_sec) continue raise RuntimeError(f"llm failed: {exc}") from exc if resp.status_code in RETRYABLE_HTTP_CODES: if i < retries: time.sleep(i * sleep_sec) continue raise RuntimeError(f"llm failed: HTTP {resp.status_code}") if resp.status_code >= 400: raise RuntimeError(f"llm failed: HTTP {resp.status_code}: {resp.text[:200]}") payload = resp.json() content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content") or "") text = str(content).strip() if text.startswith("```"): text = re.sub(r"^```[a-zA-Z0-9_-]*", "", text).strip() text = text.removesuffix("```").strip() try: obj = json.loads(text) if isinstance(obj, dict): return obj except json.JSONDecodeError: match = re.search(r"\{[\s\S]*\}", text) if match: obj = json.loads(match.group(0)) if isinstance(obj, dict): return obj raise RuntimeError("llm returned non-json content") raise RuntimeError("llm failed") def fallback_generate_problem(sampled_keywords: list[str], llm_error: str) -> dict[str, Any]: seed = now_sec() n = 5 + (seed % 6) m = 7 + (seed % 9) title = f"CSP-J 训练题·余数统计 {seed}" statement_md = f""" # 题目描述 给定一个长度为 {n} 的整数序列,你需要统计有多少个连续子段的元素和对 {m} 取模后等于 0。 ## 输入格式 第一行一个整数 n。 第二行 n 个整数 a_i。 ## 输出格式 输出一个整数,表示满足条件的连续子段数量。 ## 数据范围 - 1 <= n <= 2e5 - |a_i| <= 1e9 ## 提示 可以使用前缀和与计数哈希优化到 O(n)。 """.strip() sample_input = "6\n1 2 3 4 5 6\n" sample_output = "3\n" return { "title": title, "difficulty": 3, "statement_md": statement_md, "sample_input": sample_input, "sample_output": sample_output, "answer": "统计前缀和模 m 的相同值配对数量", "explanation": "维护 prefix % m 的出现次数,当前值为 x 时,答案增加 cnt[x],再令 cnt[x]++。", "knowledge_points": ["前缀和", "哈希计数", "同余"], "tags": ["csp-j", "prefix-sum", "hash"], "llm_error": llm_error[:200], "rag_keywords": sampled_keywords, } def build_problem_md(obj: dict[str, Any]) -> tuple[str, str, str]: statement = str(obj.get("statement_md") or "").strip() if not statement: desc = str(obj.get("description") or "").strip() in_fmt = str(obj.get("input_format") or "").strip() out_fmt = str(obj.get("output_format") or "").strip() statement = "\n\n".join( [ "# 题目描述", desc, "## 输入格式", in_fmt, "## 输出格式", out_fmt, ] ).strip() sample_input = str(obj.get("sample_input") or "").strip() sample_output = str(obj.get("sample_output") or "").strip() return statement, sample_input, sample_output def maybe_duplicate(existing: list[ExistingProblem], title: str, statement_md: str, threshold: float) -> tuple[bool, int | None, float]: best_id = None best_score = 0.0 for p in existing: t_sim = similarity(title, p.title) s_sim = similarity(statement_md[:1200], p.statement_md[:1200]) score = max(t_sim, s_sim * 0.9 + t_sim * 0.1) if score > best_score: best_score = score best_id = p.id return best_score >= threshold, best_id, best_score def insert_problem(conn: sqlite3.Connection, title: str, statement_md: str, sample_input: str, sample_output: str, difficulty: int, profile_json: str, tags: list[str]) -> int: ts = now_sec() slug_base = normalize(title).replace(" ", "-") slug_base = re.sub(r"[^a-z0-9\\-]+", "", slug_base) if not slug_base: slug_base = "cspj-generated" slug = f"{slug_base[:50]}-{ts}" cur = conn.cursor() cur.execute( """ INSERT INTO problems( slug,title,statement_md,difficulty,source,statement_url,llm_profile_json,sample_input,sample_output,created_at ) VALUES(?,?,?,?,?,?,?,?,?,?) """, ( slug, title, statement_md, max(1, min(10, difficulty)), "llm:cspj-generated", "", profile_json, sample_input, sample_output, ts, ), ) problem_id = int(cur.lastrowid) for tag in tags: cur.execute( "INSERT OR IGNORE INTO problem_tags(problem_id,tag) VALUES(?,?)", (problem_id, normalize(tag)), ) conn.commit() return problem_id def main() -> int: parser = argparse.ArgumentParser(description="RAG generate CSP-J problems") parser.add_argument("--db-path", required=True) parser.add_argument("--count", type=int, default=1, help="generate count each run") parser.add_argument("--base-url", default=DEFAULT_BASE_URL) parser.add_argument("--timeout", type=int, default=60) parser.add_argument("--retries", type=int, default=4) parser.add_argument("--retry-sleep-sec", type=float, default=1.5) parser.add_argument("--dedupe-threshold", type=float, default=0.72) args = parser.parse_args() conn = sqlite3.connect(args.db_path) conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA busy_timeout=5000") existing = load_existing(conn) luogu_titles: list[str] = [] try: luogu_titles = crawl_luogu_titles( args.base_url, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec ) except Exception: luogu_titles = [] keywords = collect_keywords(existing, luogu_titles) if not keywords: keywords = ["模拟", "枚举", "前缀和", "字符串", "贪心", "搜索"] inserted = 0 skipped_duplicate = 0 failed = 0 details: list[dict[str, Any]] = [] for _ in range(max(1, args.count)): sampled_keywords = random.sample(keywords, k=min(8, len(keywords))) prompt = f""" 请生成一道原创 CSP-J 风格编程题,难度 2~4,禁止与常见模板题同构。 结合关键词:{', '.join(sampled_keywords)} 输出 JSON: {{ "title": "题目标题", "difficulty": 2, "statement_md": "Markdown 题面(含描述、输入格式、输出格式、数据范围)", "sample_input": "样例输入", "sample_output": "样例输出", "answer": "简要答案关键点", "explanation": "讲解", "knowledge_points": ["知识点1","知识点2"], "tags": ["csp-j","入门","..."] }} """.strip() source = "llm" llm_error = "" try: obj = llm_generate_problem( prompt, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec ) except Exception as exc: source = "fallback" llm_error = str(exc) obj = fallback_generate_problem(sampled_keywords, llm_error) try: title = str(obj.get("title") or "").strip() if not title: raise RuntimeError("generated title is empty") difficulty = int(obj.get("difficulty") or 2) statement_md, sample_input, sample_output = build_problem_md(obj) pre_dup, dup_id, dup_score = maybe_duplicate( existing, title, statement_md, args.dedupe_threshold ) if pre_dup: skipped_duplicate += 1 details.append( { "title": title, "status": "skip_pre_duplicate", "source": source, "similar_problem_id": dup_id, "similarity": round(dup_score, 4), } ) continue profile = { "schema_version": 1, "platform": "llm-generated" if source == "llm" else "fallback-generated", "difficulty": difficulty, "answer": str(obj.get("answer") or ""), "explanation": str(obj.get("explanation") or ""), "knowledge_points": obj.get("knowledge_points") if isinstance(obj.get("knowledge_points"), list) else [], "tags": obj.get("tags") if isinstance(obj.get("tags"), list) else [], "generated_at": now_sec(), "rag_keywords": sampled_keywords, } if llm_error: profile["llm_error"] = llm_error[:300] # Post-check against fresh existing corpus before insert. existing_latest = load_existing(conn) post_dup, post_dup_id, post_dup_score = maybe_duplicate( existing_latest, title, statement_md, args.dedupe_threshold ) if post_dup: skipped_duplicate += 1 details.append( { "title": title, "status": "skip_post_duplicate", "source": source, "similar_problem_id": post_dup_id, "similarity": round(post_dup_score, 4), } ) continue tags = profile["tags"] if isinstance(profile["tags"], list) else [] if "csp-j" not in [normalize(str(x)) for x in tags]: tags = [*tags, "csp-j"] tags = [str(x) for x in tags][:12] problem_id = insert_problem( conn, title=title, statement_md=statement_md, sample_input=sample_input, sample_output=sample_output, difficulty=difficulty, profile_json=json.dumps(profile, ensure_ascii=False), tags=tags, ) inserted += 1 details.append( {"title": title, "status": "inserted", "source": source, "problem_id": problem_id} ) existing.append(ExistingProblem(problem_id, title, statement_md)) except Exception as exc: failed += 1 details.append({"status": "failed", "source": source, "error": str(exc)}) conn.close() print( json.dumps( { "db_path": args.db_path, "requested_count": max(1, args.count), "inserted": inserted, "skipped_duplicate": skipped_duplicate, "failed": failed, "details": details, "keyword_sample_size": len(keywords), }, ensure_ascii=False, indent=2, ) ) return 0 if failed == 0 else 1 if __name__ == "__main__": raise SystemExit(main())