Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
- `max-cpu-loras`: maximum number of LoRAs to store in CPU memory, optional, must be >= than max-loras, default is max-loras
- `max-model-len`: model's context window, maximum number of tokens in a single request including input and output, optional, default is 1024
- `max-num-seqs`: maximum number of sequences per iteration (maximum number of inference requests that could be processed at the same time), default is 5
- `max-num-batched-tokens`: maximum number of batched tokens per iteration. If set, limits the total number of tokens (prompt + max output tokens) that can be processed simultaneously across all running requests. When not set or set to 0, only `max-num-seqs` constraint is enforced, optional, default is 0 (disabled)
- `mode`: the simulator mode, optional, by default `random`
- `echo`: returns the same text that was sent in the request
- `random`: returns a sentence chosen at random from a set of pre-defined sentences
Expand Down
1 change: 1 addition & 0 deletions manifests/basic-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
port: 8001
model: "Qwen/Qwen2-0.5B"
max-num-seqs: 5
max-num-batched-tokens: 1024
mode: "random"
time-to-first-token: 2000
inter-token-latency: 1000
Expand Down
1 change: 1 addition & 0 deletions manifests/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ served-model-name:
max-loras: 2
max-cpu-loras: 5
max-num-seqs: 5
max-num-batched-tokens: 2048
lora-modules:
- '{"name":"lora1","path":"/path/to/lora1"}'
- '{"name":"lora2","path":"/path/to/lora2"}'
Expand Down
5 changes: 5 additions & 0 deletions pkg/llm-d-inference-sim/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type configuration struct {
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
// number of inference requests that could be processed at the same time)
MaxNumSeqs int `yaml:"max-num-seqs"`
// MaxNumBatchedTokens is maximum number of batched tokens per iteration
MaxNumBatchedTokens int `yaml:"max-num-batched-tokens"`
// MaxModelLen is the model's context window, the maximum number of tokens
// in a single request including input and output. Default value is 1024.
MaxModelLen int `yaml:"max-model-len"`
Expand Down Expand Up @@ -164,6 +166,9 @@ func (c *configuration) validate() error {
if c.MaxModelLen < 1 {
return errors.New("max model len cannot be less than 1")
}
if c.MaxNumBatchedTokens < 0 {
return errors.New("max num batched tokens cannot be negative")
}

for _, lora := range c.LoraModules {
if lora.Name == "" {
Expand Down
31 changes: 31 additions & 0 deletions pkg/llm-d-inference-sim/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func createDefaultConfig(model string) *configuration {
c.MaxNumSeqs = 5
c.MaxLoras = 2
c.MaxCPULoras = 5
c.MaxNumBatchedTokens = 2048
c.TimeToFirstToken = 2000
c.InterTokenLatency = 1000
c.KVCacheTransferLatency = 100
Expand Down Expand Up @@ -184,6 +185,7 @@ var _ = Describe("Simulator configuration", func() {
// basic config file does not contain properties related to lora
c.MaxLoras = 1
c.MaxCPULoras = 1
c.MaxNumBatchedTokens = 1024
c.KVCacheTransferLatency = 50
test = testCase{
name: "config file with command line args with time to transfer kv-cache",
Expand Down Expand Up @@ -258,4 +260,33 @@ var _ = Describe("Simulator configuration", func() {
Entry(tests[12].name, tests[12].args),
Entry(tests[13].name, tests[13].args),
)

It("should accept max-num-batched-tokens parameter", func() {
config, err := createSimConfig([]string{
"test",
"--model", qwenModelName,
"--max-num-batched-tokens", "1024",
})
Expect(err).NotTo(HaveOccurred())
Expect(config.MaxNumBatchedTokens).Should(Equal(1024))
})

It("should validate max-num-batched-tokens cannot be negative", func() {
config := newConfig()
config.Model = qwenModelName
config.MaxNumBatchedTokens = -1

err := config.validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("max num batched tokens cannot be negative"))
})

It("should allow max-num-batched-tokens to be zero (disabled)", func() {
config := newConfig()
config.Model = qwenModelName
config.MaxNumBatchedTokens = 0

err := config.validate()
Expect(err).NotTo(HaveOccurred())
})
})
1 change: 1 addition & 0 deletions pkg/llm-d-inference-sim/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ type completionReqCtx struct {
httpReqCtx *fasthttp.RequestCtx
isChatCompletion bool
wg *sync.WaitGroup
processingTokens int
}

// chatCompletionRequest defines structure of /chat/completion request
Expand Down
124 changes: 119 additions & 5 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ type VllmSimulator struct {
nRunningReqs int64
// nWaitingReqs is the number of inference requests that are waiting to be processed
nWaitingReqs int64
// processingTokensCount tracks the total number of tokens being processed by running requests
processingTokensCount int64
// loraInfo is prometheus gauge
loraInfo *prometheus.GaugeVec
// runningRequests is prometheus gauge
Expand All @@ -86,6 +88,8 @@ type VllmSimulator struct {
kvCacheUsagePercentage *prometheus.GaugeVec
// channel for requeasts to be passed to workers
reqChan chan *completionReqCtx
// channel for processing queue, managed by queue manager
processingChan chan *completionReqCtx
// schema validator for tools parameters
toolsValidator *validator
}
Expand All @@ -99,6 +103,7 @@ func New(logger logr.Logger) (*VllmSimulator, error) {
return &VllmSimulator{
logger: logger,
reqChan: make(chan *completionReqCtx, 1000),
processingChan: make(chan *completionReqCtx, 1000),
toolsValidator: toolsValidtor,
}, nil
}
Expand All @@ -117,6 +122,9 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
return err
}

// run queue manager that handles request constraints
go s.queueManager(ctx)

// run request processing workers
for i := 1; i <= s.config.MaxNumSeqs; i++ {
go s.reqProcessingWorker(ctx, i)
Expand Down Expand Up @@ -149,6 +157,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
f.IntVar(&config.Port, "port", config.Port, "Port")
f.StringVar(&config.Model, "model", config.Model, "Currently 'loaded' model")
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
f.IntVar(&config.MaxNumBatchedTokens, "max-num-batched-tokens", config.MaxNumBatchedTokens, "Maximum number of batched tokens per iteration")
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")
Expand Down Expand Up @@ -375,6 +384,58 @@ func (s *VllmSimulator) isLora(model string) bool {
return false
}

// calculateProcessingTokens calculates the total number of processing tokens for a request
// Returns prompt tokens + max output tokens, or MaxModelLen if max_tokens is not specified
func (s *VllmSimulator) calculateProcessingTokens(req completionRequest) int {
promptTokens := req.getNumberOfPromptTokens()
maxCompletionTokens := req.getMaxCompletionTokens()

// If max_tokens is not specified, return the maximum possible tokens (MaxModelLen)
if maxCompletionTokens == nil {
return s.config.MaxModelLen
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it that if maxCompletionTokens is nil, this function should just return s.config.MaxModelLen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved


// If max_tokens is specified, return prompt tokens + specified max completion tokens
return promptTokens + int(*maxCompletionTokens)
}

// canAcceptRequest checks if a new request can be accepted based on max-num-seqs and max-num-batched-tokens constraints
func (s *VllmSimulator) canAcceptRequest(req completionRequest) bool {
currentRunning := atomic.LoadInt64(&s.nRunningReqs)

// Check max-num-seqs constraint
if currentRunning >= int64(s.config.MaxNumSeqs) {
return false
}

// If max-num-batched-tokens is not configured (0), only check max-num-seqs
if s.config.MaxNumBatchedTokens <= 0 {
return true
}

// Calculate tokens needed for this request
requestTokens := s.calculateProcessingTokens(req)
currentTokens := atomic.LoadInt64(&s.processingTokensCount)

// Check max-num-batched-tokens constraint
return currentTokens+int64(requestTokens) <= int64(s.config.MaxNumBatchedTokens)
}

// addRunningRequest adds a request to the running requests tracking
func (s *VllmSimulator) addRunningRequest(reqCtx *completionReqCtx) {
processingTokens := s.calculateProcessingTokens(reqCtx.completionReq)
reqCtx.processingTokens = processingTokens

atomic.AddInt64(&s.processingTokensCount, int64(processingTokens))
atomic.AddInt64(&s.nRunningReqs, 1)
}

// removeRunningRequest removes a request from the running requests tracking
func (s *VllmSimulator) removeRunningRequest(reqCtx *completionReqCtx) {
atomic.AddInt64(&s.processingTokensCount, -int64(reqCtx.processingTokens))
atomic.AddInt64(&s.nRunningReqs, -1)
}

// handleCompletions general completion requests handler, support both text and chat completion APIs
func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) {
vllmReq, err := s.readRequest(ctx, isChatCompletion)
Expand All @@ -400,6 +461,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
return
}

// Validate max-num-batched-tokens constraint - reject requests that would never be accepted
if s.config.MaxNumBatchedTokens > 0 {
requestTokens := s.calculateProcessingTokens(vllmReq)
if requestTokens > s.config.MaxNumBatchedTokens {
s.sendCompletionError(ctx, fmt.Sprintf("Request requires %d tokens, but max-num-batched-tokens is set to %d. This request would never be accepted. Please reduce max_tokens or increase max-num-batched-tokens",
requestTokens, s.config.MaxNumBatchedTokens), "BadRequestError", fasthttp.StatusBadRequest)
return
}
}

var wg sync.WaitGroup
wg.Add(1)
reqCtx := &completionReqCtx{
Expand All @@ -414,15 +485,54 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
wg.Wait()
}

func (s *VllmSimulator) queueManager(ctx context.Context) {
// Use a slice to maintain the queue of waiting requests
var waitingQueue []*completionReqCtx
ticker := time.NewTicker(10 * time.Millisecond) // Check every 10ms if we can process waiting requests
defer ticker.Stop()

for {
select {
case <-ctx.Done():
s.logger.Info("queueManager stopped")
return
case reqCtx := <-s.reqChan:
// Add new request to the waiting queue
waitingQueue = append(waitingQueue, reqCtx)
case <-ticker.C:
// Periodically check if we can process waiting requests
if len(waitingQueue) == 0 {
continue
}

// Try to process requests from the front of the queue
var newQueue []*completionReqCtx
for _, reqCtx := range waitingQueue {
if s.canAcceptRequest(reqCtx.completionReq) {
// Add to running requests tracking
s.addRunningRequest(reqCtx)

// Send to processing channel
s.processingChan <- reqCtx
} else {
// Can't process yet, keep in queue
newQueue = append(newQueue, reqCtx)
}
}
waitingQueue = newQueue
}
}
}

func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
for {
select {
case <-ctx.Done():
s.logger.Info("reqProcessingWorker stopped:", "worker id", id)
return
case reqCtx, ok := <-s.reqChan:
case reqCtx, ok := <-s.processingChan:
if !ok {
s.logger.Info("reqProcessingWorker worker exiting: reqChan closed")
s.logger.Info("reqProcessingWorker worker exiting: processingChan closed")
return
}
atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan)))
Expand All @@ -449,7 +559,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
// TODO - check if this request went to the waiting queue - add it to waiting map
s.reportLoras()
}
atomic.AddInt64(&(s.nRunningReqs), 1)

// Note: we don't increment nRunningReqs here because it's already done in addRunningRequest
s.reportRunningRequests()

var responseTokens []string
Expand Down Expand Up @@ -514,15 +625,18 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
req.doRemotePrefill())
}
}

// Clean up the running request tracking
s.removeRunningRequest(reqCtx)

reqCtx.wg.Done()
}
}
}

// decrease model usage reference number
func (s *VllmSimulator) responseSentCallback(model string) {

atomic.AddInt64(&(s.nRunningReqs), -1)
// Note: nRunningReqs is now decremented in removeRunningRequest
s.reportRunningRequests()

// Only LoRA models require reference-count handling.
Expand Down
Loading