Fix training plan generation flow

这个提交包含在:
cryptocommuniums-afk
2026-03-14 23:16:19 +08:00
父节点 6943754838
当前提交 1cc863e60e
修改 8 个文件,包含 429 行新增19 行删除

查看文件

@@ -146,9 +146,9 @@ export async function getUserTrainingPlans(userId: number) {
export async function getActivePlan(userId: number) {
const db = await getDb();
if (!db) return undefined;
if (!db) return null;
const result = await db.select().from(trainingPlans).where(and(eq(trainingPlans.userId, userId), eq(trainingPlans.isActive, 1))).limit(1);
return result.length > 0 ? result[0] : undefined;
return result.length > 0 ? result[0] : null;
}
export async function updateTrainingPlan(planId: number, data: Partial<InsertTrainingPlan>) {

查看文件

@@ -2,6 +2,7 @@ import { describe, expect, it, vi, beforeEach } from "vitest";
import { appRouter } from "./routers";
import { COOKIE_NAME } from "../shared/const";
import type { TrpcContext } from "./_core/context";
import * as db from "./db";
type AuthenticatedUser = NonNullable<TrpcContext["user"]>;
@@ -209,6 +210,17 @@ describe("plan.active", () => {
const caller = appRouter.createCaller(ctx);
await expect(caller.plan.active()).rejects.toThrow();
});
it("returns null when the user has no active plan", async () => {
const user = createTestUser();
const { ctx } = createMockContext(user);
const caller = appRouter.createCaller(ctx);
const getActivePlanSpy = vi.spyOn(db, "getActivePlan").mockResolvedValueOnce(null);
await expect(caller.plan.active()).resolves.toBeNull();
getActivePlanSpy.mockRestore();
});
});
describe("plan.adjust input validation", () => {

查看文件

@@ -8,6 +8,50 @@ import { invokeLLM } from "./_core/llm";
import { storagePut } from "./storage";
import * as db from "./db";
import { nanoid } from "nanoid";
import {
normalizeAdjustedPlanResponse,
normalizeTrainingPlanResponse,
} from "./trainingPlan";
async function invokeStructuredPlan<T>(params: {
baseMessages: Array<{ role: "system" | "user"; content: string }>;
responseFormat: {
type: "json_schema";
json_schema: {
name: string;
strict: true;
schema: Record<string, unknown>;
};
};
parse: (content: unknown) => T;
}) {
let lastError: unknown;
for (let attempt = 0; attempt < 3; attempt++) {
const retryHint =
attempt === 0 || !(lastError instanceof Error)
? []
: [{
role: "user" as const,
content:
`上一次输出无法被系统解析,错误是:${lastError.message}` +
"请只返回一个合法、完整、可解析的 JSON 对象,不要包含额外说明、注释或 Markdown 代码块。",
}];
const response = await invokeLLM({
messages: [...params.baseMessages, ...retryHint],
response_format: params.responseFormat,
});
try {
return params.parse(response.choices[0]?.message?.content);
} catch (error) {
lastError = error;
}
}
throw lastError instanceof Error ? lastError : new Error("Failed to parse structured LLM response");
}
export const appRouter = router({
system: systemRouter,
@@ -85,12 +129,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
请返回JSON格式,包含每天的训练内容。`;
const response = await invokeLLM({
messages: [
const parsed = await invokeStructuredPlan({
baseMessages: [
{ role: "system", content: "你是网球训练计划生成器。返回严格的JSON格式。" },
{ role: "user", content: prompt },
],
response_format: {
responseFormat: {
type: "json_schema",
json_schema: {
name: "training_plan",
@@ -123,12 +167,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
},
},
},
parse: (content) => normalizeTrainingPlanResponse({
content,
fallbackTitle: `${input.durationDays}天训练计划`,
}),
});
const content = response.choices[0]?.message?.content;
const parsed = typeof content === "string" ? JSON.parse(content) : null;
if (!parsed) throw new Error("Failed to generate training plan");
const planId = await db.createTrainingPlan({
userId: user.id,
title: parsed.title,
@@ -173,12 +217,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
请根据分析结果调整训练计划,增加针对薄弱环节的训练,返回与原计划相同格式的JSON。`;
const response = await invokeLLM({
messages: [
{ role: "system", content: "你是网球评分生成器。返回严格的JSON格式。" },
const parsed = await invokeStructuredPlan({
baseMessages: [
{ role: "system", content: "你是网球训练计划调整器。返回严格的JSON格式。" },
{ role: "user", content: prompt },
],
response_format: {
responseFormat: {
type: "json_schema",
json_schema: {
name: "adjusted_plan",
@@ -212,12 +256,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
},
},
},
parse: (content) => normalizeAdjustedPlanResponse({
content,
fallbackTitle: currentPlan.title,
}),
});
const content = response.choices[0]?.message?.content;
const parsed = typeof content === "string" ? JSON.parse(content) : null;
if (!parsed) throw new Error("Failed to adjust plan");
await db.updateTrainingPlan(input.planId, {
exercises: parsed.exercises,
adjustmentNotes: parsed.adjustmentNotes,

96
server/trainingPlan.test.ts 普通文件
查看文件

@@ -0,0 +1,96 @@
import { describe, expect, it } from "vitest";
import {
normalizeAdjustedPlanResponse,
normalizeTrainingPlanResponse,
} from "./trainingPlan";
describe("normalizeTrainingPlanResponse", () => {
it("accepts canonical title/exercises output", () => {
const result = normalizeTrainingPlanResponse({
content: JSON.stringify({
title: "7天训练计划",
exercises: [
{
day: 1,
name: "正手影子挥拍",
category: "影子挥拍",
duration: 15,
description: "完成正手挥拍练习",
tips: "保持重心稳定",
sets: 3,
reps: 12,
},
],
}),
fallbackTitle: "fallback",
});
expect(result.title).toBe("7天训练计划");
expect(result.exercises).toHaveLength(1);
expect(result.exercises[0]?.category).toBe("影子挥拍");
});
it("normalizes qwen day map output into plan exercises", () => {
const result = normalizeTrainingPlanResponse({
content: JSON.stringify({
day_1: {
duration_minutes: 45,
focus: "基础握拍与正手影子挥拍",
exercises: [
{
name: "握拍方式学习",
description: "学习大陆式与东方式握拍",
duration_minutes: 10,
},
{
name: "原地小碎步热身与放松",
description: "30秒快速小碎步 + 30秒休息",
duration_minutes: 10,
},
],
},
}),
fallbackTitle: "7天训练计划",
});
expect(result.title).toBe("7天训练计划");
expect(result.exercises).toHaveLength(2);
expect(result.exercises[0]).toMatchObject({
day: 1,
name: "握拍方式学习",
duration: 10,
sets: 3,
reps: 10,
});
expect(result.exercises[1]?.category).toBe("脚步移动");
});
});
describe("normalizeAdjustedPlanResponse", () => {
it("fills missing adjustment notes for day map output", () => {
const result = normalizeAdjustedPlanResponse({
content: JSON.stringify({
day_1: {
duration_minutes: 30,
focus: "脚步移动",
exercises: [
{
name: "交叉步移动",
description: "左右移动并快速回位",
duration_minutes: 12,
},
],
},
}),
fallbackTitle: "当前训练计划",
});
expect(result.title).toBe("当前训练计划");
expect(result.adjustmentNotes).toContain("已根据最近分析结果调整");
expect(result.exercises[0]).toMatchObject({
day: 1,
name: "交叉步移动",
category: "脚步移动",
});
});
});

200
server/trainingPlan.ts 普通文件
查看文件

@@ -0,0 +1,200 @@
import { z } from "zod";
const exerciseSchema = z.object({
day: z.number().int().min(1),
name: z.string().min(1),
category: z.string().min(1),
duration: z.number().positive(),
description: z.string().min(1),
tips: z.string().min(1),
sets: z.number().int().positive(),
reps: z.number().int().positive(),
});
const normalizedPlanSchema = z.object({
title: z.string().min(1),
exercises: z.array(exerciseSchema).min(1),
});
const normalizedAdjustedPlanSchema = normalizedPlanSchema.extend({
adjustmentNotes: z.string().min(1),
});
type NormalizedExercise = z.infer<typeof exerciseSchema>;
type NormalizedPlan = z.infer<typeof normalizedPlanSchema>;
type NormalizedAdjustedPlan = z.infer<typeof normalizedAdjustedPlanSchema>;
const dayKeyPattern = /^day[_\s-]?(\d+)$/i;
function extractTextContent(content: unknown) {
if (typeof content === "string") {
return content;
}
if (Array.isArray(content)) {
const text = content
.map(item => (item && typeof item === "object" && "text" in item ? String((item as { text?: unknown }).text ?? "") : ""))
.join("")
.trim();
return text.length > 0 ? text : null;
}
return null;
}
function parseJsonContent(content: unknown) {
const text = extractTextContent(content);
if (!text) {
throw new Error("LLM did not return text content");
}
try {
return JSON.parse(text) as Record<string, unknown>;
} catch (error) {
throw new Error(`LLM returned invalid JSON: ${error instanceof Error ? error.message : "unknown error"}`);
}
}
function toPositiveNumber(value: unknown, fallback: number) {
const parsed = typeof value === "number" ? value : Number(value);
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
}
function toPositiveInteger(value: unknown, fallback: number) {
const parsed = typeof value === "number" ? value : Number.parseInt(String(value), 10);
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
}
function inferCategory(...values: Array<unknown>) {
const text = values
.filter((value): value is string => typeof value === "string")
.join(" ");
if (/(墙|wall)/i.test(text)) return "墙壁练习";
if (/(步|移动|footwork|shuffle|split step)/i.test(text)) return "脚步移动";
if (/(挥拍|shadow|正手|反手|发球|截击)/i.test(text)) return "影子挥拍";
return "体能训练";
}
function normalizeExercise(
day: number,
exercise: Record<string, unknown>,
section?: Record<string, unknown>
): NormalizedExercise {
const name =
typeof exercise.name === "string" && exercise.name.trim().length > 0
? exercise.name.trim()
: typeof section?.focus === "string" && section.focus.trim().length > 0
? section.focus.trim()
: `${day}天训练项目`;
const description =
typeof exercise.description === "string" && exercise.description.trim().length > 0
? exercise.description.trim()
: typeof section?.focus === "string" && section.focus.trim().length > 0
? section.focus.trim()
: `${name}训练`;
const tips =
typeof exercise.tips === "string" && exercise.tips.trim().length > 0
? exercise.tips.trim()
: typeof section?.focus === "string" && section.focus.trim().length > 0
? `重点关注:${section.focus.trim()}`
: "保持动作稳定,注意训练节奏。";
return {
day,
name,
category: inferCategory(exercise.category, name, description, section?.focus),
duration: toPositiveNumber(
exercise.duration ?? exercise.duration_minutes,
toPositiveNumber(section?.duration_minutes, 10)
),
description,
tips,
sets: toPositiveInteger(exercise.sets, 3),
reps: toPositiveInteger(exercise.reps, 10),
};
}
function normalizeDayMapPlan(
raw: Record<string, unknown>,
fallbackTitle: string
): NormalizedPlan {
const exercises: NormalizedExercise[] = [];
for (const [key, value] of Object.entries(raw)) {
const match = key.match(dayKeyPattern);
if (!match || !value || typeof value !== "object" || Array.isArray(value)) {
continue;
}
const day = Number.parseInt(match[1] ?? "", 10);
if (!Number.isFinite(day) || day <= 0) {
continue;
}
const section = value as Record<string, unknown>;
const sectionExercises = Array.isArray(section.exercises)
? section.exercises.filter(
(item): item is Record<string, unknown> =>
Boolean(item) && typeof item === "object" && !Array.isArray(item)
)
: [];
for (const exercise of sectionExercises) {
exercises.push(normalizeExercise(day, exercise, section));
}
}
return normalizedPlanSchema.parse({
title:
typeof raw.title === "string" && raw.title.trim().length > 0
? raw.title.trim()
: fallbackTitle,
exercises,
});
}
export function normalizeTrainingPlanResponse(params: {
content: unknown;
fallbackTitle: string;
}): NormalizedPlan {
const raw = parseJsonContent(params.content);
if (Array.isArray(raw.exercises)) {
return normalizedPlanSchema.parse(raw);
}
return normalizeDayMapPlan(raw, params.fallbackTitle);
}
export function normalizeAdjustedPlanResponse(params: {
content: unknown;
fallbackTitle: string;
}): NormalizedAdjustedPlan {
const raw = parseJsonContent(params.content);
if (Array.isArray(raw.exercises)) {
return normalizedAdjustedPlanSchema.parse({
...raw,
adjustmentNotes:
typeof raw.adjustmentNotes === "string" && raw.adjustmentNotes.trim().length > 0
? raw.adjustmentNotes.trim()
: "已根据最近分析结果调整训练内容。",
});
}
const normalized = normalizeDayMapPlan(raw, params.fallbackTitle);
return normalizedAdjustedPlanSchema.parse({
...normalized,
adjustmentNotes:
typeof raw.adjustmentNotes === "string" && raw.adjustmentNotes.trim().length > 0
? raw.adjustmentNotes.trim()
: typeof raw.adjustment_notes === "string" && raw.adjustment_notes.trim().length > 0
? raw.adjustment_notes.trim()
: "已根据最近分析结果调整训练内容。",
});
}