Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ SENTRY_ENVIRONMENT=development
# For Google Cloud Routes /storage /email (internal use only)
GOOGLE_APPLICATION_CREDENTIALS=
STORAGE_ROUTE_KEY=
# MAX_UPLOAD_SIZE=104857600
EMAIL_SEND_ROUTE_KEY=
EMAIL_QUEUE_ROUTE_KEY=

Expand Down
23 changes: 23 additions & 0 deletions api/configs/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,26 @@ func GetEnvLimit() int64 {

return limit
}

func GetEnvMaxUploadSize() int64 {
const (
defaultLimit int64 = 30 * 1024 * 1024
hardCapLimit int64 = 50 * 1024 * 1024
)

limitString, exist := os.LookupEnv("MAX_UPLOAD_SIZE")
if !exist {
return defaultLimit
}

limit, err := strconv.ParseInt(limitString, 10, 64)
if err != nil {
return defaultLimit
}

if limit > hardCapLimit {
return hardCapLimit
}

return limit
}
75 changes: 58 additions & 17 deletions api/controllers/storage.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package controllers

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"

"cloud.google.com/go/storage"
"github.com/gin-gonic/gin"
"google.golang.org/api/iterator"

"github.com/UTDNebula/nebula-api/api/configs"
"github.com/UTDNebula/nebula-api/api/schema"
)

Expand All @@ -37,9 +40,14 @@ func getOrCreateBucket(client *storage.Client, bucket string) (*storage.BucketHa
bucketHandle := client.Bucket(schema.BUCKET_PREFIX + bucket)
_, err := bucketHandle.Attrs(ctx)
if err != nil {
err = bucketHandle.Create(ctx, PROJECT_ID, nil)
if err != nil {
return nil, errors.New("failed to create bucket: " + err.Error())

if errors.Is(err, storage.ErrBucketNotExist) {
err = bucketHandle.Create(ctx, PROJECT_ID, nil)
if err != nil {
return nil, errors.New("failed to create bucket: " + err.Error())
}
} else {
return nil, err
}
}
return bucketHandle, nil
Expand Down Expand Up @@ -203,6 +211,33 @@ func ObjectInfo(c *gin.Context) {
func PostObject(c *gin.Context) {
bucket := c.Param("bucket")
objectID := c.Param("objectID")

maxUploadSize := configs.GetEnvMaxUploadSize()

// Force early 413 check via Content-Length if present
if c.Request.ContentLength > maxUploadSize {
respond(c, http.StatusRequestEntityTooLarge, "error", fmt.Sprintf("File too large. Maximum allowed size is %d bytes (%dMB)", maxUploadSize, maxUploadSize/(1024*1024)))
return
}

// Use MaxBytesReader to limit the body
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxUploadSize)

// Read and validate the entire (capped) request body before touching GCS.

Comment thread
mikehquan19 marked this conversation as resolved.
fileBytes, readErr := io.ReadAll(c.Request.Body)
if readErr != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(readErr, &maxBytesErr) {
respond(c, http.StatusRequestEntityTooLarge, "error", fmt.Sprintf("File too large. Maximum allowed size is %d bytes (%dMB)", maxUploadSize, maxUploadSize/(1024*1024)))
return
}
respondWithInternalError(c, readErr)
return
}

fileReader := bytes.NewReader(fileBytes)

client := getClient(c)
ctx := context.Background()

Expand All @@ -212,14 +247,6 @@ func PostObject(c *gin.Context) {
return
}

// Read body as byte stream
fileReader := c.Request.Body
if fileReader == nil {
respond(c, http.StatusBadRequest, "error", "Empty body")
return
}
defer fileReader.Close()

objectHandle := bucketHandle.Object(objectID)
if objectHandle == nil {
respondWithInternalError(c, err)
Expand Down Expand Up @@ -251,11 +278,7 @@ func PostObject(c *gin.Context) {

// Generate public URL
escapedObject := url.PathEscape(objectID)
url := fmt.Sprintf(
"https://storage.googleapis.com/%s/%s",
schema.BUCKET_PREFIX+bucket,
escapedObject,
)
url := fmt.Sprintf("https://storage.googleapis.com/%s/%s", schema.BUCKET_PREFIX+bucket, escapedObject)

objectInfo := schema.ObjectInfoFromAttrs(attrs, url)
respond(c, http.StatusOK, "success", objectInfo)
Expand Down Expand Up @@ -330,10 +353,28 @@ func ObjectSignedURL(c *gin.Context) {
respondWithInternalError(c, err)
return
}

headers := append([]string{}, body.Headers...)
// Upload size limits for signed URL uploads.
if strings.EqualFold(body.Method, http.MethodPut) || strings.EqualFold(body.Method, http.MethodPost) {
maxUploadSize := configs.GetEnvMaxUploadSize()
hasContentLengthRange := false
for _, header := range headers {
if strings.HasPrefix(strings.ToLower(header), "x-goog-content-length-range:") {
hasContentLengthRange = true
break
}
}

if !hasContentLengthRange {
headers = append(headers, fmt.Sprintf("x-goog-content-length-range:0,%d", maxUploadSize))
}
}

opts := &storage.SignedURLOptions{
Scheme: storage.SigningSchemeV4,
Method: body.Method,
Headers: body.Headers,
Headers: headers,
Expires: expirationTime,
}

Expand Down
89 changes: 89 additions & 0 deletions api/max_upload_size_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package main

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

"cloud.google.com/go/storage"
"github.com/UTDNebula/nebula-api/api/controllers"
"github.com/gin-gonic/gin"
)

func TestMaxUploadSize(t *testing.T) {
// Set the environment variable for max upload size (e.g., 100 bytes)
os.Setenv("MAX_UPLOAD_SIZE", "100")
defer os.Unsetenv("MAX_UPLOAD_SIZE")

// Setup Gin
gin.SetMode(gin.TestMode)
router := gin.New()

router.POST("/storage/:bucket/:objectID", func(c *gin.Context) {

c.Set("gcsClient", &storage.Client{})
controllers.PostObject(c)
})

t.Run("Upload within limit", func(t *testing.T) {

defer func() {
if r := recover(); r != nil {

}
}()

data := make([]byte, 50)
req, _ := http.NewRequest("POST", "/storage/test-bucket/small-file", bytes.NewBuffer(data))
req.Header.Set("Content-Type", "application/octet-stream")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

if w.Code == http.StatusRequestEntityTooLarge {
t.Errorf("Expected status NOT 413, got %d", w.Code)
}
})

t.Run("Upload exceeding limit via Content-Length", func(t *testing.T) {
data := make([]byte, 150)
req, _ := http.NewRequest("POST", "/storage/test-bucket/large-file", bytes.NewBuffer(data))
req.Header.Set("Content-Type", "application/octet-stream")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Expected status 413, got %d. Body: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "File too large") {
t.Errorf("Expected error message 'File too large', got %s", w.Body.String())
}
})

t.Run("Upload exceeding limit via Stream", func(t *testing.T) {
pr, pw := io.Pipe()
go func() {
pw.Write(make([]byte, 150))
pw.Close()
}()

req, _ := http.NewRequest("POST", "/storage/test-bucket/stream-file", pr)
req.ContentLength = -1
req.Header.Set("Content-Type", "application/octet-stream")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Expected status 413, got %d. Body: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "File too large") {
t.Errorf("Expected error message 'File too large', got %s", w.Body.String())
}
})
}
6 changes: 3 additions & 3 deletions api/routes/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ func initStorageClient() *storage.Client {
} else {
// We're not running on the cloud, get JSON service account key from .env
encodedCreds, exist := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS")
jsonCredss := []byte(encodedCreds)
if !exist {
log.Println("Error loading 'GOOGLE_APPLICATION_CREDENTIALS' from the .env file, skipping cloud storage routes")
return
}
c, err = storage.NewClient(ctx, option.WithAuthCredentialsJSON(option.ServiceAccount, jsonCredss))
jsonCreds := []byte(encodedCreds)
c, err = storage.NewClient(ctx, option.WithAuthCredentialsJSON(option.ServiceAccount, jsonCreds))
}
if err != nil {
log.Printf("Failed to create GCS client: %v", err)
log.Printf("Error initializing GCS client: %v", err)
return
}
client = c
Expand Down
Loading