Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore] Manual implementation of extensionauth.Client interface #38451

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions extension/asapauthextension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,31 @@ import (

"bitbucket.org/atlassian/go-asap/v2"
"github.com/SermoDigital/jose/crypto"
"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/extension/extensionauth"
"google.golang.org/grpc/credentials"
)

var _ extensionauth.Client = (*asapAuthExtension)(nil)

type asapAuthExtension struct {
component.StartFunc
component.ShutdownFunc

provisioner asap.Provisioner
privateKey any
}

// PerRPCCredentials returns extensionauth.Client.
func (e *asapAuthExtension) PerRPCCredentials() (credentials.PerRPCCredentials, error) {
return &perRPCAuth{provisioner: e.provisioner, privateKey: e.privateKey}, nil
}

// RoundTripper implements extensionauth.Client.
func (e *asapAuthExtension) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) {
return asap.NewTransportDecorator(e.provisioner, e.privateKey)(base), nil
}

func createASAPClientAuthenticator(cfg *Config) (extensionauth.Client, error) {
pk, err := asap.NewPrivateKey([]byte(cfg.PrivateKey))
if err != nil {
Expand All @@ -24,14 +45,10 @@ func createASAPClientAuthenticator(cfg *Config) (extensionauth.Client, error) {
p := asap.NewCachingProvisioner(asap.NewProvisioner(
cfg.KeyID, cfg.TTL, cfg.Issuer, cfg.Audience, crypto.SigningMethodRS256))

return extensionauth.NewClient(
extensionauth.WithClientRoundTripper(func(base http.RoundTripper) (http.RoundTripper, error) {
return asap.NewTransportDecorator(p, pk)(base), nil
}),
extensionauth.WithClientPerRPCCredentials(func() (credentials.PerRPCCredentials, error) {
return &perRPCAuth{provisioner: p, privateKey: pk}, nil
}),
)
return &asapAuthExtension{
provisioner: p,
privateKey: pk,
}, nil
}

// perRPCAuth is a gRPC credentials.PerRPCCredentials implementation that returns an 'authorization' header.
Expand Down
49 changes: 25 additions & 24 deletions extension/basicauthextension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,29 @@ var (
errInvalidFormat = errors.New("invalid authorization format")
)

type basicAuth struct {
htpasswd *HtpasswdSettings
clientAuth *ClientAuthSettings
matchFunc func(username, password string) bool
}

func newClientAuthExtension(cfg *Config) (extensionauth.Client, error) {
ba := basicAuth{
clientAuth: cfg.ClientAuth,
}
return extensionauth.NewClient(
extensionauth.WithClientRoundTripper(ba.roundTripper),
extensionauth.WithClientPerRPCCredentials(ba.perRPCCredentials),
)
func newClientAuthExtension(cfg *Config) extensionauth.Client {
return &basicAuthClient{clientAuth: cfg.ClientAuth}
}

func newServerAuthExtension(cfg *Config) (extensionauth.Server, error) {
if cfg.Htpasswd == nil || (cfg.Htpasswd.File == "" && cfg.Htpasswd.Inline == "") {
return nil, errNoCredentialSource
}

ba := basicAuth{
return &basicAuthServer{
htpasswd: cfg.Htpasswd,
}
return extensionauth.NewServer(
extensionauth.WithServerStart(ba.serverStart),
extensionauth.WithServerAuthenticate(ba.authenticate),
)
}, nil
}

func (ba *basicAuth) serverStart(_ context.Context, _ component.Host) error {
var _ extensionauth.Server = (*basicAuthServer)(nil)

type basicAuthServer struct {
htpasswd *HtpasswdSettings
matchFunc func(username, password string) bool
component.ShutdownFunc
}

func (ba *basicAuthServer) Start(_ context.Context, _ component.Host) error {
var rs []io.Reader

if ba.htpasswd.File != "" {
Expand Down Expand Up @@ -86,7 +78,7 @@ func (ba *basicAuth) serverStart(_ context.Context, _ component.Host) error {
return nil
}

func (ba *basicAuth) authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) {
func (ba *basicAuthServer) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) {
auth := getAuthHeader(headers)
if auth == "" {
return ctx, errNoAuth
Expand Down Expand Up @@ -209,7 +201,16 @@ func (b *basicAuthRoundTripper) RoundTrip(request *http.Request) (*http.Response
return b.base.RoundTrip(newRequest)
}

func (ba *basicAuth) roundTripper(base http.RoundTripper) (http.RoundTripper, error) {
var _ extensionauth.Client = (*basicAuthClient)(nil)

type basicAuthClient struct {
component.StartFunc
component.ShutdownFunc

clientAuth *ClientAuthSettings
}

func (ba *basicAuthClient) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) {
if strings.Contains(ba.clientAuth.Username, ":") {
return nil, errInvalidFormat
}
Expand All @@ -219,7 +220,7 @@ func (ba *basicAuth) roundTripper(base http.RoundTripper) (http.RoundTripper, er
}, nil
}

func (ba *basicAuth) perRPCCredentials() (creds.PerRPCCredentials, error) {
func (ba *basicAuthClient) PerRPCCredentials() (creds.PerRPCCredentials, error) {
if strings.Contains(ba.clientAuth.Username, ":") {
return nil, errInvalidFormat
}
Expand Down
8 changes: 3 additions & 5 deletions extension/basicauthextension/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,12 @@ func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
}

func TestBasicAuth_ClientValid(t *testing.T) {
ext, err := newClientAuthExtension(&Config{
ext := newClientAuthExtension(&Config{
ClientAuth: &ClientAuthSettings{
Username: "username",
Password: "password",
},
})
require.NoError(t, err)
require.NotNil(t, ext)

require.NoError(t, ext.Start(context.Background(), componenttest.NewNopHost()))
Expand Down Expand Up @@ -273,19 +272,18 @@ func TestBasicAuth_ClientValid(t *testing.T) {

func TestBasicAuth_ClientInvalid(t *testing.T) {
t.Run("invalid username format", func(t *testing.T) {
ext, err := newClientAuthExtension(&Config{
ext := newClientAuthExtension(&Config{
ClientAuth: &ClientAuthSettings{
Username: "user:name",
Password: "password",
},
})
require.NoError(t, err)
require.NotNil(t, ext)

require.NoError(t, ext.Start(context.Background(), componenttest.NewNopHost()))

base := &mockRoundTripper{}
_, err = ext.RoundTripper(base)
_, err := ext.RoundTripper(base)
assert.Error(t, err)

_, err = ext.PerRPCCredentials()
Expand Down
2 changes: 1 addition & 1 deletion extension/basicauthextension/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ func createExtension(_ context.Context, _ extension.Settings, cfg component.Conf
if cfg.(*Config).Htpasswd != nil {
return newServerAuthExtension(cfg.(*Config))
}
return newClientAuthExtension(cfg.(*Config))
return newClientAuthExtension(cfg.(*Config)), nil
}
36 changes: 24 additions & 12 deletions extension/headerssetterextension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net/http"

"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/extension/extensionauth"
"go.uber.org/zap"
"google.golang.org/grpc/credentials"
Expand All @@ -22,6 +23,28 @@ type Header struct {
source source.Source
}

var _ extensionauth.Client = (*headerSetterExtension)(nil)

type headerSetterExtension struct {
component.StartFunc
component.ShutdownFunc

headers []Header
}

// PerRPCCredentials implements extensionauth.Client.
func (h *headerSetterExtension) PerRPCCredentials() (credentials.PerRPCCredentials, error) {
return &headersPerRPC{headers: h.headers}, nil
}

// RoundTripper implements extensionauth.Client.
func (h *headerSetterExtension) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) {
return &headersRoundTripper{
base: base,
headers: h.headers,
}, nil
}

func newHeadersSetterExtension(cfg *Config, logger *zap.Logger) (extensionauth.Client, error) {
if cfg == nil {
return nil, errors.New("extension configuration is not provided")
Expand Down Expand Up @@ -63,18 +86,7 @@ func newHeadersSetterExtension(cfg *Config, logger *zap.Logger) (extensionauth.C
headers = append(headers, Header{action: a, source: s})
}

return extensionauth.NewClient(
extensionauth.WithClientRoundTripper(
func(base http.RoundTripper) (http.RoundTripper, error) {
return &headersRoundTripper{
base: base,
headers: headers,
}, nil
}),
extensionauth.WithClientPerRPCCredentials(func() (credentials.PerRPCCredentials, error) {
return &headersPerRPC{headers: headers}, nil
}),
)
return &headerSetterExtension{headers: headers}, nil
}

// headersPerRPC is a gRPC credentials.PerRPCCredentials implementation sets
Expand Down
15 changes: 11 additions & 4 deletions extension/oauth2clientauthextension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"net/http"

"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/extension/extensionauth"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/oauth2"
Expand All @@ -16,9 +18,14 @@ import (
grpcOAuth "google.golang.org/grpc/credentials/oauth"
)

var _ extensionauth.Client = (*clientAuthenticator)(nil)

// clientAuthenticator provides implementation for providing client authentication using OAuth2 client credentials
// workflow for both gRPC and HTTP clients.
type clientAuthenticator struct {
component.StartFunc
component.ShutdownFunc

clientCredentials *clientCredentialsConfig
logger *zap.Logger
client *http.Client
Expand Down Expand Up @@ -75,9 +82,9 @@ func (ewts errorWrappingTokenSource) Token() (*oauth2.Token, error) {
return tok, nil
}

// roundTripper returns oauth2.Transport, an http.RoundTripper that performs "client-credential" OAuth flow and
// RoundTripper returns oauth2.Transport, an http.RoundTripper that performs "client-credential" OAuth flow and
// also auto refreshes OAuth tokens as needed.
func (o *clientAuthenticator) roundTripper(base http.RoundTripper) (http.RoundTripper, error) {
func (o *clientAuthenticator) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) {
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.client)
return &oauth2.Transport{
Source: errorWrappingTokenSource{
Expand All @@ -88,9 +95,9 @@ func (o *clientAuthenticator) roundTripper(base http.RoundTripper) (http.RoundTr
}, nil
}

// perRPCCredentials returns gRPC PerRPCCredentials that supports "client-credential" OAuth flow. The underneath
// PerRPCCredentials returns gRPC PerRPCCredentials that supports "client-credential" OAuth flow. The underneath
// oauth2.clientcredentials.Config instance will manage tokens performing auto refresh as necessary.
func (o *clientAuthenticator) perRPCCredentials() (credentials.PerRPCCredentials, error) {
func (o *clientAuthenticator) PerRPCCredentials() (credentials.PerRPCCredentials, error) {
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.client)
return grpcOAuth.TokenSource{
TokenSource: errorWrappingTokenSource{
Expand Down
8 changes: 4 additions & 4 deletions extension/oauth2clientauthextension/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func TestRoundTripper(t *testing.T) {
}

assert.NotNil(t, oauth2Authenticator)
roundTripper, err := oauth2Authenticator.roundTripper(baseRoundTripper)
roundTripper, err := oauth2Authenticator.RoundTripper(baseRoundTripper)
assert.NoError(t, err)

// test roundTripper is an OAuth RoundTripper
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestOAuth2PerRPCCredentials(t *testing.T) {
return
}
assert.NoError(t, err)
perRPCCredentials, err := oauth2Authenticator.perRPCCredentials()
perRPCCredentials, err := oauth2Authenticator.PerRPCCredentials()
assert.NoError(t, err)
// test perRPCCredentials is an grpc OAuthTokenSource
_, ok := perRPCCredentials.(grpcOAuth.TokenSource)
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestFailContactingOAuth(t *testing.T) {
require.NoError(t, err)

// Test for gRPC connections
credential, err := oauth2Authenticator.perRPCCredentials()
credential, err := oauth2Authenticator.PerRPCCredentials()
require.NoError(t, err)

_, err = credential.GetRequestMetadata(context.Background())
Expand All @@ -303,7 +303,7 @@ func TestFailContactingOAuth(t *testing.T) {

transport := http.DefaultTransport.(*http.Transport).Clone()
baseRoundTripper := (http.RoundTripper)(transport)
roundTripper, err := oauth2Authenticator.roundTripper(baseRoundTripper)
roundTripper, err := oauth2Authenticator.RoundTripper(baseRoundTripper)
require.NoError(t, err)

client := &http.Client{
Expand Down
11 changes: 1 addition & 10 deletions extension/oauth2clientauthextension/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/extension"
"go.opentelemetry.io/collector/extension/extensionauth"

"github.com/open-telemetry/opentelemetry-collector-contrib/extension/oauth2clientauthextension/internal/metadata"
)
Expand All @@ -31,13 +30,5 @@ func createDefaultConfig() component.Config {
}

func createExtension(_ context.Context, set extension.Settings, cfg component.Config) (extension.Extension, error) {
ca, err := newClientAuthenticator(cfg.(*Config), set.Logger)
if err != nil {
return nil, err
}

return extensionauth.NewClient(
extensionauth.WithClientRoundTripper(ca.roundTripper),
extensionauth.WithClientPerRPCCredentials(ca.perRPCCredentials),
)
return newClientAuthenticator(cfg.(*Config), set.Logger)
}
17 changes: 7 additions & 10 deletions extension/oidcauthextension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"go.uber.org/zap"
)

var _ extensionauth.Server = (*oidcExtension)(nil)

type oidcExtension struct {
cfg *Config

Expand All @@ -45,23 +47,18 @@ var (
errNotAuthenticated = errors.New("authentication didn't succeed")
)

func newExtension(cfg *Config, logger *zap.Logger) (extensionauth.Server, error) {
func newExtension(cfg *Config, logger *zap.Logger) extensionauth.Server {
if cfg.Attribute == "" {
cfg.Attribute = defaultAttribute
}

oe := &oidcExtension{
return &oidcExtension{
cfg: cfg,
logger: logger,
}
return extensionauth.NewServer(
extensionauth.WithServerStart(oe.start),
extensionauth.WithServerAuthenticate(oe.authenticate),
extensionauth.WithServerShutdown(oe.shutdown),
)
}

func (e *oidcExtension) start(ctx context.Context, _ component.Host) error {
func (e *oidcExtension) Start(ctx context.Context, _ component.Host) error {
err := e.setProviderConfig(ctx, e.cfg)
if err != nil {
return fmt.Errorf("failed to get configuration from the auth server: %w", err)
Expand All @@ -72,7 +69,7 @@ func (e *oidcExtension) start(ctx context.Context, _ component.Host) error {
return nil
}

func (e *oidcExtension) shutdown(context.Context) error {
func (e *oidcExtension) Shutdown(context.Context) error {
if e.client != nil {
e.client.CloseIdleConnections()
}
Expand All @@ -84,7 +81,7 @@ func (e *oidcExtension) shutdown(context.Context) error {
}

// authenticate checks whether the given context contains valid auth data. Successfully authenticated calls will always return a nil error and a context with the auth data.
func (e *oidcExtension) authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) {
func (e *oidcExtension) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) {
var authHeaders []string
for k, v := range headers {
if strings.EqualFold(k, e.cfg.Attribute) {
Expand Down
Loading