diff --git a/echoprometheus/prometheus.go b/echoprometheus/prometheus.go index e8b93be..7cab72c 100644 --- a/echoprometheus/prometheus.go +++ b/echoprometheus/prometheus.go @@ -78,6 +78,9 @@ type MiddlewareConfig struct { // If DoNotUseRequestPathFor404 is true, all 404 responses (due to non-matching route) will have the same `url` label and // thus won't generate new metrics. DoNotUseRequestPathFor404 bool + + // StatusCodeResolver resolves err & context into http status code. Default is to use context.Response().Status + StatusCodeResolver func(c echo.Context, err error) int } type LabelValueFunc func(c echo.Context, err error) string @@ -167,6 +170,9 @@ func (conf MiddlewareConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return opts } } + if conf.StatusCodeResolver == nil { + conf.StatusCodeResolver = defaultStatusResolver + } labelNames, customValuers := createLabels(conf.LabelFuncs) @@ -257,16 +263,7 @@ func (conf MiddlewareConfig) ToMiddleware() (echo.MiddlewareFunc, error) { url = c.Request().URL.Path } - status := c.Response().Status - if err != nil { - var httpError *echo.HTTPError - if errors.As(err, &httpError) { - status = httpError.Code - } - if status == 0 || status == http.StatusOK { - status = http.StatusInternalServerError - } - } + status := conf.StatusCodeResolver(c, err) values := make([]string, len(labelNames)) values[0] = strconv.Itoa(status) @@ -458,3 +455,18 @@ func WriteGatheredMetrics(writer io.Writer, gatherer prometheus.Gatherer) error } return nil } + +// defaultStatusResolver resolves http status code by referencing echo.HTTPError. +func defaultStatusResolver(c echo.Context, err error) int { + status := c.Response().Status + if err != nil { + var httpError *echo.HTTPError + if errors.As(err, &httpError) { + status = httpError.Code + } + if status == 0 || status == http.StatusOK { + status = http.StatusInternalServerError + } + } + return status +} diff --git a/echoprometheus/prometheus_test.go b/echoprometheus/prometheus_test.go index 24ee2fa..9edd012 100644 --- a/echoprometheus/prometheus_test.go +++ b/echoprometheus/prometheus_test.go @@ -161,6 +161,64 @@ func TestMiddlewareConfig_LabelFuncs(t *testing.T) { assert.Contains(t, body, `echo_request_duration_seconds_count{code="200",host="example.com",method="overridden_GET",scheme="http",url="/ok"} 1`) } +func TestMiddlewareConfig_StatusCodeResolver(t *testing.T) { + e := echo.New() + customRegistry := prometheus.NewRegistry() + customResolver := func(c echo.Context, err error) int { + if err == nil { + return c.Response().Status + } + msg := err.Error() + if strings.Contains(msg, "NOT FOUND") { + return http.StatusNotFound + } + if strings.Contains(msg, "NOT Authorized") { + return http.StatusUnauthorized + } + return http.StatusInternalServerError + } + e.Use(NewMiddlewareWithConfig(MiddlewareConfig{ + Skipper: func(c echo.Context) bool { + return strings.HasSuffix(c.Path(), "ignore") + }, + Subsystem: "myapp", + Registerer: customRegistry, + StatusCodeResolver: customResolver, + })) + e.GET("/metrics", NewHandlerWithConfig(HandlerConfig{Gatherer: customRegistry})) + + e.GET("/handler_for_ok", func(c echo.Context) error { + return c.JSON(http.StatusOK, "OK") + }) + e.GET("/handler_for_nok", func(c echo.Context) error { + return c.JSON(http.StatusConflict, "NOK") + }) + e.GET("/handler_for_not_found", func(c echo.Context) error { + return errors.New("NOT FOUND") + }) + e.GET("/handler_for_not_authorized", func(c echo.Context) error { + return errors.New("NOT Authorized") + }) + e.GET("/handler_for_unknown_error", func(c echo.Context) error { + return errors.New("i do not know") + }) + + assert.Equal(t, http.StatusOK, request(e, "/handler_for_ok")) + assert.Equal(t, http.StatusConflict, request(e, "/handler_for_nok")) + assert.Equal(t, http.StatusInternalServerError, request(e, "/handler_for_not_found")) + assert.Equal(t, http.StatusInternalServerError, request(e, "/handler_for_not_authorized")) + assert.Equal(t, http.StatusInternalServerError, request(e, "/handler_for_unknown_error")) + + body, code := requestBody(e, "/metrics") + assert.Equal(t, http.StatusOK, code) + assert.Contains(t, body, fmt.Sprintf("%s_requests_total", "myapp")) + assert.Contains(t, body, `myapp_requests_total{code="200",host="example.com",method="GET",url="/handler_for_ok"} 1`) + assert.Contains(t, body, `myapp_requests_total{code="409",host="example.com",method="GET",url="/handler_for_nok"} 1`) + assert.Contains(t, body, `myapp_requests_total{code="404",host="example.com",method="GET",url="/handler_for_not_found"} 1`) + assert.Contains(t, body, `myapp_requests_total{code="401",host="example.com",method="GET",url="/handler_for_not_authorized"} 1`) + assert.Contains(t, body, `myapp_requests_total{code="500",host="example.com",method="GET",url="/handler_for_unknown_error"} 1`) +} + func TestMiddlewareConfig_HistogramOptsFunc(t *testing.T) { e := echo.New() customRegistry := prometheus.NewRegistry()