Fix training plan generation flow
这个提交包含在:
@@ -33,6 +33,11 @@ server {
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:3002;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_request_buffering off;
|
||||
proxy_connect_timeout 300s;
|
||||
proxy_read_timeout 3600s;
|
||||
proxy_send_timeout 3600s;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
|
||||
@@ -146,9 +146,9 @@ export async function getUserTrainingPlans(userId: number) {
|
||||
|
||||
export async function getActivePlan(userId: number) {
|
||||
const db = await getDb();
|
||||
if (!db) return undefined;
|
||||
if (!db) return null;
|
||||
const result = await db.select().from(trainingPlans).where(and(eq(trainingPlans.userId, userId), eq(trainingPlans.isActive, 1))).limit(1);
|
||||
return result.length > 0 ? result[0] : undefined;
|
||||
return result.length > 0 ? result[0] : null;
|
||||
}
|
||||
|
||||
export async function updateTrainingPlan(planId: number, data: Partial<InsertTrainingPlan>) {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { appRouter } from "./routers";
|
||||
import { COOKIE_NAME } from "../shared/const";
|
||||
import type { TrpcContext } from "./_core/context";
|
||||
import * as db from "./db";
|
||||
|
||||
type AuthenticatedUser = NonNullable<TrpcContext["user"]>;
|
||||
|
||||
@@ -209,6 +210,17 @@ describe("plan.active", () => {
|
||||
const caller = appRouter.createCaller(ctx);
|
||||
await expect(caller.plan.active()).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("returns null when the user has no active plan", async () => {
|
||||
const user = createTestUser();
|
||||
const { ctx } = createMockContext(user);
|
||||
const caller = appRouter.createCaller(ctx);
|
||||
const getActivePlanSpy = vi.spyOn(db, "getActivePlan").mockResolvedValueOnce(null);
|
||||
|
||||
await expect(caller.plan.active()).resolves.toBeNull();
|
||||
|
||||
getActivePlanSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("plan.adjust input validation", () => {
|
||||
|
||||
@@ -8,6 +8,50 @@ import { invokeLLM } from "./_core/llm";
|
||||
import { storagePut } from "./storage";
|
||||
import * as db from "./db";
|
||||
import { nanoid } from "nanoid";
|
||||
import {
|
||||
normalizeAdjustedPlanResponse,
|
||||
normalizeTrainingPlanResponse,
|
||||
} from "./trainingPlan";
|
||||
|
||||
async function invokeStructuredPlan<T>(params: {
|
||||
baseMessages: Array<{ role: "system" | "user"; content: string }>;
|
||||
responseFormat: {
|
||||
type: "json_schema";
|
||||
json_schema: {
|
||||
name: string;
|
||||
strict: true;
|
||||
schema: Record<string, unknown>;
|
||||
};
|
||||
};
|
||||
parse: (content: unknown) => T;
|
||||
}) {
|
||||
let lastError: unknown;
|
||||
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
const retryHint =
|
||||
attempt === 0 || !(lastError instanceof Error)
|
||||
? []
|
||||
: [{
|
||||
role: "user" as const,
|
||||
content:
|
||||
`上一次输出无法被系统解析,错误是:${lastError.message}。` +
|
||||
"请只返回一个合法、完整、可解析的 JSON 对象,不要包含额外说明、注释或 Markdown 代码块。",
|
||||
}];
|
||||
|
||||
const response = await invokeLLM({
|
||||
messages: [...params.baseMessages, ...retryHint],
|
||||
response_format: params.responseFormat,
|
||||
});
|
||||
|
||||
try {
|
||||
return params.parse(response.choices[0]?.message?.content);
|
||||
} catch (error) {
|
||||
lastError = error;
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError instanceof Error ? lastError : new Error("Failed to parse structured LLM response");
|
||||
}
|
||||
|
||||
export const appRouter = router({
|
||||
system: systemRouter,
|
||||
@@ -85,12 +129,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
|
||||
|
||||
请返回JSON格式,包含每天的训练内容。`;
|
||||
|
||||
const response = await invokeLLM({
|
||||
messages: [
|
||||
const parsed = await invokeStructuredPlan({
|
||||
baseMessages: [
|
||||
{ role: "system", content: "你是网球训练计划生成器。返回严格的JSON格式。" },
|
||||
{ role: "user", content: prompt },
|
||||
],
|
||||
response_format: {
|
||||
responseFormat: {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "training_plan",
|
||||
@@ -123,12 +167,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
|
||||
},
|
||||
},
|
||||
},
|
||||
parse: (content) => normalizeTrainingPlanResponse({
|
||||
content,
|
||||
fallbackTitle: `${input.durationDays}天训练计划`,
|
||||
}),
|
||||
});
|
||||
|
||||
const content = response.choices[0]?.message?.content;
|
||||
const parsed = typeof content === "string" ? JSON.parse(content) : null;
|
||||
if (!parsed) throw new Error("Failed to generate training plan");
|
||||
|
||||
const planId = await db.createTrainingPlan({
|
||||
userId: user.id,
|
||||
title: parsed.title,
|
||||
@@ -173,12 +217,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
|
||||
|
||||
请根据分析结果调整训练计划,增加针对薄弱环节的训练,返回与原计划相同格式的JSON。`;
|
||||
|
||||
const response = await invokeLLM({
|
||||
messages: [
|
||||
{ role: "system", content: "你是网球评分生成器。返回严格的JSON格式。" },
|
||||
const parsed = await invokeStructuredPlan({
|
||||
baseMessages: [
|
||||
{ role: "system", content: "你是网球训练计划调整器。返回严格的JSON格式。" },
|
||||
{ role: "user", content: prompt },
|
||||
],
|
||||
response_format: {
|
||||
responseFormat: {
|
||||
type: "json_schema",
|
||||
json_schema: {
|
||||
name: "adjusted_plan",
|
||||
@@ -212,12 +256,12 @@ ${recentScores.length > 0 ? `- 用户最近的分析数据: ${JSON.stringify(rec
|
||||
},
|
||||
},
|
||||
},
|
||||
parse: (content) => normalizeAdjustedPlanResponse({
|
||||
content,
|
||||
fallbackTitle: currentPlan.title,
|
||||
}),
|
||||
});
|
||||
|
||||
const content = response.choices[0]?.message?.content;
|
||||
const parsed = typeof content === "string" ? JSON.parse(content) : null;
|
||||
if (!parsed) throw new Error("Failed to adjust plan");
|
||||
|
||||
await db.updateTrainingPlan(input.planId, {
|
||||
exercises: parsed.exercises,
|
||||
adjustmentNotes: parsed.adjustmentNotes,
|
||||
|
||||
96
server/trainingPlan.test.ts
普通文件
96
server/trainingPlan.test.ts
普通文件
@@ -0,0 +1,96 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
normalizeAdjustedPlanResponse,
|
||||
normalizeTrainingPlanResponse,
|
||||
} from "./trainingPlan";
|
||||
|
||||
describe("normalizeTrainingPlanResponse", () => {
|
||||
it("accepts canonical title/exercises output", () => {
|
||||
const result = normalizeTrainingPlanResponse({
|
||||
content: JSON.stringify({
|
||||
title: "7天训练计划",
|
||||
exercises: [
|
||||
{
|
||||
day: 1,
|
||||
name: "正手影子挥拍",
|
||||
category: "影子挥拍",
|
||||
duration: 15,
|
||||
description: "完成正手挥拍练习",
|
||||
tips: "保持重心稳定",
|
||||
sets: 3,
|
||||
reps: 12,
|
||||
},
|
||||
],
|
||||
}),
|
||||
fallbackTitle: "fallback",
|
||||
});
|
||||
|
||||
expect(result.title).toBe("7天训练计划");
|
||||
expect(result.exercises).toHaveLength(1);
|
||||
expect(result.exercises[0]?.category).toBe("影子挥拍");
|
||||
});
|
||||
|
||||
it("normalizes qwen day map output into plan exercises", () => {
|
||||
const result = normalizeTrainingPlanResponse({
|
||||
content: JSON.stringify({
|
||||
day_1: {
|
||||
duration_minutes: 45,
|
||||
focus: "基础握拍与正手影子挥拍",
|
||||
exercises: [
|
||||
{
|
||||
name: "握拍方式学习",
|
||||
description: "学习大陆式与东方式握拍",
|
||||
duration_minutes: 10,
|
||||
},
|
||||
{
|
||||
name: "原地小碎步热身与放松",
|
||||
description: "30秒快速小碎步 + 30秒休息",
|
||||
duration_minutes: 10,
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
fallbackTitle: "7天训练计划",
|
||||
});
|
||||
|
||||
expect(result.title).toBe("7天训练计划");
|
||||
expect(result.exercises).toHaveLength(2);
|
||||
expect(result.exercises[0]).toMatchObject({
|
||||
day: 1,
|
||||
name: "握拍方式学习",
|
||||
duration: 10,
|
||||
sets: 3,
|
||||
reps: 10,
|
||||
});
|
||||
expect(result.exercises[1]?.category).toBe("脚步移动");
|
||||
});
|
||||
});
|
||||
|
||||
describe("normalizeAdjustedPlanResponse", () => {
|
||||
it("fills missing adjustment notes for day map output", () => {
|
||||
const result = normalizeAdjustedPlanResponse({
|
||||
content: JSON.stringify({
|
||||
day_1: {
|
||||
duration_minutes: 30,
|
||||
focus: "脚步移动",
|
||||
exercises: [
|
||||
{
|
||||
name: "交叉步移动",
|
||||
description: "左右移动并快速回位",
|
||||
duration_minutes: 12,
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
fallbackTitle: "当前训练计划",
|
||||
});
|
||||
|
||||
expect(result.title).toBe("当前训练计划");
|
||||
expect(result.adjustmentNotes).toContain("已根据最近分析结果调整");
|
||||
expect(result.exercises[0]).toMatchObject({
|
||||
day: 1,
|
||||
name: "交叉步移动",
|
||||
category: "脚步移动",
|
||||
});
|
||||
});
|
||||
});
|
||||
200
server/trainingPlan.ts
普通文件
200
server/trainingPlan.ts
普通文件
@@ -0,0 +1,200 @@
|
||||
import { z } from "zod";
|
||||
|
||||
const exerciseSchema = z.object({
|
||||
day: z.number().int().min(1),
|
||||
name: z.string().min(1),
|
||||
category: z.string().min(1),
|
||||
duration: z.number().positive(),
|
||||
description: z.string().min(1),
|
||||
tips: z.string().min(1),
|
||||
sets: z.number().int().positive(),
|
||||
reps: z.number().int().positive(),
|
||||
});
|
||||
|
||||
const normalizedPlanSchema = z.object({
|
||||
title: z.string().min(1),
|
||||
exercises: z.array(exerciseSchema).min(1),
|
||||
});
|
||||
|
||||
const normalizedAdjustedPlanSchema = normalizedPlanSchema.extend({
|
||||
adjustmentNotes: z.string().min(1),
|
||||
});
|
||||
|
||||
type NormalizedExercise = z.infer<typeof exerciseSchema>;
|
||||
type NormalizedPlan = z.infer<typeof normalizedPlanSchema>;
|
||||
type NormalizedAdjustedPlan = z.infer<typeof normalizedAdjustedPlanSchema>;
|
||||
|
||||
const dayKeyPattern = /^day[_\s-]?(\d+)$/i;
|
||||
|
||||
function extractTextContent(content: unknown) {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
|
||||
if (Array.isArray(content)) {
|
||||
const text = content
|
||||
.map(item => (item && typeof item === "object" && "text" in item ? String((item as { text?: unknown }).text ?? "") : ""))
|
||||
.join("")
|
||||
.trim();
|
||||
|
||||
return text.length > 0 ? text : null;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseJsonContent(content: unknown) {
|
||||
const text = extractTextContent(content);
|
||||
if (!text) {
|
||||
throw new Error("LLM did not return text content");
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(text) as Record<string, unknown>;
|
||||
} catch (error) {
|
||||
throw new Error(`LLM returned invalid JSON: ${error instanceof Error ? error.message : "unknown error"}`);
|
||||
}
|
||||
}
|
||||
|
||||
function toPositiveNumber(value: unknown, fallback: number) {
|
||||
const parsed = typeof value === "number" ? value : Number(value);
|
||||
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
|
||||
}
|
||||
|
||||
function toPositiveInteger(value: unknown, fallback: number) {
|
||||
const parsed = typeof value === "number" ? value : Number.parseInt(String(value), 10);
|
||||
return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback;
|
||||
}
|
||||
|
||||
function inferCategory(...values: Array<unknown>) {
|
||||
const text = values
|
||||
.filter((value): value is string => typeof value === "string")
|
||||
.join(" ");
|
||||
|
||||
if (/(墙|wall)/i.test(text)) return "墙壁练习";
|
||||
if (/(步|移动|footwork|shuffle|split step)/i.test(text)) return "脚步移动";
|
||||
if (/(挥拍|shadow|正手|反手|发球|截击)/i.test(text)) return "影子挥拍";
|
||||
return "体能训练";
|
||||
}
|
||||
|
||||
function normalizeExercise(
|
||||
day: number,
|
||||
exercise: Record<string, unknown>,
|
||||
section?: Record<string, unknown>
|
||||
): NormalizedExercise {
|
||||
const name =
|
||||
typeof exercise.name === "string" && exercise.name.trim().length > 0
|
||||
? exercise.name.trim()
|
||||
: typeof section?.focus === "string" && section.focus.trim().length > 0
|
||||
? section.focus.trim()
|
||||
: `第${day}天训练项目`;
|
||||
|
||||
const description =
|
||||
typeof exercise.description === "string" && exercise.description.trim().length > 0
|
||||
? exercise.description.trim()
|
||||
: typeof section?.focus === "string" && section.focus.trim().length > 0
|
||||
? section.focus.trim()
|
||||
: `${name}训练`;
|
||||
|
||||
const tips =
|
||||
typeof exercise.tips === "string" && exercise.tips.trim().length > 0
|
||||
? exercise.tips.trim()
|
||||
: typeof section?.focus === "string" && section.focus.trim().length > 0
|
||||
? `重点关注:${section.focus.trim()}`
|
||||
: "保持动作稳定,注意训练节奏。";
|
||||
|
||||
return {
|
||||
day,
|
||||
name,
|
||||
category: inferCategory(exercise.category, name, description, section?.focus),
|
||||
duration: toPositiveNumber(
|
||||
exercise.duration ?? exercise.duration_minutes,
|
||||
toPositiveNumber(section?.duration_minutes, 10)
|
||||
),
|
||||
description,
|
||||
tips,
|
||||
sets: toPositiveInteger(exercise.sets, 3),
|
||||
reps: toPositiveInteger(exercise.reps, 10),
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeDayMapPlan(
|
||||
raw: Record<string, unknown>,
|
||||
fallbackTitle: string
|
||||
): NormalizedPlan {
|
||||
const exercises: NormalizedExercise[] = [];
|
||||
|
||||
for (const [key, value] of Object.entries(raw)) {
|
||||
const match = key.match(dayKeyPattern);
|
||||
if (!match || !value || typeof value !== "object" || Array.isArray(value)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const day = Number.parseInt(match[1] ?? "", 10);
|
||||
if (!Number.isFinite(day) || day <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const section = value as Record<string, unknown>;
|
||||
const sectionExercises = Array.isArray(section.exercises)
|
||||
? section.exercises.filter(
|
||||
(item): item is Record<string, unknown> =>
|
||||
Boolean(item) && typeof item === "object" && !Array.isArray(item)
|
||||
)
|
||||
: [];
|
||||
|
||||
for (const exercise of sectionExercises) {
|
||||
exercises.push(normalizeExercise(day, exercise, section));
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedPlanSchema.parse({
|
||||
title:
|
||||
typeof raw.title === "string" && raw.title.trim().length > 0
|
||||
? raw.title.trim()
|
||||
: fallbackTitle,
|
||||
exercises,
|
||||
});
|
||||
}
|
||||
|
||||
export function normalizeTrainingPlanResponse(params: {
|
||||
content: unknown;
|
||||
fallbackTitle: string;
|
||||
}): NormalizedPlan {
|
||||
const raw = parseJsonContent(params.content);
|
||||
|
||||
if (Array.isArray(raw.exercises)) {
|
||||
return normalizedPlanSchema.parse(raw);
|
||||
}
|
||||
|
||||
return normalizeDayMapPlan(raw, params.fallbackTitle);
|
||||
}
|
||||
|
||||
export function normalizeAdjustedPlanResponse(params: {
|
||||
content: unknown;
|
||||
fallbackTitle: string;
|
||||
}): NormalizedAdjustedPlan {
|
||||
const raw = parseJsonContent(params.content);
|
||||
|
||||
if (Array.isArray(raw.exercises)) {
|
||||
return normalizedAdjustedPlanSchema.parse({
|
||||
...raw,
|
||||
adjustmentNotes:
|
||||
typeof raw.adjustmentNotes === "string" && raw.adjustmentNotes.trim().length > 0
|
||||
? raw.adjustmentNotes.trim()
|
||||
: "已根据最近分析结果调整训练内容。",
|
||||
});
|
||||
}
|
||||
|
||||
const normalized = normalizeDayMapPlan(raw, params.fallbackTitle);
|
||||
|
||||
return normalizedAdjustedPlanSchema.parse({
|
||||
...normalized,
|
||||
adjustmentNotes:
|
||||
typeof raw.adjustmentNotes === "string" && raw.adjustmentNotes.trim().length > 0
|
||||
? raw.adjustmentNotes.trim()
|
||||
: typeof raw.adjustment_notes === "string" && raw.adjustment_notes.trim().length > 0
|
||||
? raw.adjustment_notes.trim()
|
||||
: "已根据最近分析结果调整训练内容。",
|
||||
});
|
||||
}
|
||||
@@ -23,6 +23,8 @@ test("training page shows plan generation flow", async ({ page }) => {
|
||||
await page.goto("/training");
|
||||
await expect(page.getByTestId("training-title")).toBeVisible();
|
||||
await expect(page.getByTestId("training-generate-button")).toBeVisible();
|
||||
await page.getByTestId("training-generate-button").click();
|
||||
await expect(page.getByText("TestPlayer 的训练计划")).toBeVisible();
|
||||
});
|
||||
|
||||
test("videos page renders video library items", async ({ page }) => {
|
||||
|
||||
@@ -59,6 +59,24 @@ type MockAppState = {
|
||||
user: MockUser;
|
||||
videos: any[];
|
||||
analyses: any[];
|
||||
activePlan: {
|
||||
id: number;
|
||||
title: string;
|
||||
skillLevel: string;
|
||||
durationDays: number;
|
||||
exercises: Array<{
|
||||
day: number;
|
||||
name: string;
|
||||
category: string;
|
||||
duration: number;
|
||||
description: string;
|
||||
tips: string;
|
||||
sets: number;
|
||||
reps: number;
|
||||
}>;
|
||||
version: number;
|
||||
adjustmentNotes: string | null;
|
||||
} | null;
|
||||
mediaSession: MockMediaSession | null;
|
||||
nextVideoId: number;
|
||||
authMeNullResponsesAfterLogin: number;
|
||||
@@ -166,9 +184,41 @@ async function handleTrpc(route: Route, state: MockAppState) {
|
||||
case "profile.stats":
|
||||
return trpcResult(buildStats(state.user));
|
||||
case "plan.active":
|
||||
return trpcResult(null);
|
||||
return trpcResult(state.activePlan);
|
||||
case "plan.list":
|
||||
return trpcResult([]);
|
||||
return trpcResult(state.activePlan ? [state.activePlan] : []);
|
||||
case "plan.generate":
|
||||
state.activePlan = {
|
||||
id: 200,
|
||||
title: `${state.user.name} 的训练计划`,
|
||||
skillLevel: "beginner",
|
||||
durationDays: 7,
|
||||
version: 1,
|
||||
adjustmentNotes: null,
|
||||
exercises: [
|
||||
{
|
||||
day: 1,
|
||||
name: "正手影子挥拍",
|
||||
category: "影子挥拍",
|
||||
duration: 15,
|
||||
description: "练习完整引拍和收拍动作。",
|
||||
tips: "保持重心稳定,击球点在身体前侧。",
|
||||
sets: 3,
|
||||
reps: 12,
|
||||
},
|
||||
{
|
||||
day: 1,
|
||||
name: "交叉步移动",
|
||||
category: "脚步移动",
|
||||
duration: 12,
|
||||
description: "强化启动和回位节奏。",
|
||||
tips: "每次移动后快速回到准备姿势。",
|
||||
sets: 4,
|
||||
reps: 10,
|
||||
},
|
||||
],
|
||||
};
|
||||
return trpcResult({ planId: state.activePlan.id, plan: state.activePlan });
|
||||
case "video.list":
|
||||
return trpcResult(state.videos);
|
||||
case "analysis.list":
|
||||
@@ -316,6 +366,7 @@ export async function installAppMocks(
|
||||
createdAt: nowIso(),
|
||||
},
|
||||
],
|
||||
activePlan: null,
|
||||
mediaSession: null,
|
||||
nextVideoId: 100,
|
||||
authMeNullResponsesAfterLogin: options?.authMeNullResponsesAfterLogin ?? 0,
|
||||
|
||||
在新工单中引用
屏蔽一个用户