From c4ed5bf580281f62dcfc2919eb874abcd7b2524f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=BC=A0=E6=B3=BD=E8=A5=BF=2E0059?= <xiyuliu@bytedance.com>
Date: Thu, 13 Mar 2025 21:57:58 +0800
Subject: [PATCH 1/3] feat[sse]: make method and body configurable

---
 sse.go      |  27 ++++++++-
 sse_test.go | 154 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 180 insertions(+), 1 deletion(-)

diff --git a/sse.go b/sse.go
index 440cabea..c2947b36 100644
--- a/sse.go
+++ b/sse.go
@@ -24,6 +24,7 @@ import (
 var (
 	defaultSseMaxBufSize = 1 << 15 // 32kb
 	defaultEventName     = "message"
+	defaultHTTPMethod    = MethodGet
 
 	headerID    = []byte("id:")
 	headerData  = []byte("data:")
@@ -63,7 +64,9 @@ type (
 	EventSource struct {
 		lock             *sync.RWMutex
 		url              string
+		method           string
 		header           http.Header
+		body             io.Reader
 		lastEventID      string
 		retryCount       int
 		retryWaitTime    time.Duration
@@ -126,6 +129,14 @@ func (es *EventSource) SetURL(url string) *EventSource {
 	return es
 }
 
+// SetMethod method sets a [EventSource] connection HTTP method in the instance
+//
+//	es.SetMethod("POST"), or es.SetMethod(resty.MethodPost)
+func (es *EventSource) SetMethod(method string) *EventSource {
+	es.method = method
+	return es
+}
+
 // SetHeader method sets a header and its value to the [EventSource] instance.
 // It overwrites the header value if the key already exists. These headers will be sent in
 // the request while establishing a connection to the event source
@@ -139,6 +150,14 @@ func (es *EventSource) SetHeader(header, value string) *EventSource {
 	return es
 }
 
+// SetBody method sets body value to the [EventSource] instance
+//
+//	es.SetBody([]byte(`{"test":"put_data"}`),)
+func (es *EventSource) SetBody(body io.Reader) *EventSource {
+	es.body = body
+	return es
+}
+
 // AddHeader method adds a header and its value to the [EventSource] instance.
 // If the header key already exists, it appends. These headers will be sent in
 // the request while establishing a connection to the event source
@@ -343,6 +362,12 @@ func (es *EventSource) Get() error {
 	if isStringEmpty(es.url) {
 		return fmt.Errorf("resty:sse: event source URL is required")
 	}
+	if isStringEmpty(es.method) {
+		// It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here.
+		// Ensure compatibility, use GET as default http method
+		es.method = defaultHTTPMethod
+	}
+
 	if _, found := es.onEvent[defaultEventName]; !found {
 		return fmt.Errorf("resty:sse: OnMessage function is required")
 	}
@@ -401,7 +426,7 @@ func (es *EventSource) triggerOnError(err error) {
 }
 
 func (es *EventSource) createRequest() (*http.Request, error) {
-	req, err := http.NewRequest(MethodGet, es.url, nil)
+	req, err := http.NewRequest(es.method, es.url, es.body)
 	if err != nil {
 		return nil, err
 	}
diff --git a/sse_test.go b/sse_test.go
index f64be2bf..28ff6d31 100644
--- a/sse_test.go
+++ b/sse_test.go
@@ -46,6 +46,7 @@ func TestEventSourceSimpleFlow(t *testing.T) {
 	defer ts.Close()
 
 	es.SetURL(ts.URL)
+	es.SetMethod(MethodPost)
 	err := es.Get()
 	assertNil(t, err)
 	assertEqual(t, counter, messageCounter)
@@ -115,6 +116,7 @@ func TestEventSourceMultipleEventTypes(t *testing.T) {
 	defer ts.Close()
 
 	es.SetURL(ts.URL).
+		SetMethod(MethodPost).
 		AddEventListener("user_connect", userConnectFunc, userEvent{}).
 		AddEventListener("user_message", userMessageFunc, userEvent{})
 
@@ -354,6 +356,7 @@ func TestEventSourceCoverage(t *testing.T) {
 func createEventSource(t *testing.T, url string, fn EventMessageFunc, rt any) *EventSource {
 	es := NewEventSource().
 		SetURL(url).
+		SetMethod(MethodGet).
 		AddHeader("X-Test-Header-1", "test header 1").
 		SetHeader("X-Test-Header-2", "test header 2").
 		SetRetryCount(2).
@@ -406,3 +409,154 @@ func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer)
 		}
 	})
 }
+
+func TestEventSourceWithDifferentMethods(t *testing.T) {
+	testCases := []struct {
+		name   string
+		method string
+		body   []byte
+	}{
+		{
+			name:   "GET Method",
+			method: MethodGet,
+			body:   nil,
+		},
+		{
+			name:   "POST Method",
+			method: MethodPost,
+			body:   []byte(`{"test":"post_data"}`),
+		},
+		{
+			name:   "PUT Method",
+			method: MethodPut,
+			body:   []byte(`{"test":"put_data"}`),
+		},
+		{
+			name:   "DELETE Method",
+			method: MethodDelete,
+			body:   nil,
+		},
+		{
+			name:   "PATCH Method",
+			method: MethodPatch,
+			body:   []byte(`{"test":"patch_data"}`),
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			messageCounter := 0
+			messageFunc := func(e any) {
+				event := e.(*Event)
+				assertEqual(t, strconv.Itoa(messageCounter), event.ID)
+				assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method)))
+				messageCounter++
+			}
+
+			counter := 0
+			methodVerified := false
+			bodyVerified := false
+
+			es := createEventSource(t, "", messageFunc, nil)
+			ts := createMethodVerifyingSSETestServer(
+				t,
+				10*time.Millisecond,
+				tc.method,
+				tc.body,
+				&methodVerified,
+				&bodyVerified,
+				func(w io.Writer) error {
+					if counter == 20 {
+						es.Close()
+						return fmt.Errorf("stop sending events")
+					}
+					_, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))
+					counter++
+					return err
+				},
+			)
+			defer ts.Close()
+
+			es.SetURL(ts.URL)
+			es.SetMethod(tc.method)
+
+			// 设置请求体(如果有)
+			if tc.body != nil {
+				es.SetBody(bytes.NewBuffer(tc.body))
+			}
+
+			err := es.Get()
+			assertNil(t, err)
+
+			// 验证接收到的消息数量
+			assertEqual(t, counter, messageCounter)
+
+			// check if server receive correct method and body
+			assertEqual(t, true, methodVerified)
+			if tc.body != nil {
+				assertEqual(t, true, bodyVerified)
+			}
+		})
+	}
+}
+
+// 创建一个验证请求方法和请求体的SSE测试服务器
+func createMethodVerifyingSSETestServer(
+	t *testing.T,
+	ticker time.Duration,
+	expectedMethod string,
+	expectedBody []byte,
+	methodVerified *bool,
+	bodyVerified *bool,
+	fn func(io.Writer) error,
+) *httptest.Server {
+	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		// validate method
+		if r.Method == expectedMethod {
+			*methodVerified = true
+		} else {
+			t.Errorf("Expected method %s, got %s", expectedMethod, r.Method)
+		}
+
+		// validate body
+		if expectedBody != nil {
+			body, err := io.ReadAll(r.Body)
+			if err != nil {
+				t.Errorf("Failed to read request body: %v", err)
+			} else if string(body) == string(expectedBody) {
+				*bodyVerified = true
+			} else {
+				t.Errorf("Expected body %s, got %s", string(expectedBody), string(body))
+			}
+		}
+
+		// same as createSSETestServer
+		w.Header().Set("Content-Type", "text/event-stream")
+		w.Header().Set("Cache-Control", "no-cache")
+		w.Header().Set("Connection", "keep-alive")
+		w.Header().Set("Access-Control-Allow-Origin", "*")
+
+		clientGone := r.Context().Done()
+
+		rc := http.NewResponseController(w)
+		tick := time.NewTicker(ticker)
+		defer tick.Stop()
+
+		for {
+			select {
+			case <-clientGone:
+				t.Log("Client disconnected")
+				return
+			case <-tick.C:
+				if err := fn(w); err != nil {
+					t.Log(err)
+					return
+				}
+				if err := rc.Flush(); err != nil {
+					t.Log(err)
+					return
+				}
+			}
+		}
+	}))
+}

From 0aad9f44ff4221f87cc0442158604f7ee23733d3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=BC=A0=E6=B3=BD=E8=A5=BF=2E0059?= <xiyuliu@bytedance.com>
Date: Thu, 13 Mar 2025 22:00:08 +0800
Subject: [PATCH 2/3] chore[sse]: edit comments

---
 sse_test.go | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/sse_test.go b/sse_test.go
index 28ff6d31..70beaa56 100644
--- a/sse_test.go
+++ b/sse_test.go
@@ -480,7 +480,7 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
 			es.SetURL(ts.URL)
 			es.SetMethod(tc.method)
 
-			// 设置请求体(如果有)
+			// set body
 			if tc.body != nil {
 				es.SetBody(bytes.NewBuffer(tc.body))
 			}
@@ -488,7 +488,7 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
 			err := es.Get()
 			assertNil(t, err)
 
-			// 验证接收到的消息数量
+			// check the message count
 			assertEqual(t, counter, messageCounter)
 
 			// check if server receive correct method and body
@@ -500,7 +500,7 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
 	}
 }
 
-// 创建一个验证请求方法和请求体的SSE测试服务器
+// almost like create server before but add verifying method and body
 func createMethodVerifyingSSETestServer(
 	t *testing.T,
 	ticker time.Duration,

From 4aea323a7ef7c168b0d207f1741ff68f7cb5d476 Mon Sep 17 00:00:00 2001
From: xiyuliu <xiyuliu@bytedance.com>
Date: Thu, 20 Mar 2025 23:13:52 +0800
Subject: [PATCH 3/3] [refactor] use AddEventListener check instead of
 OnMessage check

---
 sse.go | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/sse.go b/sse.go
index c2947b36..64a01166 100644
--- a/sse.go
+++ b/sse.go
@@ -368,8 +368,8 @@ func (es *EventSource) Get() error {
 		es.method = defaultHTTPMethod
 	}
 
-	if _, found := es.onEvent[defaultEventName]; !found {
-		return fmt.Errorf("resty:sse: OnMessage function is required")
+	if len(es.onEvent) == 0 {
+		return fmt.Errorf("resty:sse: At least one OnMessage/AddEventListener func is required")
 	}
 
 	// reset to begin