feat: rebuild CSP practice workflow, UX and automation
这个提交包含在:
475
scripts/generate_cspj_problem_rag.py
普通文件
475
scripts/generate_cspj_problem_rag.py
普通文件
@@ -0,0 +1,475 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate new CSP-J problems with RAG + dedupe checks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
|
||||
DEFAULT_BASE_URL = "https://www.luogu.com.cn"
|
||||
DEFAULT_TAG_IDS = [343, 82] # CSP-J + NOIP junior
|
||||
RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}
|
||||
CONTEXT_RE = re.compile(
|
||||
r'<script[^>]*id="lentille-context"[^>]*>(.*?)</script>', re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExistingProblem:
|
||||
id: int
|
||||
title: str
|
||||
statement_md: str
|
||||
|
||||
|
||||
def now_sec() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def normalize(text: str) -> str:
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = re.sub(r"[^0-9a-z\u4e00-\u9fff ]+", " ", text)
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
def similarity(a: str, b: str) -> float:
|
||||
if not a or not b:
|
||||
return 0.0
|
||||
return SequenceMatcher(None, normalize(a), normalize(b)).ratio()
|
||||
|
||||
|
||||
def requests_with_retry(url: str, timeout: int, retries: int, sleep_sec: float) -> str:
|
||||
last_error: Exception | None = None
|
||||
for i in range(1, retries + 1):
|
||||
try:
|
||||
resp = requests.get(url, timeout=timeout)
|
||||
except requests.RequestException as exc:
|
||||
last_error = exc
|
||||
if i < retries:
|
||||
time.sleep(i * sleep_sec)
|
||||
continue
|
||||
raise RuntimeError(f"request failed: {exc}") from exc
|
||||
if resp.status_code in RETRYABLE_HTTP_CODES:
|
||||
if i < retries:
|
||||
time.sleep(i * sleep_sec)
|
||||
continue
|
||||
raise RuntimeError(f"request failed: HTTP {resp.status_code}")
|
||||
if resp.status_code >= 400:
|
||||
raise RuntimeError(f"request failed: HTTP {resp.status_code}")
|
||||
return resp.text
|
||||
if last_error:
|
||||
raise RuntimeError(str(last_error))
|
||||
raise RuntimeError("request failed")
|
||||
|
||||
|
||||
def extract_context_json(html_text: str) -> dict[str, Any]:
|
||||
match = CONTEXT_RE.search(html_text)
|
||||
if not match:
|
||||
raise RuntimeError("lentille-context script not found")
|
||||
return json.loads(match.group(1))
|
||||
|
||||
|
||||
def crawl_luogu_titles(base_url: str, timeout: int, retries: int, sleep_sec: float) -> list[str]:
|
||||
tags_csv = ",".join(str(x) for x in DEFAULT_TAG_IDS)
|
||||
url = f"{base_url}/problem/list?type=all&tag={quote(tags_csv)}&page=1"
|
||||
text = requests_with_retry(url, timeout=timeout, retries=retries, sleep_sec=sleep_sec)
|
||||
ctx = extract_context_json(text)
|
||||
result = (((ctx.get("data") or {}).get("problems") or {}).get("result") or [])
|
||||
titles: list[str] = []
|
||||
for row in result:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
title = str(row.get("title") or "").strip()
|
||||
if title:
|
||||
titles.append(title)
|
||||
return titles
|
||||
|
||||
|
||||
def load_existing(conn: sqlite3.Connection) -> list[ExistingProblem]:
|
||||
cur = conn.execute("SELECT id,title,statement_md FROM problems")
|
||||
rows: list[ExistingProblem] = []
|
||||
for row in cur.fetchall():
|
||||
rows.append(
|
||||
ExistingProblem(
|
||||
id=int(row[0]),
|
||||
title=str(row[1] or ""),
|
||||
statement_md=str(row[2] or ""),
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def collect_keywords(existing: list[ExistingProblem], luogu_titles: list[str]) -> list[str]:
|
||||
bucket: dict[str, int] = {}
|
||||
|
||||
def add_word(w: str, weight: int = 1) -> None:
|
||||
w = normalize(w)
|
||||
if not w or len(w) < 2:
|
||||
return
|
||||
if w.isdigit():
|
||||
return
|
||||
bucket[w] = bucket.get(w, 0) + weight
|
||||
|
||||
for p in existing:
|
||||
parts = re.split(r"[\s,/|+()\[\]【】-]+", p.title)
|
||||
for part in parts:
|
||||
add_word(part, 1)
|
||||
|
||||
for t in luogu_titles:
|
||||
parts = re.split(r"[\s,/|+()\[\]【】-]+", t)
|
||||
for part in parts:
|
||||
add_word(part, 2)
|
||||
|
||||
ranked = sorted(bucket.items(), key=lambda x: x[1], reverse=True)
|
||||
return [k for k, _ in ranked[:40]]
|
||||
|
||||
|
||||
def llm_generate_problem(prompt: str, timeout: int, retries: int, sleep_sec: float) -> dict[str, Any]:
|
||||
url = os.getenv("OI_LLM_API_URL", "").strip()
|
||||
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
|
||||
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
|
||||
if not url:
|
||||
raise RuntimeError("missing OI_LLM_API_URL")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
body = {
|
||||
"model": model,
|
||||
"stream": False,
|
||||
"temperature": 0.7,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是 CSP-J 出题人。只输出 JSON,不输出额外解释。",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
}
|
||||
|
||||
for i in range(1, retries + 1):
|
||||
try:
|
||||
resp = requests.post(url, headers=headers, json=body, timeout=timeout)
|
||||
except requests.RequestException as exc:
|
||||
if i < retries:
|
||||
time.sleep(i * sleep_sec)
|
||||
continue
|
||||
raise RuntimeError(f"llm failed: {exc}") from exc
|
||||
|
||||
if resp.status_code in RETRYABLE_HTTP_CODES:
|
||||
if i < retries:
|
||||
time.sleep(i * sleep_sec)
|
||||
continue
|
||||
raise RuntimeError(f"llm failed: HTTP {resp.status_code}")
|
||||
if resp.status_code >= 400:
|
||||
raise RuntimeError(f"llm failed: HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
payload = resp.json()
|
||||
content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content") or "")
|
||||
text = str(content).strip()
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```[a-zA-Z0-9_-]*", "", text).strip()
|
||||
text = text.removesuffix("```").strip()
|
||||
try:
|
||||
obj = json.loads(text)
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{[\s\S]*\}", text)
|
||||
if match:
|
||||
obj = json.loads(match.group(0))
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
raise RuntimeError("llm returned non-json content")
|
||||
raise RuntimeError("llm failed")
|
||||
|
||||
|
||||
def fallback_generate_problem(sampled_keywords: list[str], llm_error: str) -> dict[str, Any]:
|
||||
seed = now_sec()
|
||||
n = 5 + (seed % 6)
|
||||
m = 7 + (seed % 9)
|
||||
title = f"CSP-J 训练题·余数统计 {seed}"
|
||||
statement_md = f"""
|
||||
# 题目描述
|
||||
给定一个长度为 {n} 的整数序列,你需要统计有多少个连续子段的元素和对 {m} 取模后等于 0。
|
||||
|
||||
## 输入格式
|
||||
第一行一个整数 n。
|
||||
第二行 n 个整数 a_i。
|
||||
|
||||
## 输出格式
|
||||
输出一个整数,表示满足条件的连续子段数量。
|
||||
|
||||
## 数据范围
|
||||
- 1 <= n <= 2e5
|
||||
- |a_i| <= 1e9
|
||||
|
||||
## 提示
|
||||
可以使用前缀和与计数哈希优化到 O(n)。
|
||||
""".strip()
|
||||
sample_input = "6\n1 2 3 4 5 6\n"
|
||||
sample_output = "3\n"
|
||||
return {
|
||||
"title": title,
|
||||
"difficulty": 3,
|
||||
"statement_md": statement_md,
|
||||
"sample_input": sample_input,
|
||||
"sample_output": sample_output,
|
||||
"answer": "统计前缀和模 m 的相同值配对数量",
|
||||
"explanation": "维护 prefix % m 的出现次数,当前值为 x 时,答案增加 cnt[x],再令 cnt[x]++。",
|
||||
"knowledge_points": ["前缀和", "哈希计数", "同余"],
|
||||
"tags": ["csp-j", "prefix-sum", "hash"],
|
||||
"llm_error": llm_error[:200],
|
||||
"rag_keywords": sampled_keywords,
|
||||
}
|
||||
|
||||
|
||||
def build_problem_md(obj: dict[str, Any]) -> tuple[str, str, str]:
|
||||
statement = str(obj.get("statement_md") or "").strip()
|
||||
if not statement:
|
||||
desc = str(obj.get("description") or "").strip()
|
||||
in_fmt = str(obj.get("input_format") or "").strip()
|
||||
out_fmt = str(obj.get("output_format") or "").strip()
|
||||
statement = "\n\n".join(
|
||||
[
|
||||
"# 题目描述",
|
||||
desc,
|
||||
"## 输入格式",
|
||||
in_fmt,
|
||||
"## 输出格式",
|
||||
out_fmt,
|
||||
]
|
||||
).strip()
|
||||
sample_input = str(obj.get("sample_input") or "").strip()
|
||||
sample_output = str(obj.get("sample_output") or "").strip()
|
||||
return statement, sample_input, sample_output
|
||||
|
||||
|
||||
def maybe_duplicate(existing: list[ExistingProblem], title: str, statement_md: str, threshold: float) -> tuple[bool, int | None, float]:
|
||||
best_id = None
|
||||
best_score = 0.0
|
||||
for p in existing:
|
||||
t_sim = similarity(title, p.title)
|
||||
s_sim = similarity(statement_md[:1200], p.statement_md[:1200])
|
||||
score = max(t_sim, s_sim * 0.9 + t_sim * 0.1)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_id = p.id
|
||||
return best_score >= threshold, best_id, best_score
|
||||
|
||||
|
||||
def insert_problem(conn: sqlite3.Connection, title: str, statement_md: str, sample_input: str, sample_output: str, difficulty: int, profile_json: str, tags: list[str]) -> int:
|
||||
ts = now_sec()
|
||||
slug_base = normalize(title).replace(" ", "-")
|
||||
slug_base = re.sub(r"[^a-z0-9\\-]+", "", slug_base)
|
||||
if not slug_base:
|
||||
slug_base = "cspj-generated"
|
||||
slug = f"{slug_base[:50]}-{ts}"
|
||||
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO problems(
|
||||
slug,title,statement_md,difficulty,source,statement_url,llm_profile_json,sample_input,sample_output,created_at
|
||||
) VALUES(?,?,?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(
|
||||
slug,
|
||||
title,
|
||||
statement_md,
|
||||
max(1, min(10, difficulty)),
|
||||
"llm:cspj-generated",
|
||||
"",
|
||||
profile_json,
|
||||
sample_input,
|
||||
sample_output,
|
||||
ts,
|
||||
),
|
||||
)
|
||||
problem_id = int(cur.lastrowid)
|
||||
for tag in tags:
|
||||
cur.execute(
|
||||
"INSERT OR IGNORE INTO problem_tags(problem_id,tag) VALUES(?,?)",
|
||||
(problem_id, normalize(tag)),
|
||||
)
|
||||
conn.commit()
|
||||
return problem_id
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="RAG generate CSP-J problems")
|
||||
parser.add_argument("--db-path", required=True)
|
||||
parser.add_argument("--count", type=int, default=1, help="generate count each run")
|
||||
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
|
||||
parser.add_argument("--timeout", type=int, default=60)
|
||||
parser.add_argument("--retries", type=int, default=4)
|
||||
parser.add_argument("--retry-sleep-sec", type=float, default=1.5)
|
||||
parser.add_argument("--dedupe-threshold", type=float, default=0.72)
|
||||
args = parser.parse_args()
|
||||
|
||||
conn = sqlite3.connect(args.db_path)
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
|
||||
existing = load_existing(conn)
|
||||
luogu_titles: list[str] = []
|
||||
try:
|
||||
luogu_titles = crawl_luogu_titles(
|
||||
args.base_url, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec
|
||||
)
|
||||
except Exception:
|
||||
luogu_titles = []
|
||||
|
||||
keywords = collect_keywords(existing, luogu_titles)
|
||||
if not keywords:
|
||||
keywords = ["模拟", "枚举", "前缀和", "字符串", "贪心", "搜索"]
|
||||
|
||||
inserted = 0
|
||||
skipped_duplicate = 0
|
||||
failed = 0
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
for _ in range(max(1, args.count)):
|
||||
sampled_keywords = random.sample(keywords, k=min(8, len(keywords)))
|
||||
prompt = f"""
|
||||
请生成一道原创 CSP-J 风格编程题,难度 2~4,禁止与常见模板题同构。
|
||||
结合关键词:{', '.join(sampled_keywords)}
|
||||
|
||||
输出 JSON:
|
||||
{{
|
||||
"title": "题目标题",
|
||||
"difficulty": 2,
|
||||
"statement_md": "Markdown 题面(含描述、输入格式、输出格式、数据范围)",
|
||||
"sample_input": "样例输入",
|
||||
"sample_output": "样例输出",
|
||||
"answer": "简要答案关键点",
|
||||
"explanation": "讲解",
|
||||
"knowledge_points": ["知识点1","知识点2"],
|
||||
"tags": ["csp-j","入门","..."]
|
||||
}}
|
||||
""".strip()
|
||||
|
||||
source = "llm"
|
||||
llm_error = ""
|
||||
try:
|
||||
obj = llm_generate_problem(
|
||||
prompt, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec
|
||||
)
|
||||
except Exception as exc:
|
||||
source = "fallback"
|
||||
llm_error = str(exc)
|
||||
obj = fallback_generate_problem(sampled_keywords, llm_error)
|
||||
|
||||
try:
|
||||
title = str(obj.get("title") or "").strip()
|
||||
if not title:
|
||||
raise RuntimeError("generated title is empty")
|
||||
difficulty = int(obj.get("difficulty") or 2)
|
||||
statement_md, sample_input, sample_output = build_problem_md(obj)
|
||||
|
||||
pre_dup, dup_id, dup_score = maybe_duplicate(
|
||||
existing, title, statement_md, args.dedupe_threshold
|
||||
)
|
||||
if pre_dup:
|
||||
skipped_duplicate += 1
|
||||
details.append(
|
||||
{
|
||||
"title": title,
|
||||
"status": "skip_pre_duplicate",
|
||||
"source": source,
|
||||
"similar_problem_id": dup_id,
|
||||
"similarity": round(dup_score, 4),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
profile = {
|
||||
"schema_version": 1,
|
||||
"platform": "llm-generated" if source == "llm" else "fallback-generated",
|
||||
"difficulty": difficulty,
|
||||
"answer": str(obj.get("answer") or ""),
|
||||
"explanation": str(obj.get("explanation") or ""),
|
||||
"knowledge_points": obj.get("knowledge_points") if isinstance(obj.get("knowledge_points"), list) else [],
|
||||
"tags": obj.get("tags") if isinstance(obj.get("tags"), list) else [],
|
||||
"generated_at": now_sec(),
|
||||
"rag_keywords": sampled_keywords,
|
||||
}
|
||||
if llm_error:
|
||||
profile["llm_error"] = llm_error[:300]
|
||||
|
||||
# Post-check against fresh existing corpus before insert.
|
||||
existing_latest = load_existing(conn)
|
||||
post_dup, post_dup_id, post_dup_score = maybe_duplicate(
|
||||
existing_latest, title, statement_md, args.dedupe_threshold
|
||||
)
|
||||
if post_dup:
|
||||
skipped_duplicate += 1
|
||||
details.append(
|
||||
{
|
||||
"title": title,
|
||||
"status": "skip_post_duplicate",
|
||||
"source": source,
|
||||
"similar_problem_id": post_dup_id,
|
||||
"similarity": round(post_dup_score, 4),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tags = profile["tags"] if isinstance(profile["tags"], list) else []
|
||||
if "csp-j" not in [normalize(str(x)) for x in tags]:
|
||||
tags = [*tags, "csp-j"]
|
||||
tags = [str(x) for x in tags][:12]
|
||||
|
||||
problem_id = insert_problem(
|
||||
conn,
|
||||
title=title,
|
||||
statement_md=statement_md,
|
||||
sample_input=sample_input,
|
||||
sample_output=sample_output,
|
||||
difficulty=difficulty,
|
||||
profile_json=json.dumps(profile, ensure_ascii=False),
|
||||
tags=tags,
|
||||
)
|
||||
inserted += 1
|
||||
details.append(
|
||||
{"title": title, "status": "inserted", "source": source, "problem_id": problem_id}
|
||||
)
|
||||
existing.append(ExistingProblem(problem_id, title, statement_md))
|
||||
except Exception as exc:
|
||||
failed += 1
|
||||
details.append({"status": "failed", "source": source, "error": str(exc)})
|
||||
|
||||
conn.close()
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"db_path": args.db_path,
|
||||
"requested_count": max(1, args.count),
|
||||
"inserted": inserted,
|
||||
"skipped_duplicate": skipped_duplicate,
|
||||
"failed": failed,
|
||||
"details": details,
|
||||
"keyword_sample_size": len(keywords),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
在新工单中引用
屏蔽一个用户