Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/blobstore/configuration/ac_blob_access_creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (bac *acBlobAccessCreator) NewCustomBlobAccess(terminationGroup program.Gro
DigestKeyFormat: base.DigestKeyFormat.Combine(bac.contentAddressableStorage.DigestKeyFormat),
}, "completeness_checking", nil
case *pb.BlobAccessConfiguration_Grpc:
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc, terminationGroup)
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc.Client, terminationGroup)
if err != nil {
return BlobAccessInfo{}, "", err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/blobstore/configuration/cas_blob_access_creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ func (bac *casBlobAccessCreator) NewCustomBlobAccess(terminationGroup program.Gr
DigestKeyFormat: base.DigestKeyFormat,
}, "existence_caching", nil
case *pb.BlobAccessConfiguration_Grpc:
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc, terminationGroup)
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc.Client, terminationGroup)
if err != nil {
return BlobAccessInfo{}, "", err
}
// TODO: Should we provide a configuration option, so
// that digest.KeyWithoutInstance can be used?
return BlobAccessInfo{
BlobAccess: grpcclients.NewCASBlobAccess(client, uuid.NewRandom, 65536),
BlobAccess: grpcclients.NewCASBlobAccess(client, uuid.NewRandom, 64<<10, backend.Grpc.EnableCompression),
DigestKeyFormat: digest.KeyWithInstance,
}, "grpc", nil
case *pb.BlobAccessConfiguration_ReferenceExpanding:
Expand Down
2 changes: 1 addition & 1 deletion pkg/blobstore/configuration/fsac_blob_access_creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (fsacBlobAccessCreator) GetDefaultCapabilitiesProvider() capabilities.Provi
func (bac *fsacBlobAccessCreator) NewCustomBlobAccess(terminationGroup program.Group, configuration *pb.BlobAccessConfiguration, nestedCreator NestedBlobAccessCreator) (BlobAccessInfo, string, error) {
switch backend := configuration.Backend.(type) {
case *pb.BlobAccessConfiguration_Grpc:
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc, terminationGroup)
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc.Client, terminationGroup)
if err != nil {
return BlobAccessInfo{}, "", err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/blobstore/configuration/icas_blob_access_creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (icasBlobAccessCreator) GetDefaultCapabilitiesProvider() capabilities.Provi
func (bac *icasBlobAccessCreator) NewCustomBlobAccess(terminationGroup program.Group, configuration *pb.BlobAccessConfiguration, nestedCreator NestedBlobAccessCreator) (BlobAccessInfo, string, error) {
switch backend := configuration.Backend.(type) {
case *pb.BlobAccessConfiguration_Grpc:
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc, terminationGroup)
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc.Client, terminationGroup)
if err != nil {
return BlobAccessInfo{}, "", err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/blobstore/configuration/iscc_blob_access_creator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (isccBlobAccessCreator) GetDefaultCapabilitiesProvider() capabilities.Provi
func (bac *isccBlobAccessCreator) NewCustomBlobAccess(terminationGroup program.Group, configuration *pb.BlobAccessConfiguration, nestedCreator NestedBlobAccessCreator) (BlobAccessInfo, string, error) {
switch backend := configuration.Backend.(type) {
case *pb.BlobAccessConfiguration_Grpc:
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc, terminationGroup)
client, err := bac.grpcClientFactory.NewClientFromConfiguration(backend.Grpc.Client, terminationGroup)
if err != nil {
return BlobAccessInfo{}, "", err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/blobstore/grpcclients/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ go_library(
"//pkg/util",
"@bazel_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto",
"@com_github_google_uuid//:uuid",
"@com_github_klauspost_compress//zstd",
"@org_golang_google_genproto_googleapis_bytestream//:bytestream",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//codes",
Expand All @@ -43,6 +44,7 @@ go_test(
"@bazel_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto",
"@bazel_remote_apis//build/bazel/semver:semver_go_proto",
"@com_github_google_uuid//:uuid",
"@com_github_klauspost_compress//zstd",
"@com_github_stretchr_testify//require",
"@org_golang_google_genproto_googleapis_bytestream//:bytestream",
"@org_golang_google_grpc//:grpc",
Expand Down
219 changes: 209 additions & 10 deletions pkg/blobstore/grpcclients/cas_blob_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package grpcclients
import (
"context"
"io"
"slices"
"sync"
"sync/atomic"

remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
"github.com/buildbarn/bb-storage/pkg/blobstore"
Expand All @@ -11,10 +14,13 @@ import (
"github.com/buildbarn/bb-storage/pkg/digest"
"github.com/buildbarn/bb-storage/pkg/util"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"

"google.golang.org/genproto/googleapis/bytestream"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

type casBlobAccess struct {
Expand All @@ -23,20 +29,26 @@ type casBlobAccess struct {
capabilitiesClient remoteexecution.CapabilitiesClient
uuidGenerator util.UUIDGenerator
readChunkSize int
enableZSTDCompression bool
supportedCompressors atomic.Pointer[[]remoteexecution.Compressor_Value]
}

// NewCASBlobAccess creates a BlobAccess handle that relays any requests
// to a GRPC service that implements the bytestream.ByteStream and
// to a gRPC service that implements the bytestream.ByteStream and
// remoteexecution.ContentAddressableStorage services. Those are the
// services that Bazel uses to access blobs stored in the Content
// Addressable Storage.
func NewCASBlobAccess(client grpc.ClientConnInterface, uuidGenerator util.UUIDGenerator, readChunkSize int) blobstore.BlobAccess {
//
// If enableZSTDCompression is true, the client will use ZSTD compression
// for ByteStream operations if the server supports it.
func NewCASBlobAccess(client grpc.ClientConnInterface, uuidGenerator util.UUIDGenerator, readChunkSize int, enableZSTDCompression bool) blobstore.BlobAccess {
return &casBlobAccess{
byteStreamClient: bytestream.NewByteStreamClient(client),
contentAddressableStorageClient: remoteexecution.NewContentAddressableStorageClient(client),
capabilitiesClient: remoteexecution.NewCapabilitiesClient(client),
uuidGenerator: uuidGenerator,
readChunkSize: readChunkSize,
enableZSTDCompression: enableZSTDCompression,
}
}

Expand All @@ -62,11 +74,140 @@ func (r *byteStreamChunkReader) Close() {
}
}

type zstdByteStreamChunkReader struct {
client bytestream.ByteStream_ReadClient
cancel context.CancelFunc
zstdReader io.ReadCloser
readChunkSize int
wg sync.WaitGroup
}

func (r *zstdByteStreamChunkReader) Read() ([]byte, error) {
if r.zstdReader == nil {
pr, pw := io.Pipe()

r.wg.Add(1)
go func() {
defer r.wg.Done()
defer pw.Close()
for {
chunk, err := r.client.Recv()
if err != nil {
if err != io.EOF {
pw.CloseWithError(err)
}
return
}
if _, writeErr := pw.Write(chunk.Data); writeErr != nil {
pw.CloseWithError(writeErr)
return
}
}
}()

var err error
r.zstdReader, err = util.NewZstdReadCloser(pr, zstd.WithDecoderConcurrency(1))
if err != nil {
pr.Close()
return nil, err
}
}

buf := make([]byte, r.readChunkSize)
n, err := r.zstdReader.Read(buf)
if n > 0 {
if err != nil && err != io.EOF {
err = nil
}
return buf[:n], err
}
return nil, err
}

func (r *zstdByteStreamChunkReader) Close() {
if r.zstdReader != nil {
r.zstdReader.Close()
}
r.cancel()

// Drain the gRPC stream.
for {
if _, err := r.client.Recv(); err != nil {
break
}
}
r.wg.Wait()
}

type zstdByteStreamWriter struct {
client bytestream.ByteStream_WriteClient
resourceName string
writeOffset int64
cancel context.CancelFunc
}

func (w *zstdByteStreamWriter) Write(p []byte) (int, error) {
if err := w.client.Send(&bytestream.WriteRequest{
ResourceName: w.resourceName,
WriteOffset: w.writeOffset,
Data: p,
}); err != nil {
return 0, err
}
w.writeOffset += int64(len(p))
w.resourceName = ""
return len(p), nil
}

func (w *zstdByteStreamWriter) Close() error {
if err := w.client.Send(&bytestream.WriteRequest{
ResourceName: w.resourceName,
WriteOffset: w.writeOffset,
FinishWrite: true,
}); err != nil {
w.cancel()
w.client.CloseAndRecv()
return err
}
_, err := w.client.CloseAndRecv()
w.cancel()
return err
}

const resourceNameHeader = "build.bazel.remote.execution.v2.resource-name"

// shouldUseZSTDCompression checks if ZSTD compression should be used.
// It ensures GetCapabilities has been called to negotiate compression support.
func (ba *casBlobAccess) shouldUseZSTDCompression(ctx context.Context, digest digest.Digest) (bool, error) {
if !ba.enableZSTDCompression {
return false, nil
}

supportedCompressors := ba.supportedCompressors.Load()
if supportedCompressors == nil {
// Call GetCapabilities to check server support.
if _, err := ba.GetCapabilities(ctx, digest.GetDigestFunction().GetInstanceName()); err != nil {
return false, err
}
supportedCompressors = ba.supportedCompressors.Load()
}

return slices.Contains(*supportedCompressors, remoteexecution.Compressor_ZSTD), nil
}

func (ba *casBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.Buffer {
useCompression, err := ba.shouldUseZSTDCompression(ctx, digest)
if err != nil {
return buffer.NewBufferFromError(err)
}

compressor := remoteexecution.Compressor_IDENTITY
if useCompression {
compressor = remoteexecution.Compressor_ZSTD
}

ctxWithCancel, cancel := context.WithCancel(ctx)
resourceName := digest.GetByteStreamReadPath(remoteexecution.Compressor_IDENTITY)
resourceName := digest.GetByteStreamReadPath(compressor)
client, err := ba.byteStreamClient.Read(
metadata.AppendToOutgoingContext(ctxWithCancel, resourceNameHeader, resourceName),
&bytestream.ReadRequest{
Expand All @@ -77,6 +218,15 @@ func (ba *casBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.B
cancel()
return buffer.NewBufferFromError(err)
}

if useCompression {
return buffer.NewCASBufferFromChunkReader(digest, &zstdByteStreamChunkReader{
client: client,
cancel: cancel,
readChunkSize: ba.readChunkSize,
}, buffer.BackendProvided(buffer.Irreparable(digest)))
}

return buffer.NewCASBufferFromChunkReader(digest, &byteStreamChunkReader{
client: client,
cancel: cancel,
Expand All @@ -89,19 +239,61 @@ func (ba *casBlobAccess) GetFromComposite(ctx context.Context, parentDigest, chi
}

func (ba *casBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer.Buffer) error {
r := b.ToChunkReader(0, ba.readChunkSize)
defer r.Close()
useCompression, err := ba.shouldUseZSTDCompression(ctx, digest)
if err != nil {
b.Discard()
return err
}

compressor := remoteexecution.Compressor_IDENTITY
if useCompression {
compressor = remoteexecution.Compressor_ZSTD
}

ctxWithCancel, cancel := context.WithCancel(ctx)
resourceName := digest.GetByteStreamWritePath(uuid.Must(ba.uuidGenerator()), remoteexecution.Compressor_IDENTITY)
resourceName := digest.GetByteStreamWritePath(uuid.Must(ba.uuidGenerator()), compressor)
client, err := ba.byteStreamClient.Write(
metadata.AppendToOutgoingContext(ctxWithCancel, resourceNameHeader, resourceName),
)
if err != nil {
cancel()
b.Discard()
return err
}

if useCompression {
byteStreamWriter := &zstdByteStreamWriter{
client: client,
resourceName: resourceName,
writeOffset: 0,
cancel: cancel,
}

zstdWriter, err := zstd.NewWriter(byteStreamWriter, zstd.WithEncoderConcurrency(1))
if err != nil {
cancel()
client.CloseAndRecv()
return status.Errorf(codes.Internal, "Failed to create zstd writer: %v", err)
}

if err := b.IntoWriter(zstdWriter); err != nil {
zstdWriter.Close()
byteStreamWriter.Close()
return err
}

if err := zstdWriter.Close(); err != nil {
byteStreamWriter.Close()
return err
}

return byteStreamWriter.Close()
}

// Non-compressed path
r := b.ToChunkReader(0, ba.readChunkSize)
defer r.Close()

writeOffset := int64(0)
for {
if data, err := r.Read(); err == nil {
Expand Down Expand Up @@ -140,6 +332,10 @@ func (ba *casBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer
}

func (ba *casBlobAccess) FindMissing(ctx context.Context, digests digest.Set) (digest.Set, error) {
return findMissingBlobsInternal(ctx, digests, ba.contentAddressableStorageClient)
}

func findMissingBlobsInternal(ctx context.Context, digests digest.Set, cas remoteexecution.ContentAddressableStorageClient) (digest.Set, error) {
// Partition all digests by digest function, as the
// FindMissingBlobs() RPC can only process digests for a single
// instance name and digest function.
Expand All @@ -157,7 +353,7 @@ func (ba *casBlobAccess) FindMissing(ctx context.Context, digests digest.Set) (d
BlobDigests: blobDigests,
DigestFunction: digestFunction.GetEnumValue(),
}
response, err := ba.contentAddressableStorageClient.FindMissingBlobs(ctx, &request)
response, err := cas.FindMissingBlobs(ctx, &request)
if err != nil {
return digest.EmptySet, err
}
Expand All @@ -180,11 +376,14 @@ func (ba *casBlobAccess) GetCapabilities(ctx context.Context, instanceName diges
return nil, err
}

cacheCapabilities := serverCapabilities.CacheCapabilities

// Store supported compressors for compression negotiation.
ba.supportedCompressors.Store(&cacheCapabilities.SupportedCompressors)

// Only return fields that pertain to the Content Addressable
// Storage. Don't set 'max_batch_total_size_bytes', as we don't
// issue batch operations. The same holds for fields related to
// compression support.
cacheCapabilities := serverCapabilities.CacheCapabilities
// issue batch operations.
return &remoteexecution.ServerCapabilities{
CacheCapabilities: &remoteexecution.CacheCapabilities{
DigestFunctions: digest.RemoveUnsupportedDigestFunctions(cacheCapabilities.DigestFunctions),
Expand Down
Loading