package image import ( "bytes" "context" "encoding/json" "io" "mime" "mime/multipart" "net/http" "net/http/httptest" "net/textproto" "path/filepath" "strings" "sync/atomic" "testing" "time" "scenemint/internal/quota" "github.com/labstack/echo/v5" "github.com/sunls24/gox/network/client" "github.com/sunls24/gox/openai" "github.com/sunls24/gox/server" ) func TestSubmitGenerationTaskUsesJSON(t *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/image-tasks/generations" { t.Fatalf("unexpected path %q", r.URL.Path) } if got := r.Header.Get("Authorization"); got != "Bearer test-key" { t.Fatalf("Authorization = %q, want Bearer test-key", got) } if got := r.Header.Get("Content-Type"); got != "application/json" { t.Fatalf("Content-Type = %q, want application/json", got) } var body taskSubmitRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("Decode request body: %v", err) } if body.ClientTaskID != "task-1" || body.Prompt != "quiet studio scene" || body.Model != "gpt-image-2" || body.Size != "1:1" { t.Fatalf("unexpected request body: %+v", body) } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ ID: "task-1", Status: "queued", Mode: "generation", Model: "gpt-image-2", Size: "1:1", }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := &Client{ taskAPIRoot: ts.URL, apiKey: "test-key", http: client.New(), } task, err := c.submitGenerationTask(context.Background(), taskSubmitRequest{ ClientTaskID: "task-1", Prompt: "quiet studio scene", Model: "gpt-image-2", Size: "1:1", }) if err != nil { t.Fatalf("submitGenerationTask returned error: %v", err) } if task.ID != "task-1" || task.Status != "queued" || task.Mode != "generation" { t.Fatalf("unexpected task: %+v", task) } } func testReferenceUpload(data []byte) *referenceUpload { return &referenceUpload{ Reader: bytes.NewReader(data), Filename: "reference.png", ContentType: "image/png", } } func TestSubmitEditTaskUsesMultipartFile(t *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/image-tasks/edits" { t.Fatalf("unexpected path %q", r.URL.Path) } if got := r.Header.Get("Authorization"); got != "Bearer test-key" { t.Fatalf("Authorization = %q, want Bearer test-key", got) } if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data; boundary=") { t.Fatalf("Content-Type = %q, want multipart/form-data", got) } if r.ContentLength <= 0 { t.Fatalf("ContentLength = %d, want known positive length", r.ContentLength) } if len(r.TransferEncoding) != 0 { t.Fatalf("TransferEncoding = %v, want no chunked transfer encoding", r.TransferEncoding) } if err := r.ParseMultipartForm(20 << 20); err != nil { t.Fatalf("ParseMultipartForm: %v", err) } wantFields := map[string]string{ "client_task_id": "task-1", "prompt": "replace the background", "model": "gpt-image-2", "size": "1:1", } for name, want := range wantFields { if got := r.FormValue(name); got != want { t.Fatalf("%s = %q, want %q", name, got, want) } } files := r.MultipartForm.File["image"] if len(files) != 1 { t.Fatalf("image files = %d, want 1", len(files)) } if got := files[0].Filename; got != "reference.png" { t.Fatalf("image filename = %q, want reference.png", got) } if got := files[0].Header.Get("Content-Type"); got != "image/png" { t.Fatalf("image Content-Type = %q, want image/png", got) } file, err := files[0].Open() if err != nil { t.Fatalf("Open image file: %v", err) } defer file.Close() data, err := io.ReadAll(file) if err != nil { t.Fatalf("Read image file: %v", err) } if string(data) != "\x89PNG\r\n\x1a\n" { t.Fatalf("image bytes = %q, want PNG header", string(data)) } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ ID: "task-1", Status: "queued", Mode: "edit", Model: "gpt-image-2", Size: "1:1", }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := &Client{ taskAPIRoot: ts.URL, apiKey: "test-key", rawHTTP: ts.Client(), } task, err := c.submitEditTask(context.Background(), taskSubmitRequest{ ClientTaskID: "task-1", Prompt: "replace the background", Model: "gpt-image-2", Size: "1:1", imageUpload: testReferenceUpload([]byte("\x89PNG\r\n\x1a\n")), }) if err != nil { t.Fatalf("submitEditTask returned error: %v", err) } if task.ID != "task-1" || task.Status != "queued" || task.Mode != "edit" { t.Fatalf("unexpected task: %+v", task) } } func TestSubmitEditTaskSendsLargeReferenceImageAsFile(t *testing.T) { t.Parallel() largeImage := bytes.Repeat([]byte{0xab}, (1024*1024)+1) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data; boundary=") { t.Fatalf("Content-Type = %q, want multipart/form-data", got) } if r.ContentLength <= 0 { t.Fatalf("ContentLength = %d, want known positive length", r.ContentLength) } if len(r.TransferEncoding) != 0 { t.Fatalf("TransferEncoding = %v, want no chunked transfer encoding", r.TransferEncoding) } reader, err := r.MultipartReader() if err != nil { t.Fatalf("MultipartReader: %v", err) } var foundImage bool for { part, err := reader.NextPart() if err == io.EOF { break } if err != nil { t.Fatalf("NextPart: %v", err) } if part.FormName() != "image" { _, _ = io.Copy(io.Discard, part) continue } foundImage = true if part.FileName() == "" { t.Fatal("image part was sent as a form field, want file part") } if got := part.Header.Get("Content-Type"); got != "image/png" { t.Fatalf("image Content-Type = %q, want image/png", got) } got, err := io.ReadAll(part) if err != nil { t.Fatalf("Read image part: %v", err) } if !bytes.Equal(got, largeImage) { t.Fatalf("image payload length = %d, want %d", len(got), len(largeImage)) } } if !foundImage { t.Fatal("image part was not sent") } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ ID: "task-1", Status: "queued", Mode: "edit", Model: "gpt-image-2", Size: "1:1", }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := &Client{ taskAPIRoot: ts.URL, apiKey: "test-key", rawHTTP: ts.Client(), } _, err := c.submitEditTask(context.Background(), taskSubmitRequest{ ClientTaskID: "task-1", Prompt: "replace the background", Model: "gpt-image-2", Size: "1:1", imageUpload: testReferenceUpload(largeImage), }) if err != nil { t.Fatalf("submitEditTask returned error: %v", err) } } func TestSubmitEditTaskRequiresReferenceUpload(t *testing.T) { t.Parallel() tests := []struct { name string body taskSubmitRequest }{ { name: "missing upload", body: taskSubmitRequest{}, }, { name: "nil reader", body: taskSubmitRequest{ imageUpload: &referenceUpload{ Filename: "reference.png", ContentType: "image/png", }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := (&Client{}).submitEditTask(context.Background(), tt.body) if err == nil { t.Fatal("submitEditTask returned nil error") } if got := err.Error(); !strings.Contains(got, "参考图不能为空") { t.Fatalf("error = %q, want reference upload message", got) } }) } } func TestSubmitEditTaskReturnsUpstreamErrorBody(t *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.Copy(io.Discard, r.Body) http.Error(w, `{"detail":"bad request"}`, http.StatusUnprocessableEntity) })) defer ts.Close() c := &Client{ taskAPIRoot: ts.URL, apiKey: "test-key", rawHTTP: ts.Client(), } _, err := c.submitEditTask(context.Background(), taskSubmitRequest{ ClientTaskID: "task-1", Prompt: "replace the background", imageUpload: testReferenceUpload([]byte("\x89PNG\r\n\x1a\n")), }) if err == nil { t.Fatal("submitEditTask returned nil error") } if got := err.Error(); !strings.Contains(got, "422 Unprocessable Entity") || !strings.Contains(got, "bad request") { t.Fatalf("error = %q, want status and body", got) } } func TestGenerateSpendsCreditAfterSuccessfulSubmit(t *testing.T) { t.Parallel() store := newQuotaStore(t) if _, err := store.ApplyCheckIn("fingerprint123"); err != nil { t.Fatalf("ApplyCheckIn returned error: %v", err) } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/image-tasks/generations" { t.Fatalf("unexpected path %q", r.URL.Path) } var body map[string]any if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("Decode request body: %v", err) } if got := body["prompt"]; got != "quiet studio scene" { t.Fatalf("prompt = %q, want original prompt", got) } if _, ok := body["style"]; ok { t.Fatalf("request body should not include style: %+v", body) } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ ID: "task-1", Status: "queued", Mode: "generation", Model: "gpt-image-2", Size: "1:1", }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := &Client{ openAIBaseURL: ts.URL + "/v1", taskAPIRoot: ts.URL, apiKey: "test-key", model: "gpt-image-2", quota: store, http: client.New(), } resp, err := c.Generate(context.Background(), GenerateRequest{ Prompt: "quiet studio scene", Size: "1:1", Fingerprint: "fingerprint123", }) if err != nil { t.Fatalf("Generate returned error: %v", err) } if resp.RemainingCredits == nil || *resp.RemainingCredits != quota.DailyGrant-1 { t.Fatalf("RemainingCredits = %v, want %d", resp.RemainingCredits, quota.DailyGrant-1) } status, err := store.Get("fingerprint123") if err != nil { t.Fatalf("quota Get returned error: %v", err) } if status.Balance != quota.DailyGrant-1 { t.Fatalf("quota balance = %d, want %d", status.Balance, quota.DailyGrant-1) } } const ( testGenerateHTTPFingerprint = "fingerprint123" testGenerateHTTPPrompt = "replace the background" testGenerateHTTPSize = "1:1" testGenerateHTTPModel = "gpt-image-2" testGenerateHTTPReferenceName = "reference.png" testGenerateHTTPReferenceType = "image/png" testGenerateHTTPGenerationsPath = "/api/image-tasks/generations" testGenerateHTTPEditsPath = "/api/image-tasks/edits" testGenerateHTTPGeneratePath = "/api/images/generate" testGenerateHTTPAuthorization = "Bearer test-key" testGenerateHTTPReferenceBytes = "\x89PNG\r\n\x1a\n" ) func newGenerateHTTPQuotaStore(t *testing.T) *quota.Store { t.Helper() store := newQuotaStore(t) if _, err := store.ApplyCheckIn(testGenerateHTTPFingerprint); err != nil { t.Fatalf("ApplyCheckIn returned error: %v", err) } return store } func newGenerateHTTPTestClient(store *quota.Store, ts *httptest.Server) *Client { return &Client{ openAIBaseURL: ts.URL + "/v1", taskAPIRoot: ts.URL, apiKey: "test-key", model: testGenerateHTTPModel, quota: store, http: client.New(client.WithClient(ts.Client())), rawHTTP: ts.Client(), } } func newJSONGenerateRequest(t *testing.T) *http.Request { t.Helper() var body bytes.Buffer err := json.NewEncoder(&body).Encode(GenerateRequest{ Prompt: testGenerateHTTPPrompt, Size: testGenerateHTTPSize, Fingerprint: testGenerateHTTPFingerprint, }) if err != nil { t.Fatalf("Encode JSON request: %v", err) } req := httptest.NewRequest(http.MethodPost, testGenerateHTTPGeneratePath, &body) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) return req } func newMultipartGenerateRequest(t *testing.T, imageBytes []byte) *http.Request { t.Helper() var body bytes.Buffer writer := multipart.NewWriter(&body) fields := [][2]string{ {"prompt", testGenerateHTTPPrompt}, {"size", testGenerateHTTPSize}, {"fingerprint", testGenerateHTTPFingerprint}, } for _, field := range fields { if err := writer.WriteField(field[0], field[1]); err != nil { t.Fatalf("WriteField %s: %v", field[0], err) } } partHeader := make(textproto.MIMEHeader) partHeader.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{ "name": "image", "filename": testGenerateHTTPReferenceName, })) partHeader.Set("Content-Type", testGenerateHTTPReferenceType) part, err := writer.CreatePart(partHeader) if err != nil { t.Fatalf("CreateFormFile: %v", err) } if _, err := part.Write(imageBytes); err != nil { t.Fatalf("Write image: %v", err) } if err := writer.Close(); err != nil { t.Fatalf("Close writer: %v", err) } req := httptest.NewRequest(http.MethodPost, testGenerateHTTPGeneratePath, &body) req.Header.Set(echo.HeaderContentType, writer.FormDataContentType()) return req } func serveGenerateHTTP(t *testing.T, c *Client, req *http.Request) *httptest.ResponseRecorder { t.Helper() srv := server.New(func(srv *server.Server) { srv.Echo.POST(testGenerateHTTPGeneratePath, server.WrapReplyResp(c.GenerateReply)) }) rec := httptest.NewRecorder() srv.Echo.ServeHTTP(rec, req) return rec } func TestGenerateHTTPAcceptsJSONTextGeneration(t *testing.T) { t.Parallel() store := newGenerateHTTPQuotaStore(t) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != testGenerateHTTPGenerationsPath { t.Fatalf("unexpected path %q", r.URL.Path) } if got := r.Header.Get("Authorization"); got != testGenerateHTTPAuthorization { t.Fatalf("Authorization = %q, want %q", got, testGenerateHTTPAuthorization) } if got := r.Header.Get("Content-Type"); got != echo.MIMEApplicationJSON { t.Fatalf("Content-Type = %q, want %q", got, echo.MIMEApplicationJSON) } var body taskSubmitRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("Decode request body: %v", err) } if body.ClientTaskID == "" { t.Fatal("client_task_id is empty") } if body.Prompt != testGenerateHTTPPrompt || body.Model != testGenerateHTTPModel || body.Size != testGenerateHTTPSize { t.Fatalf("unexpected request body: %+v", body) } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ ID: "task-1", Status: "queued", Mode: "generation", Model: testGenerateHTTPModel, Size: testGenerateHTTPSize, }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := newGenerateHTTPTestClient(store, ts) rec := serveGenerateHTTP(t, c, newJSONGenerateRequest(t)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var env struct { Code int `json:"code"` Data GenerateResponse `json:"data"` } if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { t.Fatalf("Decode response: %v", err) } if env.Code != 0 { t.Fatalf("response code = %d, want 0", env.Code) } if env.Data.Mode != "text" || env.Data.Status != "queued" { t.Fatalf("unexpected response data: %+v", env.Data) } if env.Data.RemainingCredits == nil || *env.Data.RemainingCredits != quota.DailyGrant-1 { t.Fatalf("RemainingCredits = %v, want %d", env.Data.RemainingCredits, quota.DailyGrant-1) } } func TestGenerateHTTPFallsBackToSubmittedTaskFields(t *testing.T) { t.Parallel() tests := []struct { name string request func(t *testing.T) *http.Request path string wantMode string }{ { name: "json text generation", request: newJSONGenerateRequest, path: testGenerateHTTPGenerationsPath, wantMode: "text", }, { name: "multipart reference image", request: func(t *testing.T) *http.Request { t.Helper() return newMultipartGenerateRequest(t, []byte(testGenerateHTTPReferenceBytes)) }, path: testGenerateHTTPEditsPath, wantMode: "image", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() store := newGenerateHTTPQuotaStore(t) submittedIDs := make(chan string, 1) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != tt.path { t.Fatalf("unexpected path %q", r.URL.Path) } var submittedID string if tt.path == testGenerateHTTPGenerationsPath { var body taskSubmitRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("Decode request body: %v", err) } submittedID = body.ClientTaskID } else { if err := r.ParseMultipartForm(20 << 20); err != nil { t.Fatalf("ParseMultipartForm: %v", err) } submittedID = r.FormValue("client_task_id") } if strings.TrimSpace(submittedID) == "" { t.Fatal("client_task_id is empty") } submittedIDs <- submittedID w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ Status: "queued", }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := newGenerateHTTPTestClient(store, ts) rec := serveGenerateHTTP(t, c, tt.request(t)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var env struct { Code int `json:"code"` Data GenerateResponse `json:"data"` } if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { t.Fatalf("Decode response: %v", err) } if env.Code != 0 { t.Fatalf("response code = %d, want 0", env.Code) } var submittedID string select { case submittedID = <-submittedIDs: case <-time.After(time.Second): t.Fatal("upstream did not receive submitted task id") } if env.Data.ID != submittedID { t.Fatalf("response id = %q, want submitted client_task_id %q", env.Data.ID, submittedID) } if env.Data.Mode != tt.wantMode { t.Fatalf("response mode = %q, want %q", env.Data.Mode, tt.wantMode) } if env.Data.Size != testGenerateHTTPSize { t.Fatalf("response size = %q, want %q", env.Data.Size, testGenerateHTTPSize) } if _, err := time.Parse(time.RFC3339, env.Data.CreatedAt); err != nil { t.Fatalf("response createdAt = %q, want RFC3339: %v", env.Data.CreatedAt, err) } }) } } func TestGenerateHTTPAcceptsMultipartReferenceImage(t *testing.T) { t.Parallel() store := newGenerateHTTPQuotaStore(t) imageBytes := []byte(testGenerateHTTPReferenceBytes) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != testGenerateHTTPEditsPath { t.Fatalf("unexpected path %q", r.URL.Path) } if got := r.Header.Get("Authorization"); got != testGenerateHTTPAuthorization { t.Fatalf("Authorization = %q, want %q", got, testGenerateHTTPAuthorization) } if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "multipart/form-data; boundary=") { t.Fatalf("Content-Type = %q, want multipart/form-data", got) } if err := r.ParseMultipartForm(20 << 20); err != nil { t.Fatalf("ParseMultipartForm: %v", err) } if got := r.FormValue("client_task_id"); strings.TrimSpace(got) == "" { t.Fatal("client_task_id is empty") } wantFields := map[string]string{ "prompt": testGenerateHTTPPrompt, "model": testGenerateHTTPModel, "size": testGenerateHTTPSize, } for name, want := range wantFields { if got := r.FormValue(name); got != want { t.Fatalf("%s = %q, want %q", name, got, want) } } files := r.MultipartForm.File["image"] if len(files) != 1 { t.Fatalf("image files = %d, want 1", len(files)) } if got := files[0].Filename; got != testGenerateHTTPReferenceName { t.Fatalf("image filename = %q, want %q", got, testGenerateHTTPReferenceName) } if got := files[0].Header.Get("Content-Type"); got != testGenerateHTTPReferenceType { t.Fatalf("image Content-Type = %q, want %q", got, testGenerateHTTPReferenceType) } file, err := files[0].Open() if err != nil { t.Fatalf("Open image file: %v", err) } defer file.Close() data, err := io.ReadAll(file) if err != nil { t.Fatalf("Read image file: %v", err) } if !bytes.Equal(data, imageBytes) { t.Fatalf("image bytes = %q, want %q", data, imageBytes) } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(upstreamTask{ ID: "task-1", Status: "queued", Mode: "edit", Model: testGenerateHTTPModel, Size: testGenerateHTTPSize, }); err != nil { t.Fatalf("Encode: %v", err) } })) defer ts.Close() c := newGenerateHTTPTestClient(store, ts) rec := serveGenerateHTTP(t, c, newMultipartGenerateRequest(t, imageBytes)) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var env struct { Code int `json:"code"` Data GenerateResponse `json:"data"` } if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { t.Fatalf("Decode response: %v", err) } if env.Code != 0 { t.Fatalf("response code = %d, want 0", env.Code) } if env.Data.Mode != "image" || env.Data.Status != "queued" { t.Fatalf("unexpected response data: %+v", env.Data) } if env.Data.RemainingCredits == nil || *env.Data.RemainingCredits != quota.DailyGrant-1 { t.Fatalf("RemainingCredits = %v, want %d", env.Data.RemainingCredits, quota.DailyGrant-1) } } func TestGenerateHTTPRejectsInvalidMultipartReferenceImage(t *testing.T) { t.Parallel() tests := []struct { name string imageBytes []byte want string }{ { name: "empty image", imageBytes: []byte{}, want: "参考图不能为空", }, { name: "oversized image", imageBytes: bytes.Repeat([]byte{0xab}, maxReferenceUploadBytes+1), want: "参考图不能超过 10MB", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rec := serveGenerateHTTP( t, &Client{}, newMultipartGenerateRequest(t, tt.imageBytes), ) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var env struct { Code int `json:"code"` Message string `json:"message"` } if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { t.Fatalf("Decode response: %v", err) } if env.Code == 0 || !strings.Contains(env.Message, tt.want) { t.Fatalf("unexpected response envelope: %+v, want message containing %q", env, tt.want) } }) } } func TestGenerateHTTPRefundsCreditWhenMultipartSubmitFails(t *testing.T) { t.Parallel() store := newGenerateHTTPQuotaStore(t) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != testGenerateHTTPEditsPath { t.Fatalf("unexpected path %q", r.URL.Path) } http.Error(w, `{"detail":"upstream down"}`, http.StatusBadGateway) })) defer ts.Close() c := newGenerateHTTPTestClient(store, ts) rec := serveGenerateHTTP( t, c, newMultipartGenerateRequest(t, []byte(testGenerateHTTPReferenceBytes)), ) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var env struct { Code int `json:"code"` Message string `json:"message"` } if err := json.NewDecoder(rec.Body).Decode(&env); err != nil { t.Fatalf("Decode response: %v", err) } if env.Code == 0 || !strings.Contains(env.Message, "图片任务提交失败") { t.Fatalf("unexpected response envelope: %+v", env) } status, err := store.Get(testGenerateHTTPFingerprint) if err != nil { t.Fatalf("quota Get returned error: %v", err) } if status.Balance != quota.DailyGrant { t.Fatalf("quota balance = %d, want refunded balance %d", status.Balance, quota.DailyGrant) } } func TestGenerateRejectsZeroQuota(t *testing.T) { t.Parallel() var called atomic.Bool ts := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called.Store(true) })) defer ts.Close() c := &Client{ openAIBaseURL: ts.URL + "/v1", taskAPIRoot: ts.URL, apiKey: "test-key", model: "gpt-image-2", quota: newQuotaStore(t), http: client.New(), } _, err := c.Generate(context.Background(), GenerateRequest{ Prompt: "quiet studio scene", Fingerprint: "fingerprint123", }) if err == nil { t.Fatal("Generate returned nil error") } if !strings.Contains(err.Error(), "额度不足") { t.Fatalf("Generate error = %q, want quota message", err.Error()) } if called.Load() { t.Fatal("upstream was called despite zero quota") } } func TestGenerateRefundsCreditWhenSubmitFails(t *testing.T) { t.Parallel() store := newQuotaStore(t) if _, err := store.ApplyCheckIn("fingerprint123"); err != nil { t.Fatalf("ApplyCheckIn returned error: %v", err) } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.Copy(io.Discard, r.Body) http.Error(w, `{"detail":"upstream down"}`, http.StatusBadGateway) })) defer ts.Close() c := &Client{ openAIBaseURL: ts.URL + "/v1", taskAPIRoot: ts.URL, apiKey: "test-key", model: "gpt-image-2", quota: store, http: client.New(), } _, err := c.Generate(context.Background(), GenerateRequest{ Prompt: "quiet studio scene", Fingerprint: "fingerprint123", }) if err == nil { t.Fatal("Generate returned nil error") } status, err := store.Get("fingerprint123") if err != nil { t.Fatalf("quota Get returned error: %v", err) } if status.Balance != quota.DailyGrant { t.Fatalf("quota balance = %d, want refunded balance %d", status.Balance, quota.DailyGrant) } } func TestGenerateRequiresChatGPT2APIConfig(t *testing.T) { t.Parallel() tests := []struct { name string client Client want string }{ { name: "missing base url", client: Client{ apiKey: "test-key", }, want: "CHATGPT2API_BASE_URL 未配置", }, { name: "missing api key", client: Client{ openAIBaseURL: "http://127.0.0.1:3200/v1", taskAPIRoot: "http://127.0.0.1:3200", }, want: "CHATGPT2API_API_KEY 未配置", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := tt.client.Generate(context.Background(), GenerateRequest{Prompt: "quiet studio scene"}) if err == nil { t.Fatal("Generate returned nil error") } if got := err.Error(); !strings.Contains(got, tt.want) { t.Fatalf("error = %q, want %q", got, tt.want) } }) } } func TestNormalizeUsesAspectRatioSizes(t *testing.T) { t.Parallel() tests := []struct { name string size string want string }{ { name: "empty falls back to square ratio", size: "", want: "1:1", }, { name: "square ratio", size: "1:1", want: "1:1", }, { name: "landscape ratio", size: "16:9", want: "16:9", }, { name: "portrait ratio", size: "9:16", want: "9:16", }, { name: "legacy resolution falls back", size: "1024x1024", want: "1:1", }, { name: "auto falls back", size: "auto", want: "1:1", }, { name: "unknown ratio falls back", size: "4:3", want: "1:1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := normalize(GenerateRequest{Size: tt.size}) if got.Size != tt.want { t.Fatalf("normalize(%q).Size = %q, want %q", tt.size, got.Size, tt.want) } }) } } func TestNormalizeBaseURLs(t *testing.T) { t.Parallel() tests := []struct { name string in string wantOpenAI string wantTask string }{ { name: "root", in: "http://127.0.0.1:3200", wantOpenAI: "http://127.0.0.1:3200/v1", wantTask: "http://127.0.0.1:3200", }, { name: "v1 suffix", in: "http://127.0.0.1:3200/v1", wantOpenAI: "http://127.0.0.1:3200/v1", wantTask: "http://127.0.0.1:3200", }, { name: "v1 suffix with trailing slash", in: " http://127.0.0.1:3200/v1/ ", wantOpenAI: "http://127.0.0.1:3200/v1", wantTask: "http://127.0.0.1:3200", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := normalizeOpenAIBaseURL(tt.in); got != tt.wantOpenAI { t.Fatalf("normalizeOpenAIBaseURL(%q) = %q, want %q", tt.in, got, tt.wantOpenAI) } if got := normalizeTaskAPIRoot(tt.in); got != tt.wantTask { t.Fatalf("normalizeTaskAPIRoot(%q) = %q, want %q", tt.in, got, tt.wantTask) } }) } } func TestEnhancePromptStreamsUpstreamSSE(t *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/chat/completions" { t.Fatalf("unexpected path %q", r.URL.Path) } if got := r.Header.Get("Authorization"); got != "Bearer test-key" { t.Fatalf("Authorization = %q, want Bearer test-key", got) } if got := r.Header.Get("Content-Type"); got != "application/json" { t.Fatalf("Content-Type = %q, want application/json", got) } var body openai.ChatRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("Decode request body: %v", err) } if body.Model != "gpt-5.5" { t.Fatalf("model = %q, want gpt-5.5", body.Model) } if body.Stream == nil || !*body.Stream { t.Fatalf("stream = %v, want true", body.Stream) } if body.Temperature == nil || *body.Temperature != 0.35 { t.Fatalf("temperature = %v, want 0.35", body.Temperature) } if len(body.Messages) != 2 { t.Fatalf("messages length = %d, want 2", len(body.Messages)) } systemPrompt := body.Messages[0].Content if body.Messages[0].Role != openai.RSystem || !strings.Contains(systemPrompt, "自适应增强") || !strings.Contains(systemPrompt, "完整输入只做轻量润色") || !strings.Contains(systemPrompt, "不要输出尺寸、比例或横竖幅") { t.Fatalf("unexpected system message: %+v", body.Messages[0]) } if body.Messages[1].Role != openai.RUser || !strings.Contains(body.Messages[1].Content, "雨天咖啡馆") { t.Fatalf("unexpected user message: %+v", body.Messages[1]) } w.Header().Set("Content-Type", "text/event-stream") _, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"增强\"}}]}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) })) defer ts.Close() c := &Client{ openAIBaseURL: ts.URL + "/v1", taskAPIRoot: ts.URL, apiKey: "test-key", promptModel: "gpt-5.5", rawHTTP: ts.Client(), } e := echo.New() e.POST("/api/prompts/enhance", c.EnhancePrompt) req := httptest.NewRequest( http.MethodPost, "/api/prompts/enhance", strings.NewReader(`{"prompt":"雨天咖啡馆","direction":"details"}`), ) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } if got := rec.Header().Get(echo.HeaderContentType); !strings.HasPrefix(got, "text/event-stream") { t.Fatalf("Content-Type = %q, want text/event-stream", got) } if got := rec.Header().Get(echo.HeaderCacheControl); got != "no-cache" { t.Fatalf("Cache-Control = %q, want no-cache", got) } if got := rec.Body.String(); got != "data: {\"choices\":[{\"delta\":{\"content\":\"增强\"}}]}\n\ndata: [DONE]\n\n" { t.Fatalf("stream body = %q", got) } } func TestEnhanceSystemPromptAdaptsByDirection(t *testing.T) { t.Parallel() tests := []struct { name string direction string want []string }{ { name: "details keeps complete prompts restrained", direction: enhanceDirectionDetails, want: []string{ "风格、媒介、画面类型", "自适应增强", "完整输入只做轻量润色", "不要输出尺寸、比例或横竖幅", "省略它们", "补足缺失的主体细节", "不新增与原意无关", }, }, { name: "creative adds imaginative but bounded changes", direction: enhanceDirectionCreative, want: []string{ "风格、媒介、画面类型", "自适应增强", "完整输入只做轻量润色", "不要输出尺寸、比例或横竖幅", "省略它们", "只选择 1-2 个创意变量", "叙事感", "不要跑题", "堆叠无关元素", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := enhanceSystemPrompt(tt.direction) for _, want := range tt.want { if !strings.Contains(got, want) { t.Fatalf("enhanceSystemPrompt(%q) missing %q in %q", tt.direction, want, got) } } }) } } func TestEnhancePromptValidatesRequest(t *testing.T) { t.Parallel() tests := []struct { name string client Client body string status int want string }{ { name: "empty prompt", client: Client{ openAIBaseURL: "http://127.0.0.1:3200/v1", taskAPIRoot: "http://127.0.0.1:3200", apiKey: "test-key", promptModel: "gpt-5.5", }, body: `{"prompt":" ","direction":"details"}`, status: http.StatusBadRequest, want: "请输入需要增强的提示词", }, { name: "unknown direction", client: Client{ openAIBaseURL: "http://127.0.0.1:3200/v1", taskAPIRoot: "http://127.0.0.1:3200", apiKey: "test-key", promptModel: "gpt-5.5", }, body: `{"prompt":"quiet studio","direction":"photo"}`, status: http.StatusBadRequest, want: "提示词增强方向不支持", }, { name: "missing prompt model", client: Client{ openAIBaseURL: "http://127.0.0.1:3200/v1", taskAPIRoot: "http://127.0.0.1:3200", apiKey: "test-key", }, body: `{"prompt":"quiet studio","direction":"details"}`, status: http.StatusBadGateway, want: "CHATGPT2API_PROMPT_MODEL 未配置", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := echo.New() e.POST("/api/prompts/enhance", tt.client.EnhancePrompt) req := httptest.NewRequest(http.MethodPost, "/api/prompts/enhance", strings.NewReader(tt.body)) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) if rec.Code != tt.status { t.Fatalf("status = %d, want %d; body: %s", rec.Code, tt.status, rec.Body.String()) } if got := rec.Body.String(); !strings.Contains(got, tt.want) { t.Fatalf("body = %q, want to contain %q", got, tt.want) } }) } } func newQuotaStore(t *testing.T) *quota.Store { t.Helper() store, err := quota.Open(filepath.Join(t.TempDir(), "quota.db")) if err != nil { t.Fatalf("quota Open returned error: %v", err) } t.Cleanup(func() { if err := store.Close(); err != nil { t.Fatalf("quota Close returned error: %v", err) } }) return store }