diff --git a/src/platform/src/Bridge/Ollama/OllamaClient.php b/src/platform/src/Bridge/Ollama/OllamaClient.php index 8cc893bcb1..2eb3e45f4c 100644 --- a/src/platform/src/Bridge/Ollama/OllamaClient.php +++ b/src/platform/src/Bridge/Ollama/OllamaClient.php @@ -38,8 +38,8 @@ public function supports(Model $model): bool public function request(Model $model, array|string $payload, array $options = []): RawHttpResult { return match (true) { - \in_array(Capability::INPUT_MESSAGES, $model->getCapabilities(), true) => $this->doCompletionRequest($payload, $options), - \in_array(Capability::EMBEDDINGS, $model->getCapabilities(), true) => $this->doEmbeddingsRequest($model, $payload, $options), + $model->supports(Capability::INPUT_MESSAGES) => $this->doCompletionRequest($payload, $options), + $model->supports(Capability::EMBEDDINGS) => $this->doEmbeddingsRequest($model, $payload, $options), default => throw new InvalidArgumentException(\sprintf('Unsupported model "%s": "%s".', $model::class, $model->getName())), }; } diff --git a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php index 25988d82be..f12cda417e 100644 --- a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php +++ b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php @@ -47,9 +47,9 @@ public function convert(RawResultInterface $result, array $options = []): Result : $this->doConvertCompletion($data); } - public function getTokenUsageExtractor(): ?TokenUsageExtractorInterface + public function getTokenUsageExtractor(): TokenUsageExtractorInterface { - return null; + return new TokenUsageExtractor(); } /** diff --git a/src/platform/src/Bridge/Ollama/TokenUsageExtractor.php b/src/platform/src/Bridge/Ollama/TokenUsageExtractor.php new file mode 100644 index 0000000000..bfeb3f7f0f --- /dev/null +++ b/src/platform/src/Bridge/Ollama/TokenUsageExtractor.php @@ -0,0 +1,52 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Ollama; + +use Symfony\AI\Platform\Result\RawResultInterface; +use Symfony\AI\Platform\TokenUsage\TokenUsage; +use Symfony\AI\Platform\TokenUsage\TokenUsageExtractorInterface; +use Symfony\AI\Platform\TokenUsage\TokenUsageInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Guillaume Loulier + */ +final class TokenUsageExtractor implements TokenUsageExtractorInterface +{ + public function extract(RawResultInterface $rawResult, array $options = []): ?TokenUsageInterface + { + $response = $rawResult->getObject(); + if (!$response instanceof ResponseInterface) { + return null; + } + + if ($options['stream'] ?? false) { + foreach ($rawResult->getDataStream() as $chunk) { + if ($chunk['done']) { + return new TokenUsage( + $chunk['prompt_eval_count'], + $chunk['eval_count'] + ); + } + } + + return null; + } + + $payload = $response->toArray(); + + return new TokenUsage( + $payload['prompt_eval_count'], + $payload['eval_count'] + ); + } +} diff --git a/src/platform/tests/Bridge/Ollama/OllamaClientTest.php b/src/platform/tests/Bridge/Ollama/OllamaClientTest.php index f23b7a143f..1c039b8d8d 100644 --- a/src/platform/tests/Bridge/Ollama/OllamaClientTest.php +++ b/src/platform/tests/Bridge/Ollama/OllamaClientTest.php @@ -28,7 +28,7 @@ final class OllamaClientTest extends TestCase { public function testSupportsModel() { - $client = new OllamaClient(new MockHttpClient(), 'http://localhost:1234'); + $client = new OllamaClient(new MockHttpClient(), 'http://127.0.0.1:1234'); $this->assertTrue($client->supports(new Ollama('llama3.2'))); $this->assertFalse($client->supports(new Model('any-model'))); @@ -97,6 +97,8 @@ public function testStreamingIsSupported() 'created_at' => '2025-08-23T10:00:00Z', 'message' => ['role' => 'assistant', 'content' => 'Hello world'], 'done' => true, + 'prompt_eval_count' => 10, + 'eval_count' => 10, ])."\n\n", [ 'response_headers' => [ 'content-type' => 'text/event-stream', diff --git a/src/platform/tests/Bridge/Ollama/TokenUsageExtractorTest.php b/src/platform/tests/Bridge/Ollama/TokenUsageExtractorTest.php new file mode 100644 index 0000000000..4bc966f95c --- /dev/null +++ b/src/platform/tests/Bridge/Ollama/TokenUsageExtractorTest.php @@ -0,0 +1,88 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Ollama; + +use PHPUnit\Framework\MockObject\MockObject; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Ollama\TokenUsageExtractor; +use Symfony\AI\Platform\Result\InMemoryRawResult; +use Symfony\AI\Platform\TokenUsage\TokenUsage; +use Symfony\Contracts\HttpClient\ResponseInterface; + +final class TokenUsageExtractorTest extends TestCase +{ + public function testItHandlesStreamResponsesWithoutProcessing() + { + $extractor = new TokenUsageExtractor(); + + $this->assertNull($extractor->extract(new InMemoryRawResult(), ['stream' => true])); + } + + public function testItDoesNothingWithoutUsageData() + { + $extractor = new TokenUsageExtractor(); + + $this->assertNull($extractor->extract(new InMemoryRawResult(['some' => 'data']))); + } + + public function testItExtractsTokenUsage() + { + $extractor = new TokenUsageExtractor(); + $result = new InMemoryRawResult([], object: $this->createResponseObject()); + + $tokenUsage = $extractor->extract($result); + + $this->assertInstanceOf(TokenUsage::class, $tokenUsage); + $this->assertSame(10, $tokenUsage->getPromptTokens()); + $this->assertSame(10, $tokenUsage->getCompletionTokens()); + } + + public function testItExtractsTokenUsageFromStreamResult() + { + $extractor = new TokenUsageExtractor(); + + $result = new InMemoryRawResult([], [ + [ + 'model' => 'foo', + 'response' => 'First chunk', + 'done' => false, + ], + [ + 'model' => 'foo', + 'response' => 'Hello World!', + 'done' => true, + 'prompt_eval_count' => 10, + 'eval_count' => 10, + ], + ], object: $this->createResponseObject()); + + $tokenUsage = $extractor->extract($result, ['stream' => true]); + + $this->assertInstanceOf(TokenUsage::class, $tokenUsage); + $this->assertSame(10, $tokenUsage->getPromptTokens()); + $this->assertSame(10, $tokenUsage->getCompletionTokens()); + } + + private function createResponseObject(): ResponseInterface|MockObject + { + $response = $this->createStub(ResponseInterface::class); + $response->method('toArray')->willReturn([ + 'model' => 'foo', + 'response' => 'Hello World!', + 'done' => true, + 'prompt_eval_count' => 10, + 'eval_count' => 10, + ]); + + return $response; + } +} diff --git a/src/platform/tests/CachedPlatformTest.php b/src/platform/tests/CachedPlatformTest.php index 7562e95df8..05908f463e 100644 --- a/src/platform/tests/CachedPlatformTest.php +++ b/src/platform/tests/CachedPlatformTest.php @@ -26,7 +26,7 @@ final class CachedPlatformTest extends TestCase { public function testPlatformCanReturnCachedResultWhenCalledTwice() { - $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $httpResponse = $this->createMock(SymfonyHttpResponse::class); $rawHttpResult = new RawHttpResult($httpResponse); $resultConverter = self::createMock(ResultConverterInterface::class); @@ -47,6 +47,8 @@ public function testPlatformCanReturnCachedResultWhenCalledTwice() 'prompt_cache_key' => 'symfony', ]); + $this->assertTrue($deferredResult->getMetadata()->has('cached_at')); + $this->assertSame('test content', $deferredResult->getResult()->getContent()); $secondDeferredResult = $cachedPlatform->invoke('foo', 'bar', [ @@ -54,5 +56,7 @@ public function testPlatformCanReturnCachedResultWhenCalledTwice() ]); $this->assertSame('test content', $secondDeferredResult->getResult()->getContent()); + $this->assertTrue($secondDeferredResult->getMetadata()->has('cached_at')); + $this->assertSame($deferredResult->getMetadata()->get('cached_at'), $secondDeferredResult->getMetadata()->get('cached_at')); } }