文件
tennis-training-hub/server/taskWorker.ts
2026-03-15 00:41:09 +08:00

518 行
15 KiB
TypeScript

import { nanoid } from "nanoid";
import { ENV } from "./_core/env";
import { invokeLLM, type Message } from "./_core/llm";
import * as db from "./db";
import { getRemoteMediaSession } from "./mediaService";
import {
buildAdjustedTrainingPlanPrompt,
buildMultimodalCorrectionPrompt,
buildTextCorrectionPrompt,
buildTrainingPlanPrompt,
multimodalCorrectionSchema,
renderMultimodalCorrectionMarkdown,
} from "./prompts";
import { toPublicUrl } from "./publicUrl";
import { storagePut } from "./storage";
import {
normalizeAdjustedPlanResponse,
normalizeTrainingPlanResponse,
} from "./trainingPlan";
type TaskRow = Awaited<ReturnType<typeof db.getBackgroundTaskById>>;
type StructuredParams<T> = {
model?: string;
baseMessages: Array<{ role: "system" | "user"; content: string | Message["content"] }>;
responseFormat: {
type: "json_schema";
json_schema: {
name: string;
strict: true;
schema: Record<string, unknown>;
};
};
parse: (content: unknown) => T;
};
async function invokeStructured<T>(params: StructuredParams<T>) {
let lastError: unknown;
for (let attempt = 0; attempt < 3; attempt++) {
const retryHint =
attempt === 0 || !(lastError instanceof Error)
? []
: [{
role: "user" as const,
content:
`上一次输出无法被系统解析,错误是:${lastError.message}` +
"请只返回合法完整的 JSON 对象,不要附加 Markdown 或说明。",
}];
const response = await invokeLLM({
apiUrl: params.model === ENV.llmVisionModel ? ENV.llmVisionApiUrl : undefined,
apiKey: params.model === ENV.llmVisionModel ? ENV.llmVisionApiKey : undefined,
model: params.model,
messages: [...params.baseMessages, ...retryHint],
response_format: params.responseFormat,
});
try {
return params.parse(response.choices[0]?.message?.content);
} catch (error) {
lastError = error;
}
}
throw lastError instanceof Error ? lastError : new Error("Failed to parse structured LLM response");
}
function contentToPlainText(content: Message["content"]) {
if (typeof content === "string") {
return content;
}
const parts = Array.isArray(content) ? content : [content];
return parts
.map((part) => {
if (typeof part === "string") {
return part;
}
if (part.type === "text") {
return part.text;
}
if (part.type === "image_url") {
return `[image] ${part.image_url.url}`;
}
if (part.type === "file_url") {
return `[file] ${part.file_url.url}`;
}
return "";
})
.filter(Boolean)
.join("\n");
}
function parseDataUrl(input: string) {
const match = input.match(/^data:(.+?);base64,(.+)$/);
if (!match) {
throw new Error("Invalid image data URL");
}
return {
contentType: match[1],
buffer: Buffer.from(match[2], "base64"),
};
}
async function persistInlineImages(userId: number, imageDataUrls: string[]) {
const persistedUrls: string[] = [];
for (let index = 0; index < imageDataUrls.length; index++) {
const { contentType, buffer } = parseDataUrl(imageDataUrls[index]);
const extension = contentType.includes("png") ? "png" : "jpg";
const key = `analysis-images/${userId}/${nanoid()}.${extension}`;
const uploaded = await storagePut(key, buffer, contentType);
persistedUrls.push(toPublicUrl(uploaded.url));
}
return persistedUrls;
}
export async function prepareCorrectionImageUrls(input: {
userId: number;
imageUrls?: string[];
imageDataUrls?: string[];
}) {
const directUrls = (input.imageUrls ?? []).map((item) => toPublicUrl(item));
const uploadedUrls = input.imageDataUrls?.length
? await persistInlineImages(input.userId, input.imageDataUrls)
: [];
return [...directUrls, ...uploadedUrls];
}
async function runTrainingPlanGenerateTask(task: NonNullable<TaskRow>) {
const payload = task.payload as {
skillLevel: "beginner" | "intermediate" | "advanced";
durationDays: number;
focusAreas?: string[];
};
const analyses = await db.getUserAnalyses(task.userId);
const recentScores = analyses.slice(0, 5).map((analysis) => ({
score: analysis.overallScore ?? null,
issues: analysis.detectedIssues,
exerciseType: analysis.exerciseType ?? null,
shotCount: analysis.shotCount ?? null,
strokeConsistency: analysis.strokeConsistency ?? null,
footworkScore: analysis.footworkScore ?? null,
}));
const parsed = await invokeStructured({
baseMessages: [
{ role: "system", content: "你是网球训练计划生成器。返回严格的 JSON 格式。" },
{
role: "user",
content: buildTrainingPlanPrompt({
...payload,
recentScores,
}),
},
],
responseFormat: {
type: "json_schema",
json_schema: {
name: "training_plan",
strict: true,
schema: {
type: "object",
properties: {
title: { type: "string" },
exercises: {
type: "array",
items: {
type: "object",
properties: {
day: { type: "number" },
name: { type: "string" },
category: { type: "string" },
duration: { type: "number" },
description: { type: "string" },
tips: { type: "string" },
sets: { type: "number" },
reps: { type: "number" },
},
required: ["day", "name", "category", "duration", "description", "tips", "sets", "reps"],
additionalProperties: false,
},
},
},
required: ["title", "exercises"],
additionalProperties: false,
},
},
},
parse: (content) => normalizeTrainingPlanResponse({
content,
fallbackTitle: `${payload.durationDays}天训练计划`,
}),
});
const planId = await db.createTrainingPlan({
userId: task.userId,
title: parsed.title,
skillLevel: payload.skillLevel,
durationDays: payload.durationDays,
exercises: parsed.exercises,
isActive: 1,
version: 1,
});
return {
kind: "training_plan_generate" as const,
planId,
plan: parsed,
};
}
async function runTrainingPlanAdjustTask(task: NonNullable<TaskRow>) {
const payload = task.payload as { planId: number };
const analyses = await db.getUserAnalyses(task.userId);
const recentAnalyses = analyses.slice(0, 5);
const currentPlan = (await db.getUserTrainingPlans(task.userId)).find((plan) => plan.id === payload.planId);
if (!currentPlan) {
throw new Error("Plan not found");
}
const parsed = await invokeStructured({
baseMessages: [
{ role: "system", content: "你是网球训练计划调整器。返回严格的 JSON 格式。" },
{
role: "user",
content: buildAdjustedTrainingPlanPrompt({
currentExercises: currentPlan.exercises,
recentAnalyses: recentAnalyses.map((analysis) => ({
score: analysis.overallScore ?? null,
issues: analysis.detectedIssues,
corrections: analysis.corrections,
shotCount: analysis.shotCount ?? null,
strokeConsistency: analysis.strokeConsistency ?? null,
footworkScore: analysis.footworkScore ?? null,
fluidityScore: analysis.fluidityScore ?? null,
})),
}),
},
],
responseFormat: {
type: "json_schema",
json_schema: {
name: "adjusted_plan",
strict: true,
schema: {
type: "object",
properties: {
title: { type: "string" },
adjustmentNotes: { type: "string" },
exercises: {
type: "array",
items: {
type: "object",
properties: {
day: { type: "number" },
name: { type: "string" },
category: { type: "string" },
duration: { type: "number" },
description: { type: "string" },
tips: { type: "string" },
sets: { type: "number" },
reps: { type: "number" },
},
required: ["day", "name", "category", "duration", "description", "tips", "sets", "reps"],
additionalProperties: false,
},
},
},
required: ["title", "adjustmentNotes", "exercises"],
additionalProperties: false,
},
},
},
parse: (content) => normalizeAdjustedPlanResponse({
content,
fallbackTitle: currentPlan.title,
}),
});
await db.updateTrainingPlan(payload.planId, {
exercises: parsed.exercises,
adjustmentNotes: parsed.adjustmentNotes,
version: (currentPlan.version || 1) + 1,
});
return {
kind: "training_plan_adjust" as const,
planId: payload.planId,
plan: parsed,
adjustmentNotes: parsed.adjustmentNotes,
};
}
async function runTextCorrectionTask(task: NonNullable<TaskRow>) {
const payload = task.payload as {
exerciseType: string;
poseMetrics: unknown;
detectedIssues: unknown;
};
return createTextCorrectionResult(payload);
}
async function createTextCorrectionResult(payload: {
exerciseType: string;
poseMetrics: unknown;
detectedIssues: unknown;
}) {
const response = await invokeLLM({
messages: [
{
role: "system",
content: "你是一位专业网球技术教练。输出中文 Markdown,内容具体、克制、可执行。",
},
{
role: "user",
content: buildTextCorrectionPrompt(payload),
},
],
});
return {
kind: "analysis_corrections" as const,
corrections: contentToPlainText(response.choices[0]?.message?.content || "暂无建议"),
};
}
async function runMultimodalCorrectionTask(task: NonNullable<TaskRow>) {
const payload = task.payload as {
exerciseType: string;
poseMetrics: unknown;
detectedIssues: unknown;
imageUrls: string[];
};
try {
const report = await invokeStructured({
model: ENV.llmVisionModel,
baseMessages: [
{ role: "system", content: "你是专业网球教练。请基于图片和结构化姿态指标输出严格 JSON。" },
{
role: "user",
content: [
{ type: "text", text: buildMultimodalCorrectionPrompt({
exerciseType: payload.exerciseType,
poseMetrics: payload.poseMetrics,
detectedIssues: payload.detectedIssues,
imageCount: payload.imageUrls.length,
}) },
...payload.imageUrls.map((url) => ({
type: "image_url" as const,
image_url: {
url,
detail: "high" as const,
},
})),
],
},
],
responseFormat: {
type: "json_schema",
json_schema: {
name: "pose_correction_multimodal",
strict: true,
schema: multimodalCorrectionSchema,
},
},
parse: (content) => {
if (typeof content === "string") {
return JSON.parse(content);
}
return content as Record<string, unknown>;
},
});
const result = {
kind: "pose_correction_multimodal" as const,
imageUrls: payload.imageUrls,
report,
corrections: renderMultimodalCorrectionMarkdown(report as Parameters<typeof renderMultimodalCorrectionMarkdown>[0]),
visionStatus: "ok" as const,
};
await db.completeVisionTestRun(task.id, {
visionStatus: "ok",
summary: (report as { summary?: string }).summary ?? null,
corrections: result.corrections,
report,
warning: null,
});
return result;
} catch (error) {
const fallback = await createTextCorrectionResult(payload);
const result = {
kind: "pose_correction_multimodal" as const,
imageUrls: payload.imageUrls,
report: null,
corrections: fallback.corrections,
visionStatus: "fallback" as const,
warning: error instanceof Error ? error.message : "Vision model unavailable",
};
await db.completeVisionTestRun(task.id, {
visionStatus: "fallback",
summary: null,
corrections: result.corrections,
report: null,
warning: result.warning,
});
return result;
}
}
async function runMediaFinalizeTask(task: NonNullable<TaskRow>) {
const payload = task.payload as {
sessionId: string;
title: string;
exerciseType?: string;
};
const session = await getRemoteMediaSession(payload.sessionId);
if (session.userId !== String(task.userId)) {
throw new Error("Media session does not belong to the task user");
}
if (session.archiveStatus === "queued") {
await db.rescheduleBackgroundTask(task.id, {
progress: 45,
message: "录制文件已入队,等待归档",
delayMs: 4_000,
});
return null;
}
if (session.archiveStatus === "processing") {
await db.rescheduleBackgroundTask(task.id, {
progress: 78,
message: "录制文件正在整理与转码",
delayMs: 4_000,
});
return null;
}
if (session.archiveStatus === "failed") {
throw new Error(session.lastError || "Media archive failed");
}
if (!session.playback.ready) {
await db.rescheduleBackgroundTask(task.id, {
progress: 70,
message: "等待回放文件就绪",
delayMs: 4_000,
});
return null;
}
const preferredUrl = session.playback.mp4Url || session.playback.webmUrl;
const format = session.playback.mp4Url ? "mp4" : "webm";
if (!preferredUrl) {
throw new Error("Media session did not expose a playback URL");
}
const fileKey = `media/sessions/${session.id}/recording.${format}`;
const existing = await db.getVideoByFileKey(task.userId, fileKey);
if (existing) {
return {
kind: "media_finalize" as const,
sessionId: session.id,
videoId: existing.id,
url: existing.url,
fileKey,
format,
};
}
const publicUrl = toPublicUrl(preferredUrl);
const videoId = await db.createVideo({
userId: task.userId,
title: payload.title || session.title,
fileKey,
url: publicUrl,
format,
fileSize: format === "mp4" ? (session.playback.mp4Size ?? null) : (session.playback.webmSize ?? null),
duration: null,
exerciseType: payload.exerciseType || "recording",
analysisStatus: "completed",
});
return {
kind: "media_finalize" as const,
sessionId: session.id,
videoId,
url: publicUrl,
fileKey,
format,
};
}
export async function processBackgroundTask(task: NonNullable<TaskRow>) {
switch (task.type) {
case "training_plan_generate":
return runTrainingPlanGenerateTask(task);
case "training_plan_adjust":
return runTrainingPlanAdjustTask(task);
case "analysis_corrections":
return runTextCorrectionTask(task);
case "pose_correction_multimodal":
return runMultimodalCorrectionTask(task);
case "media_finalize":
return runMediaFinalizeTask(task);
default:
throw new Error(`Unsupported task type: ${String(task.type)}`);
}
}