Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[sse]: make method and body configurable #985

Closed
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion sse.go
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ import (
var (
defaultSseMaxBufSize = 1 << 15 // 32kb
defaultEventName = "message"
defaultHTTPMethod = MethodGet

headerID = []byte("id:")
headerData = []byte("data:")
@@ -63,7 +64,9 @@ type (
EventSource struct {
lock *sync.RWMutex
url string
method string
header http.Header
body io.Reader
lastEventID string
retryCount int
retryWaitTime time.Duration
@@ -126,6 +129,14 @@ func (es *EventSource) SetURL(url string) *EventSource {
return es
}

// SetMethod method sets a [EventSource] connection HTTP method in the instance
//
// es.SetMethod("POST"), or es.SetMethod(resty.MethodPost)
func (es *EventSource) SetMethod(method string) *EventSource {
es.method = method
return es
}

// SetHeader method sets a header and its value to the [EventSource] instance.
// It overwrites the header value if the key already exists. These headers will be sent in
// the request while establishing a connection to the event source
@@ -139,6 +150,14 @@ func (es *EventSource) SetHeader(header, value string) *EventSource {
return es
}

// SetBody method sets body value to the [EventSource] instance
//
// es.SetBody([]byte(`{"test":"put_data"}`),)
func (es *EventSource) SetBody(body io.Reader) *EventSource {
es.body = body
return es
}

// AddHeader method adds a header and its value to the [EventSource] instance.
// If the header key already exists, it appends. These headers will be sent in
// the request while establishing a connection to the event source
@@ -343,6 +362,12 @@ func (es *EventSource) Get() error {
if isStringEmpty(es.url) {
return fmt.Errorf("resty:sse: event source URL is required")
}
if isStringEmpty(es.method) {
// It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here.
// Ensure compatibility, use GET as default http method
es.method = defaultHTTPMethod
}

if _, found := es.onEvent[defaultEventName]; !found {
return fmt.Errorf("resty:sse: OnMessage function is required")
}
@@ -401,7 +426,7 @@ func (es *EventSource) triggerOnError(err error) {
}

func (es *EventSource) createRequest() (*http.Request, error) {
req, err := http.NewRequest(MethodGet, es.url, nil)
req, err := http.NewRequest(es.method, es.url, es.body)
if err != nil {
return nil, err
}
154 changes: 154 additions & 0 deletions sse_test.go
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ func TestEventSourceSimpleFlow(t *testing.T) {
defer ts.Close()

es.SetURL(ts.URL)
es.SetMethod(MethodPost)
err := es.Get()
assertNil(t, err)
assertEqual(t, counter, messageCounter)
@@ -115,6 +116,7 @@ func TestEventSourceMultipleEventTypes(t *testing.T) {
defer ts.Close()

es.SetURL(ts.URL).
SetMethod(MethodPost).
AddEventListener("user_connect", userConnectFunc, userEvent{}).
AddEventListener("user_message", userMessageFunc, userEvent{})

@@ -354,6 +356,7 @@ func TestEventSourceCoverage(t *testing.T) {
func createEventSource(t *testing.T, url string, fn EventMessageFunc, rt any) *EventSource {
es := NewEventSource().
SetURL(url).
SetMethod(MethodGet).
AddHeader("X-Test-Header-1", "test header 1").
SetHeader("X-Test-Header-2", "test header 2").
SetRetryCount(2).
@@ -406,3 +409,154 @@ func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer)
}
})
}

func TestEventSourceWithDifferentMethods(t *testing.T) {
testCases := []struct {
name string
method string
body []byte
}{
{
name: "GET Method",
method: MethodGet,
body: nil,
},
{
name: "POST Method",
method: MethodPost,
body: []byte(`{"test":"post_data"}`),
},
{
name: "PUT Method",
method: MethodPut,
body: []byte(`{"test":"put_data"}`),
},
{
name: "DELETE Method",
method: MethodDelete,
body: nil,
},
{
name: "PATCH Method",
method: MethodPatch,
body: []byte(`{"test":"patch_data"}`),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
messageCounter := 0
messageFunc := func(e any) {
event := e.(*Event)
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method)))
messageCounter++
}

counter := 0
methodVerified := false
bodyVerified := false

es := createEventSource(t, "", messageFunc, nil)
ts := createMethodVerifyingSSETestServer(
t,
10*time.Millisecond,
tc.method,
tc.body,
&methodVerified,
&bodyVerified,
func(w io.Writer) error {
if counter == 20 {
es.Close()
return fmt.Errorf("stop sending events")
}
_, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))
counter++
return err
},
)
defer ts.Close()

es.SetURL(ts.URL)
es.SetMethod(tc.method)

// set body
if tc.body != nil {
es.SetBody(bytes.NewBuffer(tc.body))
}

err := es.Get()
assertNil(t, err)

// check the message count
assertEqual(t, counter, messageCounter)

// check if server receive correct method and body
assertEqual(t, true, methodVerified)
if tc.body != nil {
assertEqual(t, true, bodyVerified)
}
})
}
}

// almost like create server before but add verifying method and body
func createMethodVerifyingSSETestServer(
t *testing.T,
ticker time.Duration,
expectedMethod string,
expectedBody []byte,
methodVerified *bool,
bodyVerified *bool,
fn func(io.Writer) error,
) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// validate method
if r.Method == expectedMethod {
*methodVerified = true
} else {
t.Errorf("Expected method %s, got %s", expectedMethod, r.Method)
}

// validate body
if expectedBody != nil {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed to read request body: %v", err)
} else if string(body) == string(expectedBody) {
*bodyVerified = true
} else {
t.Errorf("Expected body %s, got %s", string(expectedBody), string(body))
}
}

// same as createSSETestServer
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

clientGone := r.Context().Done()

rc := http.NewResponseController(w)
tick := time.NewTicker(ticker)
defer tick.Stop()

for {
select {
case <-clientGone:
t.Log("Client disconnected")
return
case <-tick.C:
if err := fn(w); err != nil {
t.Log(err)
return
}
if err := rc.Flush(); err != nil {
t.Log(err)
return
}
}
}
}))
}