Fix training plan generation flow

这个提交包含在:
cryptocommuniums-afk
2026-03-14 23:16:19 +08:00
父节点 6943754838
当前提交 1cc863e60e
修改 8 个文件,包含 429 行新增19 行删除

查看文件

@@ -33,6 +33,11 @@ server {
location / { location / {
proxy_pass http://127.0.0.1:3002; proxy_pass http://127.0.0.1:3002;
proxy_http_version 1.1; proxy_http_version 1.1;
proxy_buffering off;
proxy_request_buffering off;
proxy_connect_timeout 300s;
proxy_read_timeout 3600s;
proxy_send_timeout 3600s;
proxy_set_header Upgrade $http_upgrade; proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade"; proxy_set_header Connection "upgrade";
proxy_set_header Host $host; proxy_set_header Host $host;

查看文件

@@ -146,9 +146,9 @@ export async function getUserTrainingPlans(userId: number) {
export async function getActivePlan(userId: number) { export async function getActivePlan(userId: number) {
const db = await getDb(); const db = await getDb();
if (!db) return undefined; if (!db) return null;
const result = await db.select().from(trainingPlans).where(and(eq(trainingPlans.userId, userId), eq(trainingPlans.isActive, 1))).limit(1); const result = await db.select().from(trainingPlans).where(and(eq(trainingPlans.userId, userId), eq(trainingPlans.isActive, 1))).limit(1);
return result.length > 0 ? result[0] : undefined; return result.length > 0 ? result[0] : null;
} }
export async function updateTrainingPlan(planId: number, data: Partial<InsertTrainingPlan>) { export async function updateTrainingPlan(planId: number, data: Partial<InsertTrainingPlan>) {

查看文件

@@ -2,6 +2,7 @@ import { describe, expect, it, vi, beforeEach } from "vitest";
import { appRouter } from "./routers"; import { appRouter } from "./routers";
import { COOKIE_NAME } from "../shared/const"; import { COOKIE_NAME } from "../shared/const";
import type { TrpcContext } from "./_core/context"; import type { TrpcContext } from "./_core/context";
import * as db from "./db";
type AuthenticatedUser = NonNullable<TrpcContext["user"]>; type AuthenticatedUser = NonNullable<TrpcContext["user"]>;
@@ -209,6 +210,17 @@ describe("plan.active", () => {
const caller = appRouter.createCaller(ctx); const caller = appRouter.createCaller(ctx);
await expect(caller.plan.active()).rejects.toThrow(); await expect(caller.plan.active()).rejects.toThrow();
}); });
it("returns null when the user has no active plan", async () => {
const user = createTestUser();
const { ctx } = createMockContext(user);
const caller = appRouter.createCaller(ctx);
const getActivePlanSpy = vi.spyOn(db, "getActivePlan").mockResolvedValueOnce(null);
await expect(caller.plan.active()).resolves.toBeNull();
getActivePlanSpy.mockRestore();
});
}); });
describe("plan.adjust input validation", () => { describe("plan.adjust input validation", () => {

查看文件

@@ -8,6 +8,50 @@ import { invokeLLM } from "./_core/llm";
import { storagePut } from "./storage"; import { storagePut } from "./storage";
import * as db from "./db"; import * as db from "./db";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import {
normalizeAdjustedPlanResponse,
normalizeTrainingPlanResponse,
} from "./trainingPlan";
async function invokeStructuredPlan<T>(params: {
baseMessages: Array<{ role: "system" | "user"; content: string }>;
responseFormat: {
type: "json_schema";
json_schema: {
name: string;
strict: true;
schema: Record<string, unknown>;
};
};
parse: (content: unknown) => T;
}) {
let lastError: unknown;
for (let attempt = 0; attempt < 3; attempt++) {
const retryHint =
attempt === 0 || !(lastError instanceof Error)
? []
: [{
role: "user" as const,
content:
`上一次输出无法被系统解析,错误是:${lastError.message}` +
"请只返回一个合法、完整、可解析的 JSON 对象,不要包含额外说明、注释或 Markdown 代码块。",
}];
const response = await invokeLLM({
messages: [...params.baseMessages, ...retryHint],
response_format: params.responseFormat,
});
try {
return params.parse(response.choices[0]?.message?.content);
} catch (error) {
lastError = error;
}
}
throw lastError instanceof Error ? lastError : new Error("Failed to parse structured LLM response");
}
export const appRouter = router({ export const appRouter = router({
system: systemRouter, system: systemRouter,
@@ -85,12 +129,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
请返回JSON格式,包含每天的训练内容。`; 请返回JSON格式,包含每天的训练内容。`;
const response = await invokeLLM({ const parsed = await invokeStructuredPlan({
messages: [ baseMessages: [
{ role: "system", content: "你是网球训练计划生成器。返回严格的JSON格式。" }, { role: "system", content: "你是网球训练计划生成器。返回严格的JSON格式。" },
{ role: "user", content: prompt }, { role: "user", content: prompt },
], ],
response_format: { responseFormat: {
type: "json_schema", type: "json_schema",
json_schema: { json_schema: {
name: "training_plan", name: "training_plan",
@@ -123,12 +167,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
}, },
}, },
}, },
parse: (content) => normalizeTrainingPlanResponse({
content,
fallbackTitle: `${input.durationDays}天训练计划`,
}),
}); });
const content = response.choices[0]?.message?.content;
const parsed = typeof content === "string" ? JSON.parse(content) : null;
if (!parsed) throw new Error("Failed to generate training plan");
const planId = await db.createTrainingPlan({ const planId = await db.createTrainingPlan({
userId: user.id, userId: user.id,
title: parsed.title, title: parsed.title,
@@ -173,12 +217,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
请根据分析结果调整训练计划,增加针对薄弱环节的训练,返回与原计划相同格式的JSON。`; 请根据分析结果调整训练计划,增加针对薄弱环节的训练,返回与原计划相同格式的JSON。`;
const response = await invokeLLM({ const parsed = await invokeStructuredPlan({
messages: [ baseMessages: [
{ role: "system", content: "你是网球评分生成器。返回严格的JSON格式。" }, { role: "system", content: "你是网球训练计划调整器。返回严格的JSON格式。" },
{ role: "user", content: prompt }, { role: "user", content: prompt },
], ],
response_format: { responseFormat: {
type: "json_schema", type: "json_schema",
json_schema: { json_schema: {
name: "adjusted_plan", name: "adjusted_plan",
@@ -212,12 +256,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
}, },
}, },
}, },
parse: (content) => normalizeAdjustedPlanResponse({
content,
fallbackTitle: currentPlan.title,
}),
}); });
const content = response.choices[0]?.message?.content;
const parsed = typeof content === "string" ? JSON.parse(content) : null;
if (!parsed) throw new Error("Failed to adjust plan");
await db.updateTrainingPlan(input.planId, { await db.updateTrainingPlan(input.planId, {
exercises: parsed.exercises, exercises: parsed.exercises,
adjustmentNotes: parsed.adjustmentNotes, adjustmentNotes: parsed.adjustmentNotes,

96
server/trainingPlan.test.ts 普通文件
查看文件

@@ -0,0 +1,96 @@
import { describe, expect, it } from "vitest";
import {
normalizeAdjustedPlanResponse,
normalizeTrainingPlanResponse,
} from "./trainingPlan";
describe("normalizeTrainingPlanResponse", () => {
it("accepts canonical title/exercises output", () => {
const result = normalizeTrainingPlanResponse({
content: JSON.stringify({
title: "7天训练计划",
exercises: [
{
day: 1,
name: "正手影子挥拍",
category: "影子挥拍",
duration: 15,
description: "完成正手挥拍练习",
tips: "保持重心稳定",
sets: 3,
reps: 12,
},
],
}),
fallbackTitle: "fallback",
});
expect(result.title).toBe("7天训练计划");
expect(result.exercises).toHaveLength(1);
expect(result.exercises[0]?.category).toBe("影子挥拍");
});
it("normalizes qwen day map output into plan exercises", () => {
const result = normalizeTrainingPlanResponse({
content: JSON.stringify({
day_1: {
duration_minutes: 45,
focus: "基础握拍与正手影子挥拍",
exercises: [
{
name: "握拍方式学习",
description: "学习大陆式与东方式握拍",
duration_minutes: 10,
},
{
name: "原地小碎步热身与放松",
description: "30秒快速小碎步 + 30秒休息",
duration_minutes: 10,
},
],
},
}),
fallbackTitle: "7天训练计划",
});
expect(result.title).toBe("7天训练计划");
expect(result.exercises).toHaveLength(2);
expect(result.exercises[0]).toMatchObject({
day: 1,
name: "握拍方式学习",
duration: 10,
sets: 3,
reps: 10,
});
expect(result.exercises[1]?.category).toBe("脚步移动");
});
});
describe("normalizeAdjustedPlanResponse", () => {
it("fills missing adjustment notes for day map output", () => {
const result = normalizeAdjustedPlanResponse({
content: JSON.stringify({
day_1: {
duration_minutes: 30,
focus: "脚步移动",
exercises: [
{
name: "交叉步移动",
description: "左右移动并快速回位",
duration_minutes: 12,
},
],
},
}),
fallbackTitle: "当前训练计划",
});
expect(result.title).toBe("当前训练计划");
expect(result.adjustmentNotes).toContain("已根据最近分析结果调整");
expect(result.exercises[0]).toMatchObject({
day: 1,
name: "交叉步移动",
category: "脚步移动",
});
});
});

200
server/trainingPlan.ts 普通文件
查看文件

@@ -0,0 +1,200 @@
import { z } from "zod";
const exerciseSchema = z.object({
day: z.number().int().min(1),
name: z.string().min(1),
category: z.string().min(1),
duration: z.number().positive(),
description: z.string().min(1),
tips: z.string().min(1),
sets: z.number().int().positive(),
reps: z.number().int().positive(),
});
const normalizedPlanSchema = z.object({
title: z.string().min(1),
exercises: z.array(exerciseSchema).min(1),
});
const normalizedAdjustedPlanSchema = normalizedPlanSchema.extend({
adjustmentNotes: z.string().min(1),
});
type NormalizedExercise = z.infer<typeof exerciseSchema>;
type NormalizedPlan = z.infer<typeof normalizedPlanSchema>;
type NormalizedAdjustedPlan = z.infer<typeof normalizedAdjustedPlanSchema>;
const dayKeyPattern = /^day[_\s-]?(\d+)$/i;
function extractTextContent(content: unknown) {
if (typeof content === "string") {
return content;
}
if (Array.isArray(content)) {
const text = content
.map(item => (item && typeof item === "object" && "text" in item ? String((item as { text?: unknown }).text ?? "") : ""))
.join("")
.trim();
return text.length > 0 ? text : null;
}
return null;
}
function parseJsonContent(content: unknown) {
const text = extractTextContent(content);
if (!text) {
throw new Error("LLM did not return text content");
}
try {
return JSON.parse(text) as Record<string, unknown>;
} catch (error) {
throw new Error(`LLM returned invalid JSON: ${error instanceof Error ? error.message : "unknown error"}`);
}
}
function toPositiveNumber(value: unknown, fallback: number) {
const parsed = typeof value === "number" ? value : Number(value);
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
}
function toPositiveInteger(value: unknown, fallback: number) {
const parsed = typeof value === "number" ? value : Number.parseInt(String(value), 10);
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
}
function inferCategory(...values: Array<unknown>) {
const text = values
.filter((value): value is string => typeof value === "string")
.join(" ");
if (/(墙|wall)/i.test(text)) return "墙壁练习";
if (/(步|移动|footwork|shuffle|split step)/i.test(text)) return "脚步移动";
if (/(挥拍|shadow|正手|反手|发球|截击)/i.test(text)) return "影子挥拍";
return "体能训练";
}
function normalizeExercise(
day: number,
exercise: Record<string, unknown>,
section?: Record<string, unknown>
): NormalizedExercise {
const name =
typeof exercise.name === "string" && exercise.name.trim().length > 0
? exercise.name.trim()
: typeof section?.focus === "string" && section.focus.trim().length > 0
? section.focus.trim()
: `${day}天训练项目`;
const description =
typeof exercise.description === "string" && exercise.description.trim().length > 0
? exercise.description.trim()
: typeof section?.focus === "string" && section.focus.trim().length > 0
? section.focus.trim()
: `${name}训练`;
const tips =
typeof exercise.tips === "string" && exercise.tips.trim().length > 0
? exercise.tips.trim()
: typeof section?.focus === "string" && section.focus.trim().length > 0
? `重点关注:${section.focus.trim()}`
: "保持动作稳定,注意训练节奏。";
return {
day,
name,
category: inferCategory(exercise.category, name, description, section?.focus),
duration: toPositiveNumber(
exercise.duration ?? exercise.duration_minutes,
toPositiveNumber(section?.duration_minutes, 10)
),
description,
tips,
sets: toPositiveInteger(exercise.sets, 3),
reps: toPositiveInteger(exercise.reps, 10),
};
}
function normalizeDayMapPlan(
raw: Record<string, unknown>,
fallbackTitle: string
): NormalizedPlan {
const exercises: NormalizedExercise[] = [];
for (const [key, value] of Object.entries(raw)) {
const match = key.match(dayKeyPattern);
if (!match || !value || typeof value !== "object" || Array.isArray(value)) {
continue;
}
const day = Number.parseInt(match[1] ?? "", 10);
if (!Number.isFinite(day) || day <= 0) {
continue;
}
const section = value as Record<string, unknown>;
const sectionExercises = Array.isArray(section.exercises)
? section.exercises.filter(
(item): item is Record<string, unknown> =>
Boolean(item) && typeof item === "object" && !Array.isArray(item)
)
: [];
for (const exercise of sectionExercises) {
exercises.push(normalizeExercise(day, exercise, section));
}
}
return normalizedPlanSchema.parse({
title:
typeof raw.title === "string" && raw.title.trim().length > 0
? raw.title.trim()
: fallbackTitle,
exercises,
});
}
export function normalizeTrainingPlanResponse(params: {
content: unknown;
fallbackTitle: string;
}): NormalizedPlan {
const raw = parseJsonContent(params.content);
if (Array.isArray(raw.exercises)) {
return normalizedPlanSchema.parse(raw);
}
return normalizeDayMapPlan(raw, params.fallbackTitle);
}
export function normalizeAdjustedPlanResponse(params: {
content: unknown;
fallbackTitle: string;
}): NormalizedAdjustedPlan {
const raw = parseJsonContent(params.content);
if (Array.isArray(raw.exercises)) {
return normalizedAdjustedPlanSchema.parse({
...raw,
adjustmentNotes:
typeof raw.adjustmentNotes === "string" && raw.adjustmentNotes.trim().length > 0
? raw.adjustmentNotes.trim()
: "已根据最近分析结果调整训练内容。",
});
}
const normalized = normalizeDayMapPlan(raw, params.fallbackTitle);
return normalizedAdjustedPlanSchema.parse({
...normalized,
adjustmentNotes:
typeof raw.adjustmentNotes === "string" && raw.adjustmentNotes.trim().length > 0
? raw.adjustmentNotes.trim()
: typeof raw.adjustment_notes === "string" && raw.adjustment_notes.trim().length > 0
? raw.adjustment_notes.trim()
: "已根据最近分析结果调整训练内容。",
});
}

查看文件

@@ -23,6 +23,8 @@ test("training page shows plan generation flow", async ({ page }) => {
await page.goto("/training"); await page.goto("/training");
await expect(page.getByTestId("training-title")).toBeVisible(); await expect(page.getByTestId("training-title")).toBeVisible();
await expect(page.getByTestId("training-generate-button")).toBeVisible(); await expect(page.getByTestId("training-generate-button")).toBeVisible();
await page.getByTestId("training-generate-button").click();
await expect(page.getByText("TestPlayer 的训练计划")).toBeVisible();
}); });
test("videos page renders video library items", async ({ page }) => { test("videos page renders video library items", async ({ page }) => {

查看文件

@@ -59,6 +59,24 @@ type MockAppState = {
user: MockUser; user: MockUser;
videos: any[]; videos: any[];
analyses: any[]; analyses: any[];
activePlan: {
id: number;
title: string;
skillLevel: string;
durationDays: number;
exercises: Array<{
day: number;
name: string;
category: string;
duration: number;
description: string;
tips: string;
sets: number;
reps: number;
}>;
version: number;
adjustmentNotes: string | null;
} | null;
mediaSession: MockMediaSession | null; mediaSession: MockMediaSession | null;
nextVideoId: number; nextVideoId: number;
authMeNullResponsesAfterLogin: number; authMeNullResponsesAfterLogin: number;
@@ -166,9 +184,41 @@ async function handleTrpc(route: Route, state: MockAppState) {
case "profile.stats": case "profile.stats":
return trpcResult(buildStats(state.user)); return trpcResult(buildStats(state.user));
case "plan.active": case "plan.active":
return trpcResult(null); return trpcResult(state.activePlan);
case "plan.list": case "plan.list":
return trpcResult([]); return trpcResult(state.activePlan ? [state.activePlan] : []);
case "plan.generate":
state.activePlan = {
id: 200,
title: `${state.user.name} 的训练计划`,
skillLevel: "beginner",
durationDays: 7,
version: 1,
adjustmentNotes: null,
exercises: [
{
day: 1,
name: "正手影子挥拍",
category: "影子挥拍",
duration: 15,
description: "练习完整引拍和收拍动作。",
tips: "保持重心稳定,击球点在身体前侧。",
sets: 3,
reps: 12,
},
{
day: 1,
name: "交叉步移动",
category: "脚步移动",
duration: 12,
description: "强化启动和回位节奏。",
tips: "每次移动后快速回到准备姿势。",
sets: 4,
reps: 10,
},
],
};
return trpcResult({ planId: state.activePlan.id, plan: state.activePlan });
case "video.list": case "video.list":
return trpcResult(state.videos); return trpcResult(state.videos);
case "analysis.list": case "analysis.list":
@@ -316,6 +366,7 @@ export async function installAppMocks(
createdAt: nowIso(), createdAt: nowIso(),
}, },
], ],
activePlan: null,
mediaSession: null, mediaSession: null,
nextVideoId: 100, nextVideoId: 100,
authMeNullResponsesAfterLogin: options?.authMeNullResponsesAfterLogin ?? 0, authMeNullResponsesAfterLogin: options?.authMeNullResponsesAfterLogin ?? 0,