139 行
4.4 KiB
Python
139 行
4.4 KiB
Python
#!/usr/bin/env python3
|
||
"""Check whether a date is a China statutory holiday using LLM with fallback."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import datetime as dt
|
||
import json
|
||
import os
|
||
import time
|
||
from typing import Any, Dict, Optional
|
||
|
||
import requests
|
||
|
||
|
||
def env(name: str, default: str = "") -> str:
|
||
value = os.getenv(name, "").strip()
|
||
return value if value else default
|
||
|
||
|
||
def parse_date(raw: str) -> dt.date:
|
||
return dt.datetime.strptime(raw, "%Y-%m-%d").date()
|
||
|
||
|
||
def fallback(date_obj: dt.date, reason: str) -> Dict[str, Any]:
|
||
is_weekend = date_obj.weekday() >= 5
|
||
return {
|
||
"is_holiday": is_weekend,
|
||
"reason": reason if reason else ("周末自动判定为假期" if is_weekend else "工作日默认学习日"),
|
||
"model_name": "fallback-rules",
|
||
}
|
||
|
||
|
||
def call_llm(date_obj: dt.date) -> Dict[str, Any]:
|
||
api_url = env("OI_LLM_API_URL") or env("CSP_LLM_API_URL")
|
||
api_key = env("OI_LLM_API_KEY") or env("CSP_LLM_API_KEY")
|
||
model = env("OI_LLM_MODEL", "qwen3-max")
|
||
if not api_url:
|
||
raise RuntimeError("missing OI_LLM_API_URL")
|
||
|
||
weekday_labels = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
|
||
weekday_label = weekday_labels[date_obj.weekday()]
|
||
|
||
system = (
|
||
"你是中国节假日判定助手。给定日期后,只判断这一天是否属于中国法定节假日放假日。"
|
||
"不需要解释历法推导,不需要输出多余文本。"
|
||
"输出必须为纯 JSON,格式:"
|
||
'{"is_holiday":true/false,"reason":"简短中文原因","model_name":"模型名"}'
|
||
)
|
||
user = {
|
||
"date": date_obj.isoformat(),
|
||
"weekday": weekday_label,
|
||
"task": "判断该日期是否是中国法定节假日放假日。仅判断当天,不跨天推断。",
|
||
}
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
body = {
|
||
"model": model,
|
||
"stream": False,
|
||
"temperature": 0.0,
|
||
"messages": [
|
||
{"role": "system", "content": system},
|
||
{"role": "user", "content": json.dumps(user, ensure_ascii=False)},
|
||
],
|
||
}
|
||
|
||
last_err: Optional[Exception] = None
|
||
for attempt in range(3):
|
||
try:
|
||
resp = requests.post(api_url, headers=headers, json=body, timeout=25)
|
||
if resp.status_code < 500:
|
||
resp.raise_for_status()
|
||
else:
|
||
raise RuntimeError(f"HTTP {resp.status_code}")
|
||
|
||
data = resp.json()
|
||
content = data["choices"][0]["message"]["content"]
|
||
txt = content.strip()
|
||
if txt.startswith("```"):
|
||
txt = txt.split("\n", 1)[-1]
|
||
if txt.endswith("```"):
|
||
txt = txt[:-3]
|
||
txt = txt.strip()
|
||
parsed = json.loads(txt)
|
||
if not isinstance(parsed, dict):
|
||
raise RuntimeError("llm output is not object")
|
||
return {
|
||
"is_holiday": bool(parsed.get("is_holiday", False)),
|
||
"reason": str(parsed.get("reason", "")).strip() or "LLM判定",
|
||
"model_name": str(parsed.get("model_name", model)).strip() or model,
|
||
}
|
||
except Exception as exc: # noqa: BLE001
|
||
last_err = exc
|
||
time.sleep(0.6 * (attempt + 1))
|
||
|
||
raise RuntimeError(str(last_err) if last_err else "llm failed")
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--date", required=True, help="Date in YYYY-MM-DD")
|
||
args = ap.parse_args()
|
||
|
||
try:
|
||
date_obj = parse_date(args.date)
|
||
except Exception:
|
||
print(
|
||
json.dumps(
|
||
{"is_holiday": False, "reason": "invalid date format", "model_name": "fallback-rules"},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
return 0
|
||
|
||
# Weekend is always holiday by rule.
|
||
if date_obj.weekday() >= 5:
|
||
print(
|
||
json.dumps(
|
||
{"is_holiday": True, "reason": "周末自动判定为假期", "model_name": "calendar-weekend"},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
return 0
|
||
|
||
try:
|
||
out = call_llm(date_obj)
|
||
except Exception as exc: # noqa: BLE001
|
||
out = fallback(date_obj, f"工作日默认学习日(LLM失败: {exc})")
|
||
|
||
print(json.dumps(out, ensure_ascii=False))
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|