diff --git a/sse.go b/sse.go index 440cabe..64a0116 100644 --- a/sse.go +++ b/sse.go @@ -24,6 +24,7 @@ import ( var ( defaultSseMaxBufSize = 1 << 15 // 32kb defaultEventName = "message" + defaultHTTPMethod = MethodGet headerID = []byte("id:") headerData = []byte("data:") @@ -63,7 +64,9 @@ type ( EventSource struct { lock *sync.RWMutex url string + method string header http.Header + body io.Reader lastEventID string retryCount int retryWaitTime time.Duration @@ -126,6 +129,14 @@ func (es *EventSource) SetURL(url string) *EventSource { return es } +// SetMethod method sets a [EventSource] connection HTTP method in the instance +// +// es.SetMethod("POST"), or es.SetMethod(resty.MethodPost) +func (es *EventSource) SetMethod(method string) *EventSource { + es.method = method + return es +} + // SetHeader method sets a header and its value to the [EventSource] instance. // It overwrites the header value if the key already exists. These headers will be sent in // the request while establishing a connection to the event source @@ -139,6 +150,14 @@ func (es *EventSource) SetHeader(header, value string) *EventSource { return es } +// SetBody method sets body value to the [EventSource] instance +// +// es.SetBody([]byte(`{"test":"put_data"}`),) +func (es *EventSource) SetBody(body io.Reader) *EventSource { + es.body = body + return es +} + // AddHeader method adds a header and its value to the [EventSource] instance. // If the header key already exists, it appends. These headers will be sent in // the request while establishing a connection to the event source @@ -343,8 +362,14 @@ func (es *EventSource) Get() error { if isStringEmpty(es.url) { return fmt.Errorf("resty:sse: event source URL is required") } - if _, found := es.onEvent[defaultEventName]; !found { - return fmt.Errorf("resty:sse: OnMessage function is required") + if isStringEmpty(es.method) { + // It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here. + // Ensure compatibility, use GET as default http method + es.method = defaultHTTPMethod + } + + if len(es.onEvent) == 0 { + return fmt.Errorf("resty:sse: At least one OnMessage/AddEventListener func is required") } // reset to begin @@ -401,7 +426,7 @@ func (es *EventSource) triggerOnError(err error) { } func (es *EventSource) createRequest() (*http.Request, error) { - req, err := http.NewRequest(MethodGet, es.url, nil) + req, err := http.NewRequest(es.method, es.url, es.body) if err != nil { return nil, err } diff --git a/sse_test.go b/sse_test.go index f64be2b..70beaa5 100644 --- a/sse_test.go +++ b/sse_test.go @@ -46,6 +46,7 @@ func TestEventSourceSimpleFlow(t *testing.T) { defer ts.Close() es.SetURL(ts.URL) + es.SetMethod(MethodPost) err := es.Get() assertNil(t, err) assertEqual(t, counter, messageCounter) @@ -115,6 +116,7 @@ func TestEventSourceMultipleEventTypes(t *testing.T) { defer ts.Close() es.SetURL(ts.URL). + SetMethod(MethodPost). AddEventListener("user_connect", userConnectFunc, userEvent{}). AddEventListener("user_message", userMessageFunc, userEvent{}) @@ -354,6 +356,7 @@ func TestEventSourceCoverage(t *testing.T) { func createEventSource(t *testing.T, url string, fn EventMessageFunc, rt any) *EventSource { es := NewEventSource(). SetURL(url). + SetMethod(MethodGet). AddHeader("X-Test-Header-1", "test header 1"). SetHeader("X-Test-Header-2", "test header 2"). SetRetryCount(2). @@ -406,3 +409,154 @@ func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer) } }) } + +func TestEventSourceWithDifferentMethods(t *testing.T) { + testCases := []struct { + name string + method string + body []byte + }{ + { + name: "GET Method", + method: MethodGet, + body: nil, + }, + { + name: "POST Method", + method: MethodPost, + body: []byte(`{"test":"post_data"}`), + }, + { + name: "PUT Method", + method: MethodPut, + body: []byte(`{"test":"put_data"}`), + }, + { + name: "DELETE Method", + method: MethodDelete, + body: nil, + }, + { + name: "PATCH Method", + method: MethodPatch, + body: []byte(`{"test":"patch_data"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + messageCounter := 0 + messageFunc := func(e any) { + event := e.(*Event) + assertEqual(t, strconv.Itoa(messageCounter), event.ID) + assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method))) + messageCounter++ + } + + counter := 0 + methodVerified := false + bodyVerified := false + + es := createEventSource(t, "", messageFunc, nil) + ts := createMethodVerifyingSSETestServer( + t, + 10*time.Millisecond, + tc.method, + tc.body, + &methodVerified, + &bodyVerified, + func(w io.Writer) error { + if counter == 20 { + es.Close() + return fmt.Errorf("stop sending events") + } + _, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339)) + counter++ + return err + }, + ) + defer ts.Close() + + es.SetURL(ts.URL) + es.SetMethod(tc.method) + + // set body + if tc.body != nil { + es.SetBody(bytes.NewBuffer(tc.body)) + } + + err := es.Get() + assertNil(t, err) + + // check the message count + assertEqual(t, counter, messageCounter) + + // check if server receive correct method and body + assertEqual(t, true, methodVerified) + if tc.body != nil { + assertEqual(t, true, bodyVerified) + } + }) + } +} + +// almost like create server before but add verifying method and body +func createMethodVerifyingSSETestServer( + t *testing.T, + ticker time.Duration, + expectedMethod string, + expectedBody []byte, + methodVerified *bool, + bodyVerified *bool, + fn func(io.Writer) error, +) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // validate method + if r.Method == expectedMethod { + *methodVerified = true + } else { + t.Errorf("Expected method %s, got %s", expectedMethod, r.Method) + } + + // validate body + if expectedBody != nil { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed to read request body: %v", err) + } else if string(body) == string(expectedBody) { + *bodyVerified = true + } else { + t.Errorf("Expected body %s, got %s", string(expectedBody), string(body)) + } + } + + // same as createSSETestServer + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + clientGone := r.Context().Done() + + rc := http.NewResponseController(w) + tick := time.NewTicker(ticker) + defer tick.Stop() + + for { + select { + case <-clientGone: + t.Log("Client disconnected") + return + case <-tick.C: + if err := fn(w); err != nil { + t.Log(err) + return + } + if err := rc.Flush(); err != nil { + t.Log(err) + return + } + } + } + })) +}