Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[sse]: make method and body configurable #985

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
var (
defaultSseMaxBufSize = 1 << 15 // 32kb
defaultEventName = "message"
defaultHTTPMethod = MethodGet

headerID = []byte("id:")
headerData = []byte("data:")
Expand Down Expand Up @@ -63,7 +64,9 @@ type (
EventSource struct {
lock *sync.RWMutex
url string
method string
header http.Header
body io.Reader
lastEventID string
retryCount int
retryWaitTime time.Duration
Expand Down Expand Up @@ -126,6 +129,14 @@ func (es *EventSource) SetURL(url string) *EventSource {
return es
}

// SetMethod method sets a [EventSource] connection HTTP method in the instance
//
// es.SetMethod("POST"), or es.SetMethod(resty.MethodPost)
func (es *EventSource) SetMethod(method string) *EventSource {
es.method = method
return es
}

// SetHeader method sets a header and its value to the [EventSource] instance.
// It overwrites the header value if the key already exists. These headers will be sent in
// the request while establishing a connection to the event source
Expand All @@ -139,6 +150,14 @@ func (es *EventSource) SetHeader(header, value string) *EventSource {
return es
}

// SetBody method sets body value to the [EventSource] instance
//
// es.SetBody([]byte(`{"test":"put_data"}`),)
func (es *EventSource) SetBody(body io.Reader) *EventSource {
es.body = body
return es
}

// AddHeader method adds a header and its value to the [EventSource] instance.
// If the header key already exists, it appends. These headers will be sent in
// the request while establishing a connection to the event source
Expand Down Expand Up @@ -343,8 +362,14 @@ func (es *EventSource) Get() error {
if isStringEmpty(es.url) {
return fmt.Errorf("resty:sse: event source URL is required")
}
if _, found := es.onEvent[defaultEventName]; !found {
return fmt.Errorf("resty:sse: OnMessage function is required")
if isStringEmpty(es.method) {
// It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here.
// Ensure compatibility, use GET as default http method
es.method = defaultHTTPMethod
}

if len(es.onEvent) == 0 {
return fmt.Errorf("resty:sse: At least one OnMessage/AddEventListener func is required")
}

// reset to begin
Expand Down Expand Up @@ -401,7 +426,7 @@ func (es *EventSource) triggerOnError(err error) {
}

func (es *EventSource) createRequest() (*http.Request, error) {
req, err := http.NewRequest(MethodGet, es.url, nil)
req, err := http.NewRequest(es.method, es.url, es.body)
if err != nil {
return nil, err
}
Expand Down
154 changes: 154 additions & 0 deletions sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func TestEventSourceSimpleFlow(t *testing.T) {
defer ts.Close()

es.SetURL(ts.URL)
es.SetMethod(MethodPost)
err := es.Get()
assertNil(t, err)
assertEqual(t, counter, messageCounter)
Expand Down Expand Up @@ -115,6 +116,7 @@ func TestEventSourceMultipleEventTypes(t *testing.T) {
defer ts.Close()

es.SetURL(ts.URL).
SetMethod(MethodPost).
AddEventListener("user_connect", userConnectFunc, userEvent{}).
AddEventListener("user_message", userMessageFunc, userEvent{})

Expand Down Expand Up @@ -354,6 +356,7 @@ func TestEventSourceCoverage(t *testing.T) {
func createEventSource(t *testing.T, url string, fn EventMessageFunc, rt any) *EventSource {
es := NewEventSource().
SetURL(url).
SetMethod(MethodGet).
AddHeader("X-Test-Header-1", "test header 1").
SetHeader("X-Test-Header-2", "test header 2").
SetRetryCount(2).
Expand Down Expand Up @@ -406,3 +409,154 @@ func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer)
}
})
}

func TestEventSourceWithDifferentMethods(t *testing.T) {
testCases := []struct {
name string
method string
body []byte
}{
{
name: "GET Method",
method: MethodGet,
body: nil,
},
{
name: "POST Method",
method: MethodPost,
body: []byte(`{"test":"post_data"}`),
},
{
name: "PUT Method",
method: MethodPut,
body: []byte(`{"test":"put_data"}`),
},
{
name: "DELETE Method",
method: MethodDelete,
body: nil,
},
{
name: "PATCH Method",
method: MethodPatch,
body: []byte(`{"test":"patch_data"}`),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
messageCounter := 0
messageFunc := func(e any) {
event := e.(*Event)
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method)))
messageCounter++
}

counter := 0
methodVerified := false
bodyVerified := false

es := createEventSource(t, "", messageFunc, nil)
ts := createMethodVerifyingSSETestServer(
t,
10*time.Millisecond,
tc.method,
tc.body,
&methodVerified,
&bodyVerified,
func(w io.Writer) error {
if counter == 20 {
es.Close()
return fmt.Errorf("stop sending events")
}
_, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))
counter++
return err
},
)
defer ts.Close()

es.SetURL(ts.URL)
es.SetMethod(tc.method)

// set body
if tc.body != nil {
es.SetBody(bytes.NewBuffer(tc.body))
}

err := es.Get()
assertNil(t, err)

// check the message count
assertEqual(t, counter, messageCounter)

// check if server receive correct method and body
assertEqual(t, true, methodVerified)
if tc.body != nil {
assertEqual(t, true, bodyVerified)
}
})
}
}

// almost like create server before but add verifying method and body
func createMethodVerifyingSSETestServer(
t *testing.T,
ticker time.Duration,
expectedMethod string,
expectedBody []byte,
methodVerified *bool,
bodyVerified *bool,
fn func(io.Writer) error,
) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// validate method
if r.Method == expectedMethod {
*methodVerified = true
} else {
t.Errorf("Expected method %s, got %s", expectedMethod, r.Method)
}

// validate body
if expectedBody != nil {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed to read request body: %v", err)
} else if string(body) == string(expectedBody) {
*bodyVerified = true
} else {
t.Errorf("Expected body %s, got %s", string(expectedBody), string(body))
}
}

// same as createSSETestServer
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

clientGone := r.Context().Done()

rc := http.NewResponseController(w)
tick := time.NewTicker(ticker)
defer tick.Stop()

for {
select {
case <-clientGone:
t.Log("Client disconnected")
return
case <-tick.C:
if err := fn(w); err != nil {
t.Log(err)
return
}
if err := rc.Flush(); err != nil {
t.Log(err)
return
}
}
}
}))
}