476 行
16 KiB
Python
476 行
16 KiB
Python
#!/usr/bin/env python3
|
||
"""Generate new CSP-J problems with RAG + dedupe checks."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import math
|
||
import os
|
||
import random
|
||
import re
|
||
import sqlite3
|
||
import time
|
||
from dataclasses import dataclass
|
||
from difflib import SequenceMatcher
|
||
from typing import Any
|
||
from urllib.parse import quote
|
||
|
||
import requests
|
||
|
||
DEFAULT_BASE_URL = "https://www.luogu.com.cn"
|
||
DEFAULT_TAG_IDS = [343, 82] # CSP-J + NOIP junior
|
||
RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}
|
||
CONTEXT_RE = re.compile(
|
||
r'<script[^>]*id="lentille-context"[^>]*>(.*?)</script>', re.DOTALL
|
||
)
|
||
|
||
|
||
@dataclass
|
||
class ExistingProblem:
|
||
id: int
|
||
title: str
|
||
statement_md: str
|
||
|
||
|
||
def now_sec() -> int:
|
||
return int(time.time())
|
||
|
||
|
||
def normalize(text: str) -> str:
|
||
text = text.lower().strip()
|
||
text = re.sub(r"\s+", " ", text)
|
||
text = re.sub(r"[^0-9a-z\u4e00-\u9fff ]+", " ", text)
|
||
return re.sub(r"\s+", " ", text).strip()
|
||
|
||
|
||
def similarity(a: str, b: str) -> float:
|
||
if not a or not b:
|
||
return 0.0
|
||
return SequenceMatcher(None, normalize(a), normalize(b)).ratio()
|
||
|
||
|
||
def requests_with_retry(url: str, timeout: int, retries: int, sleep_sec: float) -> str:
|
||
last_error: Exception | None = None
|
||
for i in range(1, retries + 1):
|
||
try:
|
||
resp = requests.get(url, timeout=timeout)
|
||
except requests.RequestException as exc:
|
||
last_error = exc
|
||
if i < retries:
|
||
time.sleep(i * sleep_sec)
|
||
continue
|
||
raise RuntimeError(f"request failed: {exc}") from exc
|
||
if resp.status_code in RETRYABLE_HTTP_CODES:
|
||
if i < retries:
|
||
time.sleep(i * sleep_sec)
|
||
continue
|
||
raise RuntimeError(f"request failed: HTTP {resp.status_code}")
|
||
if resp.status_code >= 400:
|
||
raise RuntimeError(f"request failed: HTTP {resp.status_code}")
|
||
return resp.text
|
||
if last_error:
|
||
raise RuntimeError(str(last_error))
|
||
raise RuntimeError("request failed")
|
||
|
||
|
||
def extract_context_json(html_text: str) -> dict[str, Any]:
|
||
match = CONTEXT_RE.search(html_text)
|
||
if not match:
|
||
raise RuntimeError("lentille-context script not found")
|
||
return json.loads(match.group(1))
|
||
|
||
|
||
def crawl_luogu_titles(base_url: str, timeout: int, retries: int, sleep_sec: float) -> list[str]:
|
||
tags_csv = ",".join(str(x) for x in DEFAULT_TAG_IDS)
|
||
url = f"{base_url}/problem/list?type=all&tag={quote(tags_csv)}&page=1"
|
||
text = requests_with_retry(url, timeout=timeout, retries=retries, sleep_sec=sleep_sec)
|
||
ctx = extract_context_json(text)
|
||
result = (((ctx.get("data") or {}).get("problems") or {}).get("result") or [])
|
||
titles: list[str] = []
|
||
for row in result:
|
||
if not isinstance(row, dict):
|
||
continue
|
||
title = str(row.get("title") or "").strip()
|
||
if title:
|
||
titles.append(title)
|
||
return titles
|
||
|
||
|
||
def load_existing(conn: sqlite3.Connection) -> list[ExistingProblem]:
|
||
cur = conn.execute("SELECT id,title,statement_md FROM problems")
|
||
rows: list[ExistingProblem] = []
|
||
for row in cur.fetchall():
|
||
rows.append(
|
||
ExistingProblem(
|
||
id=int(row[0]),
|
||
title=str(row[1] or ""),
|
||
statement_md=str(row[2] or ""),
|
||
)
|
||
)
|
||
return rows
|
||
|
||
|
||
def collect_keywords(existing: list[ExistingProblem], luogu_titles: list[str]) -> list[str]:
|
||
bucket: dict[str, int] = {}
|
||
|
||
def add_word(w: str, weight: int = 1) -> None:
|
||
w = normalize(w)
|
||
if not w or len(w) < 2:
|
||
return
|
||
if w.isdigit():
|
||
return
|
||
bucket[w] = bucket.get(w, 0) + weight
|
||
|
||
for p in existing:
|
||
parts = re.split(r"[\s,/|+()\[\]【】-]+", p.title)
|
||
for part in parts:
|
||
add_word(part, 1)
|
||
|
||
for t in luogu_titles:
|
||
parts = re.split(r"[\s,/|+()\[\]【】-]+", t)
|
||
for part in parts:
|
||
add_word(part, 2)
|
||
|
||
ranked = sorted(bucket.items(), key=lambda x: x[1], reverse=True)
|
||
return [k for k, _ in ranked[:40]]
|
||
|
||
|
||
def llm_generate_problem(prompt: str, timeout: int, retries: int, sleep_sec: float) -> dict[str, Any]:
|
||
url = os.getenv("OI_LLM_API_URL", "").strip()
|
||
api_key = os.getenv("OI_LLM_API_KEY", "").strip()
|
||
model = os.getenv("OI_LLM_MODEL", "qwen3-max").strip()
|
||
if not url:
|
||
raise RuntimeError("missing OI_LLM_API_URL")
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
body = {
|
||
"model": model,
|
||
"stream": False,
|
||
"temperature": 0.7,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是 CSP-J 出题人。只输出 JSON,不输出额外解释。",
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
}
|
||
|
||
for i in range(1, retries + 1):
|
||
try:
|
||
resp = requests.post(url, headers=headers, json=body, timeout=timeout)
|
||
except requests.RequestException as exc:
|
||
if i < retries:
|
||
time.sleep(i * sleep_sec)
|
||
continue
|
||
raise RuntimeError(f"llm failed: {exc}") from exc
|
||
|
||
if resp.status_code in RETRYABLE_HTTP_CODES:
|
||
if i < retries:
|
||
time.sleep(i * sleep_sec)
|
||
continue
|
||
raise RuntimeError(f"llm failed: HTTP {resp.status_code}")
|
||
if resp.status_code >= 400:
|
||
raise RuntimeError(f"llm failed: HTTP {resp.status_code}: {resp.text[:200]}")
|
||
|
||
payload = resp.json()
|
||
content = (((payload.get("choices") or [{}])[0].get("message") or {}).get("content") or "")
|
||
text = str(content).strip()
|
||
if text.startswith("```"):
|
||
text = re.sub(r"^```[a-zA-Z0-9_-]*", "", text).strip()
|
||
text = text.removesuffix("```").strip()
|
||
try:
|
||
obj = json.loads(text)
|
||
if isinstance(obj, dict):
|
||
return obj
|
||
except json.JSONDecodeError:
|
||
match = re.search(r"\{[\s\S]*\}", text)
|
||
if match:
|
||
obj = json.loads(match.group(0))
|
||
if isinstance(obj, dict):
|
||
return obj
|
||
raise RuntimeError("llm returned non-json content")
|
||
raise RuntimeError("llm failed")
|
||
|
||
|
||
def fallback_generate_problem(sampled_keywords: list[str], llm_error: str) -> dict[str, Any]:
|
||
seed = now_sec()
|
||
n = 5 + (seed % 6)
|
||
m = 7 + (seed % 9)
|
||
title = f"CSP-J 训练题·余数统计 {seed}"
|
||
statement_md = f"""
|
||
# 题目描述
|
||
给定一个长度为 {n} 的整数序列,你需要统计有多少个连续子段的元素和对 {m} 取模后等于 0。
|
||
|
||
## 输入格式
|
||
第一行一个整数 n。
|
||
第二行 n 个整数 a_i。
|
||
|
||
## 输出格式
|
||
输出一个整数,表示满足条件的连续子段数量。
|
||
|
||
## 数据范围
|
||
- 1 <= n <= 2e5
|
||
- |a_i| <= 1e9
|
||
|
||
## 提示
|
||
可以使用前缀和与计数哈希优化到 O(n)。
|
||
""".strip()
|
||
sample_input = "6\n1 2 3 4 5 6\n"
|
||
sample_output = "3\n"
|
||
return {
|
||
"title": title,
|
||
"difficulty": 3,
|
||
"statement_md": statement_md,
|
||
"sample_input": sample_input,
|
||
"sample_output": sample_output,
|
||
"answer": "统计前缀和模 m 的相同值配对数量",
|
||
"explanation": "维护 prefix % m 的出现次数,当前值为 x 时,答案增加 cnt[x],再令 cnt[x]++。",
|
||
"knowledge_points": ["前缀和", "哈希计数", "同余"],
|
||
"tags": ["csp-j", "prefix-sum", "hash"],
|
||
"llm_error": llm_error[:200],
|
||
"rag_keywords": sampled_keywords,
|
||
}
|
||
|
||
|
||
def build_problem_md(obj: dict[str, Any]) -> tuple[str, str, str]:
|
||
statement = str(obj.get("statement_md") or "").strip()
|
||
if not statement:
|
||
desc = str(obj.get("description") or "").strip()
|
||
in_fmt = str(obj.get("input_format") or "").strip()
|
||
out_fmt = str(obj.get("output_format") or "").strip()
|
||
statement = "\n\n".join(
|
||
[
|
||
"# 题目描述",
|
||
desc,
|
||
"## 输入格式",
|
||
in_fmt,
|
||
"## 输出格式",
|
||
out_fmt,
|
||
]
|
||
).strip()
|
||
sample_input = str(obj.get("sample_input") or "").strip()
|
||
sample_output = str(obj.get("sample_output") or "").strip()
|
||
return statement, sample_input, sample_output
|
||
|
||
|
||
def maybe_duplicate(existing: list[ExistingProblem], title: str, statement_md: str, threshold: float) -> tuple[bool, int | None, float]:
|
||
best_id = None
|
||
best_score = 0.0
|
||
for p in existing:
|
||
t_sim = similarity(title, p.title)
|
||
s_sim = similarity(statement_md[:1200], p.statement_md[:1200])
|
||
score = max(t_sim, s_sim * 0.9 + t_sim * 0.1)
|
||
if score > best_score:
|
||
best_score = score
|
||
best_id = p.id
|
||
return best_score >= threshold, best_id, best_score
|
||
|
||
|
||
def insert_problem(conn: sqlite3.Connection, title: str, statement_md: str, sample_input: str, sample_output: str, difficulty: int, profile_json: str, tags: list[str]) -> int:
|
||
ts = now_sec()
|
||
slug_base = normalize(title).replace(" ", "-")
|
||
slug_base = re.sub(r"[^a-z0-9\\-]+", "", slug_base)
|
||
if not slug_base:
|
||
slug_base = "cspj-generated"
|
||
slug = f"{slug_base[:50]}-{ts}"
|
||
|
||
cur = conn.cursor()
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO problems(
|
||
slug,title,statement_md,difficulty,source,statement_url,llm_profile_json,sample_input,sample_output,created_at
|
||
) VALUES(?,?,?,?,?,?,?,?,?,?)
|
||
""",
|
||
(
|
||
slug,
|
||
title,
|
||
statement_md,
|
||
max(1, min(10, difficulty)),
|
||
"llm:cspj-generated",
|
||
"",
|
||
profile_json,
|
||
sample_input,
|
||
sample_output,
|
||
ts,
|
||
),
|
||
)
|
||
problem_id = int(cur.lastrowid)
|
||
for tag in tags:
|
||
cur.execute(
|
||
"INSERT OR IGNORE INTO problem_tags(problem_id,tag) VALUES(?,?)",
|
||
(problem_id, normalize(tag)),
|
||
)
|
||
conn.commit()
|
||
return problem_id
|
||
|
||
|
||
def main() -> int:
|
||
parser = argparse.ArgumentParser(description="RAG generate CSP-J problems")
|
||
parser.add_argument("--db-path", required=True)
|
||
parser.add_argument("--count", type=int, default=1, help="generate count each run")
|
||
parser.add_argument("--base-url", default=DEFAULT_BASE_URL)
|
||
parser.add_argument("--timeout", type=int, default=60)
|
||
parser.add_argument("--retries", type=int, default=4)
|
||
parser.add_argument("--retry-sleep-sec", type=float, default=1.5)
|
||
parser.add_argument("--dedupe-threshold", type=float, default=0.72)
|
||
args = parser.parse_args()
|
||
|
||
conn = sqlite3.connect(args.db_path)
|
||
conn.execute("PRAGMA foreign_keys=ON")
|
||
conn.execute("PRAGMA busy_timeout=5000")
|
||
|
||
existing = load_existing(conn)
|
||
luogu_titles: list[str] = []
|
||
try:
|
||
luogu_titles = crawl_luogu_titles(
|
||
args.base_url, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec
|
||
)
|
||
except Exception:
|
||
luogu_titles = []
|
||
|
||
keywords = collect_keywords(existing, luogu_titles)
|
||
if not keywords:
|
||
keywords = ["模拟", "枚举", "前缀和", "字符串", "贪心", "搜索"]
|
||
|
||
inserted = 0
|
||
skipped_duplicate = 0
|
||
failed = 0
|
||
details: list[dict[str, Any]] = []
|
||
|
||
for _ in range(max(1, args.count)):
|
||
sampled_keywords = random.sample(keywords, k=min(8, len(keywords)))
|
||
prompt = f"""
|
||
请生成一道原创 CSP-J 风格编程题,难度 2~4,禁止与常见模板题同构。
|
||
结合关键词:{', '.join(sampled_keywords)}
|
||
|
||
输出 JSON:
|
||
{{
|
||
"title": "题目标题",
|
||
"difficulty": 2,
|
||
"statement_md": "Markdown 题面(含描述、输入格式、输出格式、数据范围)",
|
||
"sample_input": "样例输入",
|
||
"sample_output": "样例输出",
|
||
"answer": "简要答案关键点",
|
||
"explanation": "讲解",
|
||
"knowledge_points": ["知识点1","知识点2"],
|
||
"tags": ["csp-j","入门","..."]
|
||
}}
|
||
""".strip()
|
||
|
||
source = "llm"
|
||
llm_error = ""
|
||
try:
|
||
obj = llm_generate_problem(
|
||
prompt, timeout=args.timeout, retries=args.retries, sleep_sec=args.retry_sleep_sec
|
||
)
|
||
except Exception as exc:
|
||
source = "fallback"
|
||
llm_error = str(exc)
|
||
obj = fallback_generate_problem(sampled_keywords, llm_error)
|
||
|
||
try:
|
||
title = str(obj.get("title") or "").strip()
|
||
if not title:
|
||
raise RuntimeError("generated title is empty")
|
||
difficulty = int(obj.get("difficulty") or 2)
|
||
statement_md, sample_input, sample_output = build_problem_md(obj)
|
||
|
||
pre_dup, dup_id, dup_score = maybe_duplicate(
|
||
existing, title, statement_md, args.dedupe_threshold
|
||
)
|
||
if pre_dup:
|
||
skipped_duplicate += 1
|
||
details.append(
|
||
{
|
||
"title": title,
|
||
"status": "skip_pre_duplicate",
|
||
"source": source,
|
||
"similar_problem_id": dup_id,
|
||
"similarity": round(dup_score, 4),
|
||
}
|
||
)
|
||
continue
|
||
|
||
profile = {
|
||
"schema_version": 1,
|
||
"platform": "llm-generated" if source == "llm" else "fallback-generated",
|
||
"difficulty": difficulty,
|
||
"answer": str(obj.get("answer") or ""),
|
||
"explanation": str(obj.get("explanation") or ""),
|
||
"knowledge_points": obj.get("knowledge_points") if isinstance(obj.get("knowledge_points"), list) else [],
|
||
"tags": obj.get("tags") if isinstance(obj.get("tags"), list) else [],
|
||
"generated_at": now_sec(),
|
||
"rag_keywords": sampled_keywords,
|
||
}
|
||
if llm_error:
|
||
profile["llm_error"] = llm_error[:300]
|
||
|
||
# Post-check against fresh existing corpus before insert.
|
||
existing_latest = load_existing(conn)
|
||
post_dup, post_dup_id, post_dup_score = maybe_duplicate(
|
||
existing_latest, title, statement_md, args.dedupe_threshold
|
||
)
|
||
if post_dup:
|
||
skipped_duplicate += 1
|
||
details.append(
|
||
{
|
||
"title": title,
|
||
"status": "skip_post_duplicate",
|
||
"source": source,
|
||
"similar_problem_id": post_dup_id,
|
||
"similarity": round(post_dup_score, 4),
|
||
}
|
||
)
|
||
continue
|
||
|
||
tags = profile["tags"] if isinstance(profile["tags"], list) else []
|
||
if "csp-j" not in [normalize(str(x)) for x in tags]:
|
||
tags = [*tags, "csp-j"]
|
||
tags = [str(x) for x in tags][:12]
|
||
|
||
problem_id = insert_problem(
|
||
conn,
|
||
title=title,
|
||
statement_md=statement_md,
|
||
sample_input=sample_input,
|
||
sample_output=sample_output,
|
||
difficulty=difficulty,
|
||
profile_json=json.dumps(profile, ensure_ascii=False),
|
||
tags=tags,
|
||
)
|
||
inserted += 1
|
||
details.append(
|
||
{"title": title, "status": "inserted", "source": source, "problem_id": problem_id}
|
||
)
|
||
existing.append(ExistingProblem(problem_id, title, statement_md))
|
||
except Exception as exc:
|
||
failed += 1
|
||
details.append({"status": "failed", "source": source, "error": str(exc)})
|
||
|
||
conn.close()
|
||
print(
|
||
json.dumps(
|
||
{
|
||
"db_path": args.db_path,
|
||
"requested_count": max(1, args.count),
|
||
"inserted": inserted,
|
||
"skipped_duplicate": skipped_duplicate,
|
||
"failed": failed,
|
||
"details": details,
|
||
"keyword_sample_size": len(keywords),
|
||
},
|
||
ensure_ascii=False,
|
||
indent=2,
|
||
)
|
||
)
|
||
return 0 if failed == 0 else 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|