Skip to content

Commit 37227c8

Browse files
authored
refactoring
- added defaults - removed oneof - added Times()
1 parent 184b00d commit 37227c8

3 files changed

Lines changed: 233 additions & 62 deletions

File tree

httpmockserver.go

Lines changed: 99 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,32 @@ package httpmockserver
33
import (
44
"bytes"
55
"fmt"
6-
"testing"
76
"io/ioutil"
7+
"net"
88
"net/http"
9-
"sync"
109
"net/http/httptest"
11-
"net"
10+
"sync"
11+
"testing"
1212
)
1313

14-
func New(ssl bool, t *testing.T) *MockServer {
15-
return NewWithPort("0", ssl, t)
14+
type Opts struct {
15+
Port string
16+
UseSSL bool
1617
}
1718

18-
func NewWithPort(port string, ssl bool, t *testing.T) *MockServer {
19+
func (o *Opts) validate() error {
20+
return nil
21+
}
22+
23+
func New(t *testing.T) *MockServer {
24+
return NewWithOpts(t, Opts{})
25+
}
26+
27+
func NewWithOpts(t *testing.T, opts Opts) *MockServer {
28+
err := opts.validate()
29+
if err != nil {
30+
t.Fatal(err)
31+
}
1932

2033
mockServer := &MockServer{
2134
t: t,
@@ -25,19 +38,18 @@ func NewWithPort(port string, ssl bool, t *testing.T) *MockServer {
2538
mockServer.server = httptest.NewUnstartedServer(mockServer)
2639
mockServer.server.Config.SetKeepAlivesEnabled(false)
2740

28-
if port != "0" {
41+
if opts.Port != "0" && opts.Port != "" {
2942
mockServer.server.Listener.Close()
30-
l, err := net.Listen("tcp", "127.0.0.1:"+port)
43+
l, err := net.Listen("tcp", "127.0.0.1:"+opts.Port)
3144
if err != nil {
32-
panic(fmt.Sprintf("httptest: failed to listen on 127.0.0.1:%v: %v", port, err))
45+
t.Fatalf("httpmock: failed to listen on 127.0.0.1:%v: %v", opts.Port, err)
3346
}
3447
mockServer.server.Listener = l
3548
}
3649

37-
if (ssl) {
50+
if opts.UseSSL {
3851
mockServer.server.StartTLS()
3952
} else {
40-
4153
mockServer.server.Start()
4254
}
4355

@@ -52,28 +64,21 @@ type MockServer struct {
5264
handlerMutex sync.Mutex
5365

5466
every []*requestExpectation
55-
one []*requestExpectation
5667
expectations []*requestExpectation
68+
defaults []*requestExpectation
5769
}
5870

59-
func (s *MockServer) GetURL() string {
71+
func (s *MockServer) URL() string {
6072
return s.server.URL
6173
}
6274

63-
// TODO: should not be public
6475
func (s *MockServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
65-
// only one request at a time
6676
s.handlerMutex.Lock()
6777
defer s.handlerMutex.Unlock()
6878

69-
// check if we have expectations
70-
if len(s.expectations) == 0 {
71-
s.t.Fatalf("Missing expectation for %v %v", r.Method, r.URL.Path)
72-
}
73-
7479
body, err := ioutil.ReadAll(r.Body)
7580
if err != nil {
76-
s.t.Fatal("request validation failed: could not read incoming request body")
81+
s.t.Fatal("request validation failed: could not read incoming request body: ", err.Error())
7782
}
7883

7984
incomingRequest := &IncomingRequest{
@@ -85,49 +90,60 @@ func (s *MockServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
8590
for _, every := range s.every {
8691
for _, everyExp := range every.requestValidations {
8792
if err := everyExp.validation(incomingRequest); err != nil {
88-
s.t.Fatalf("EVERY expectation failed: %v", err.Error())
93+
s.t.Errorf("expectation failed: %v", err)
8994
}
9095
}
9196
}
9297

93-
// check ONEOF expectations
94-
oneMatch := true
95-
for _, one := range s.one {
96-
oneMatch = true
97-
for _, oneExp := range one.requestValidations {
98-
if oneExp.validation(incomingRequest) != nil {
99-
oneMatch = false
100-
break
101-
}
98+
var matchedExpectation *requestExpectation
99+
// check if call matches an expectation
100+
outerExp:
101+
for _, exp := range s.expectations {
102+
if exp.count >= exp.max {
103+
continue
102104
}
103-
if oneMatch {
104-
break
105+
106+
for _, reqVal := range exp.requestValidations {
107+
if err := reqVal.validation(incomingRequest); err != nil {
108+
continue outerExp
109+
}
105110
}
106-
}
107111

108-
if !oneMatch {
109-
s.t.Fatalf("request validation failed: no match in ONEOF constraint list for %v %v", r.Method, r.URL.Path)
112+
matchedExpectation = exp
113+
matchedExpectation.count++
114+
break
110115
}
111116

112-
exp := s.expectations[0]
113-
s.expectations = s.expectations[1:]
117+
// if not matched any of the expectations
118+
if matchedExpectation == nil {
119+
// check if call matches a default
120+
outerDefaults:
121+
for _, exp := range s.defaults {
122+
for _, reqVal := range exp.requestValidations {
123+
if err := reqVal.validation(incomingRequest); err != nil {
124+
continue outerDefaults
125+
}
126+
}
114127

115-
// check request validations
116-
for _, reqVal := range exp.requestValidations {
117-
if err := reqVal.validation(incomingRequest); err != nil {
118-
s.t.Fatalf(err.Error())
128+
matchedExpectation = exp
129+
break
119130
}
120131
}
121132

133+
// if no default found log request and return default code
134+
if matchedExpectation == nil {
135+
s.t.Fatalf("Unexpected call:\nMethod: %v\nURL: %v\nHeaders: %v\nBody: %v", r.Method, r.URL.Path, r.Header, string(body))
136+
}
137+
122138
// build response
123-
for key, value := range exp.response.Headers {
139+
for key, value := range matchedExpectation.response.Headers {
124140
w.Header().Set(key, value)
125141
}
126142

127-
w.WriteHeader(exp.response.Code)
143+
w.WriteHeader(matchedExpectation.response.Code)
128144

129-
if exp.response.Body != nil {
130-
w.Write(exp.response.Body)
145+
if matchedExpectation.response.Body != nil {
146+
w.Write(matchedExpectation.response.Body)
131147
}
132148
}
133149

@@ -140,40 +156,56 @@ func (s *MockServer) EVERY() RequestExpectation {
140156
return exp
141157
}
142158

143-
func (s *MockServer) ONEOF() RequestExpectation {
144-
exp := new(requestExpectation)
145-
exp.t = s.t
159+
func (s *MockServer) EXPECT() RequestExpectation {
160+
exp := &requestExpectation{
161+
t: s.t,
162+
count: 0,
163+
min: 1,
164+
max: 1,
165+
}
146166

147-
s.one = append(s.one, exp)
167+
// TODO: default response
168+
exp.response = &MockResponse{
169+
Code: 404,
170+
Headers: make(map[string]string),
171+
}
172+
173+
s.expectations = append(s.expectations, exp)
148174
return exp
149175
}
150176

151-
func (s *MockServer) EXPECT() RequestExpectation {
152-
exp := new(requestExpectation)
153-
exp.t = s.t
177+
func (s *MockServer) DEFAULT() RequestExpectation {
178+
exp := &requestExpectation{
179+
t: s.t,
180+
}
154181

155-
// default response
182+
// TODO: default response
156183
exp.response = &MockResponse{
157184
Code: 404,
158185
Headers: make(map[string]string),
159186
}
160187

161-
s.expectations = append(s.expectations, exp)
188+
s.defaults = append(s.defaults, exp)
162189
return exp
163190
}
164191

165192
func (s *MockServer) Finish() {
166-
if len(s.expectations) != 0 {
167-
var buf bytes.Buffer
193+
var buf bytes.Buffer
168194

169-
for i, exp := range s.expectations {
195+
unsatisfied := false
196+
for i, exp := range s.expectations {
197+
if exp.count < exp.min || exp.count > exp.max {
198+
unsatisfied = true
170199
buf.WriteString(fmt.Sprintf("%v. Expectation\n", i+1))
171-
172200
for _, val := range exp.requestValidations {
173201
buf.WriteString(fmt.Sprintf("----- %v\n", val.description))
174202
}
203+
175204
}
176-
s.t.Fatalf("\nexpectations not satisfied:\n%v", buf.String())
205+
}
206+
207+
if unsatisfied {
208+
s.t.Fatalf("\nexpectation(s) not satisfied:\n%v", buf.String())
177209
}
178210
}
179211

@@ -183,3 +215,10 @@ func (s *MockServer) Shutdown() {
183215

184216
s.server.Close()
185217
}
218+
219+
type request struct {
220+
Method string
221+
Headers map[string][]string
222+
URL string
223+
Body []byte
224+
}

httpmockserver_test.go

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,96 @@
1-
package httpmockserver
1+
package httpmockserver_test
2+
3+
import (
4+
"bytes"
5+
"github.com/stretchr/testify/assert"
6+
"github.com/ybbus/httpmockserver"
7+
"net/http"
8+
"testing"
9+
)
10+
11+
func TestMockServer_EXPECT(t *testing.T) {
12+
check := assert.New(t)
13+
14+
tests := []struct {
15+
Description string
16+
Run func(mockServer *httpmockserver.MockServer, url string)
17+
}{
18+
{
19+
Description: "simple get on /hello",
20+
Run: func(mockServer *httpmockserver.MockServer, url string) {
21+
mockServer.EXPECT().Get("/hello").Response(200)
22+
23+
res, err := Get(url+"/hello", nil)
24+
check.NoError(err)
25+
check.Equal(200, res.StatusCode)
26+
},
27+
},
28+
{
29+
Description: "simple get on /hello with header",
30+
Run: func(mockServer *httpmockserver.MockServer, url string) {
31+
mockServer.EXPECT().Get("/hello").Header("Test", "123").Response(200)
32+
33+
res, err := Get(url+"/hello", map[string]string{"Test": "123"})
34+
check.NoError(err)
35+
check.Equal(200, res.StatusCode)
36+
},
37+
},
38+
{
39+
Description: "default calls",
40+
Run: func(mockServer *httpmockserver.MockServer, url string) {
41+
mockServer.DEFAULT().AnyRequest().Response(201)
42+
mockServer.EXPECT().Get("/hello").Header("Test", "123").Response(200)
43+
44+
res, err := Get(url+"/hello", map[string]string{"Test": "123"})
45+
check.NoError(err)
46+
check.Equal(200, res.StatusCode)
47+
48+
res, err = Get(url+"/hello", map[string]string{"Test": "123"})
49+
check.NoError(err)
50+
check.Equal(201, res.StatusCode)
51+
52+
res, err = Post(url+"/test", map[string]string{"Test": "123"}, []byte("Hello World"))
53+
check.NoError(err)
54+
check.Equal(201, res.StatusCode)
55+
},
56+
},
57+
}
58+
59+
for _, test := range tests {
60+
func() {
61+
server := httpmockserver.New(t)
62+
defer server.Shutdown()
63+
64+
test.Run(server, server.URL())
65+
66+
server.Finish()
67+
}()
68+
}
69+
70+
}
71+
72+
func Get(url string, headers map[string]string) (*http.Response, error) {
73+
c := http.Client{}
74+
req, err := http.NewRequest("GET", url, nil)
75+
if err != nil {
76+
panic(err)
77+
}
78+
for k, v := range headers {
79+
req.Header.Set(k, v)
80+
}
81+
82+
return c.Do(req)
83+
}
84+
85+
func Post(url string, headers map[string]string, body []byte) (*http.Response, error) {
86+
c := http.Client{}
87+
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
88+
if err != nil {
89+
panic(err)
90+
}
91+
for k, v := range headers {
92+
req.Header.Set(k, v)
93+
}
94+
95+
return c.Do(req)
96+
}

0 commit comments

Comments
 (0)