文件
csp/scripts/generate_cspj_problem_rag.py

476 行
16 KiB
Python
原始文件 Blame 文件历史

此文件含有模棱两可的 Unicode 字符
此文件含有可能会与其他字符混淆的 Unicode 字符。 如果您是想特意这样的,可以安全地忽略该警告。 使用 Escape 按钮显示他们。
#!/usr/bin/env python3
"""Generate new CSP-J problems with RAG + dedupe checks."""
from __future__ import annotations
import argparse
import json
import math
import os
import random
import re
import sqlite3
import time
from dataclasses import dataclass
from difflib import SequenceMatcher
from typing import Any
from urllib.parse import quote
import requests
DEFAULT_BASE_URL = "https://www.luogu.com.cn"
DEFAULT_TAG_IDS = [343, 82] # CSP-J + NOIP junior
RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}
CONTEXT_RE = re.compile(
r'<script[^>]*id="lentille-context"[^>]*>(.*?)</script>', re.DOTALL
)
@dataclass
class ExistingProblem:
id: int
title: str
statement_md: str
def now_sec() -> int:
return int(time.time())
def normalize(text: str) -> str:
text = text.lower().strip()
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^0-9a-z\u4e00-\u9fff ]+", " ", text)
return re.sub(r"\s+", " ", text).strip()
def similarity(a: str, b: str) -> float:
if not a or not b:
return 0.0
return SequenceMatcher(None, normalize(a), normalize(b)).ratio()
def requests_with_retry(url: str, timeout: int, retries: int, sleep_sec: float) -> str:
last_error: Exception | None = None
for i in range(1, retries + 1):
try:
resp = requests.get(url, timeout=timeout)
except requests.RequestException as exc:
last_error = exc
if i < retries:
time.sleep(i * sleep_sec)
continue
raise RuntimeError(f"request failed: {exc}") from exc
if resp.status_code in RETRYABLE_HTTP_CODES:
if i < retries:
time.sleep(i * sleep_sec)
continue
raise RuntimeError(f"request failed: HTTP {resp.status_code}")
if resp.status_code >= 400:
raise RuntimeError(f"request failed: HTTP {resp.status_code}")
return resp.text
if last_error:
raise RuntimeError(str(last_error))
raise RuntimeError("request failed")
def extract_context_json(html_text: str) -> dict[str, Any]:
match = CONTEXT_RE.search(html_text)
if not match:
raise RuntimeError("lentille-context script not found")
return json.loads(match.group(1))
def crawl_luogu_titles(base_url: str, timeout: int, retries: int, sleep_sec: float) -> list[str]:
tags_csv = ",".join(str(x) for x in DEFAULT_TAG_IDS)
url = f"{base_url}/problem/list?type=all&tag={quote(tags_csv)}&page=1"
text = requests_with_retry(url, timeout=timeout, retries=retries, sleep_sec=sleep_sec)
ctx = extract_context_json(text)
result = (((ctx.get("data") or {}).get("problems") or {}).get("result") or [])
titles: list[str] = []
for row in result:
if not isinstance(row, dict):
continue
title = str(row.get("title") or "").strip()
if title:
titles.append(title)
return titles
def load_existing(conn: sqlite3.Connection) -> list[ExistingProblem]:
cur = conn.execute("SELECT id,title,statement_md FROM problems")
rows: list[ExistingProblem] = []
for row in cur.fetchall():
rows.append(
ExistingProblem(
id=int(row[0]),
title=str(row[1] or ""),
statement_md=str(row[2] or ""),
)
)
return rows
def collect_keywords(existing: list[ExistingProblem], luogu_titles: list[str]) -> list[str]:
bucket: dict[str, int] = {}
def add_word(w: str, weight: int = 1) -> None:
w = normalize(w)
if not w or len(w) < 2:
return
if w.isdigit():
return
bucket[w] = bucket.get(w, 0) + weight
for p in existing:
parts = re.split(r"[\s,/|+()\[\]【】-]+", p.title)
for part in parts:
add_word(part, 1)
for t in luogu_titles:
parts = re.split(r"[\s,/|+()\[\]【】-]+", t)
for part in parts:
add_word(part, 2)
ranked = sorted(bucket.items(), key=lambda x: x[1], reverse=True)
return [k for k, _ in ranked[:40]]
def llm_generate_problem(prompt: str, timeout: int, retries: int, sleep_sec: float) -> dict[str, Any]:
url = os.getenv("OI_LLM_API_URL", "").strip()
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
if not url:
raise RuntimeError("missing OI_LLM_API_URL")
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
body = {
"model": model,
"stream": False,
"temperature": 0.7,
"messages": [
{
"role": "system",
"content": "你是 CSP-J 出题人。只输出 JSON,不输出额外解释。",
},
{"role": "user", "content": prompt},
],
}
for i in range(1, retries + 1):
try:
resp = requests.post(url, headers=headers, json=body, timeout=timeout)
except requests.RequestException as exc:
if i < retries:
time.sleep(i * sleep_sec)
continue
raise RuntimeError(f"llm failed: {exc}") from exc
if resp.status_code in RETRYABLE_HTTP_CODES:
if i < retries:
time.sleep(i * sleep_sec)
continue
raise RuntimeError(f"llm failed: HTTP {resp.status_code}")
if resp.status_code >= 400:
raise RuntimeError(f"llm failed: HTTP {resp.status_code}: {resp.text[:200]}")
payload = resp.json()
content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content") or "")
text = str(content).strip()
if text.startswith("```"):
text = re.sub(r"^```[a-zA-Z0-9_-]*", "", text).strip()
text = text.removesuffix("```").strip()
try:
obj = json.loads(text)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
match = re.search(r"\{[\s\S]*\}", text)
if match:
obj = json.loads(match.group(0))
if isinstance(obj, dict):
return obj
raise RuntimeError("llm returned non-json content")
raise RuntimeError("llm failed")
def fallback_generate_problem(sampled_keywords: list[str], llm_error: str) -> dict[str, Any]:
seed = now_sec()
n = 5 + (seed % 6)
m = 7 + (seed % 9)
title = f"CSP-J 训练题·余数统计 {seed}"
statement_md = f"""
# 题目描述
给定一个长度为 {n} 的整数序列,你需要统计有多少个连续子段的元素和对 {m} 取模后等于 0。
## 输入格式
第一行一个整数 n。
第二行 n 个整数 a_i。
## 输出格式
输出一个整数,表示满足条件的连续子段数量。
## 数据范围
- 1 <= n <= 2e5
- |a_i| <= 1e9
## 提示
可以使用前缀和与计数哈希优化到 O(n)。
""".strip()
sample_input = "6\n1 2 3 4 5 6\n"
sample_output = "3\n"
return {
"title": title,
"difficulty": 3,
"statement_md": statement_md,
"sample_input": sample_input,
"sample_output": sample_output,
"answer": "统计前缀和模 m 的相同值配对数量",
"explanation": "维护 prefix % m 的出现次数,当前值为 x 时,答案增加 cnt[x],再令 cnt[x]++。",
"knowledge_points": ["前缀和", "哈希计数", "同余"],
"tags": ["csp-j", "prefix-sum", "hash"],
"llm_error": llm_error[:200],
"rag_keywords": sampled_keywords,
}
def build_problem_md(obj: dict[str, Any]) -> tuple[str, str, str]:
statement = str(obj.get("statement_md") or "").strip()
if not statement:
desc = str(obj.get("description") or "").strip()
in_fmt = str(obj.get("input_format") or "").strip()
out_fmt = str(obj.get("output_format") or "").strip()
statement = "\n\n".join(
[
"# 题目描述",
desc,
"## 输入格式",
in_fmt,
"## 输出格式",
out_fmt,
]
).strip()
sample_input = str(obj.get("sample_input") or "").strip()
sample_output = str(obj.get("sample_output") or "").strip()
return statement, sample_input, sample_output
def maybe_duplicate(existing: list[ExistingProblem], title: str, statement_md: str, threshold: float) -> tuple[bool, int | None, float]:
best_id = None
best_score = 0.0
for p in existing:
t_sim = similarity(title, p.title)
s_sim = similarity(statement_md[:1200], p.statement_md[:1200])
score = max(t_sim, s_sim * 0.9 + t_sim * 0.1)
if score > best_score:
best_score = score
best_id = p.id
return best_score >= threshold, best_id, best_score
def insert_problem(conn: sqlite3.Connection, title: str, statement_md: str, sample_input: str, sample_output: str, difficulty: int, profile_json: str, tags: list[str]) -> int:
ts = now_sec()
slug_base = normalize(title).replace(" ", "-")
slug_base = re.sub(r"[^a-z0-9\\-]+", "", slug_base)
if not slug_base:
slug_base = "cspj-generated"
slug = f"{slug_base[:50]}-{ts}"
cur = conn.cursor()
cur.execute(
"""
INSERT INTO problems(
slug,title,statement_md,difficulty,source,statement_url,llm_profile_json,sample_input,sample_output,created_at
) VALUES(?,?,?,?,?,?,?,?,?,?)
""",
(
slug,
title,
statement_md,
max(1, min(10, difficulty)),
"llm:cspj-generated",
"",
profile_json,
sample_input,
sample_output,
ts,
),
)
problem_id = int(cur.lastrowid)
for tag in tags:
cur.execute(
"INSERT OR IGNORE INTO problem_tags(problem_id,tag) VALUES(?,?)",
(problem_id, normalize(tag)),
)
conn.commit()
return problem_id
def main() -> int:
parser = argparse.ArgumentParser(description="RAG generate CSP-J problems")
parser.add_argument("--db-path", required=True)
parser.add_argument("--count", type=int, default=1, help="generate count each run")
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
parser.add_argument("--timeout", type=int, default=60)
parser.add_argument("--retries", type=int, default=4)
parser.add_argument("--retry-sleep-sec", type=float, default=1.5)
parser.add_argument("--dedupe-threshold", type=float, default=0.72)
args = parser.parse_args()
conn = sqlite3.connect(args.db_path)
conn.execute("PRAGMA foreign_keys=ON")
conn.execute("PRAGMA busy_timeout=5000")
existing = load_existing(conn)
luogu_titles: list[str] = []
try:
luogu_titles = crawl_luogu_titles(
args.base_url, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec
)
except Exception:
luogu_titles = []
keywords = collect_keywords(existing, luogu_titles)
if not keywords:
keywords = ["模拟", "枚举", "前缀和", "字符串", "贪心", "搜索"]
inserted = 0
skipped_duplicate = 0
failed = 0
details: list[dict[str, Any]] = []
for _ in range(max(1, args.count)):
sampled_keywords = random.sample(keywords, k=min(8, len(keywords)))
prompt = f"""
请生成一道原创 CSP-J 风格编程题,难度 2~4,禁止与常见模板题同构。
结合关键词:{', '.join(sampled_keywords)}
输出 JSON
{{
"title": "题目标题",
"difficulty": 2,
"statement_md": "Markdown 题面(含描述、输入格式、输出格式、数据范围)",
"sample_input": "样例输入",
"sample_output": "样例输出",
"answer": "简要答案关键点",
"explanation": "讲解",
"knowledge_points": ["知识点1","知识点2"],
"tags": ["csp-j","入门","..."]
}}
""".strip()
source = "llm"
llm_error = ""
try:
obj = llm_generate_problem(
prompt, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec
)
except Exception as exc:
source = "fallback"
llm_error = str(exc)
obj = fallback_generate_problem(sampled_keywords, llm_error)
try:
title = str(obj.get("title") or "").strip()
if not title:
raise RuntimeError("generated title is empty")
difficulty = int(obj.get("difficulty") or 2)
statement_md, sample_input, sample_output = build_problem_md(obj)
pre_dup, dup_id, dup_score = maybe_duplicate(
existing, title, statement_md, args.dedupe_threshold
)
if pre_dup:
skipped_duplicate += 1
details.append(
{
"title": title,
"status": "skip_pre_duplicate",
"source": source,
"similar_problem_id": dup_id,
"similarity": round(dup_score, 4),
}
)
continue
profile = {
"schema_version": 1,
"platform": "llm-generated" if source == "llm" else "fallback-generated",
"difficulty": difficulty,
"answer": str(obj.get("answer") or ""),
"explanation": str(obj.get("explanation") or ""),
"knowledge_points": obj.get("knowledge_points") if isinstance(obj.get("knowledge_points"), list) else [],
"tags": obj.get("tags") if isinstance(obj.get("tags"), list) else [],
"generated_at": now_sec(),
"rag_keywords": sampled_keywords,
}
if llm_error:
profile["llm_error"] = llm_error[:300]
# Post-check against fresh existing corpus before insert.
existing_latest = load_existing(conn)
post_dup, post_dup_id, post_dup_score = maybe_duplicate(
existing_latest, title, statement_md, args.dedupe_threshold
)
if post_dup:
skipped_duplicate += 1
details.append(
{
"title": title,
"status": "skip_post_duplicate",
"source": source,
"similar_problem_id": post_dup_id,
"similarity": round(post_dup_score, 4),
}
)
continue
tags = profile["tags"] if isinstance(profile["tags"], list) else []
if "csp-j" not in [normalize(str(x)) for x in tags]:
tags = [*tags, "csp-j"]
tags = [str(x) for x in tags][:12]
problem_id = insert_problem(
conn,
title=title,
statement_md=statement_md,
sample_input=sample_input,
sample_output=sample_output,
difficulty=difficulty,
profile_json=json.dumps(profile, ensure_ascii=False),
tags=tags,
)
inserted += 1
details.append(
{"title": title, "status": "inserted", "source": source, "problem_id": problem_id}
)
existing.append(ExistingProblem(problem_id, title, statement_md))
except Exception as exc:
failed += 1
details.append({"status": "failed", "source": source, "error": str(exc)})
conn.close()
print(
json.dumps(
{
"db_path": args.db_path,
"requested_count": max(1, args.count),
"inserted": inserted,
"skipped_duplicate": skipped_duplicate,
"failed": failed,
"details": details,
"keyword_sample_size": len(keywords),
},
ensure_ascii=False,
indent=2,
)
)
return 0 if failed == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())