From 5169d4bb40777bbccdd2841eea47be8b39764a9b Mon Sep 17 00:00:00 2001 From: KernelDeimos Date: Tue, 11 Feb 2025 16:24:57 -0500 Subject: [PATCH] dev: add stream support to Gemini --- .../src/modules/puterai/GeminiService.js | 49 +++++++++++++++--- .../modules/puterai/lib/GeminiSquareHole.js | 51 +++++++++++++++++++ 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/src/backend/src/modules/puterai/GeminiService.js b/src/backend/src/modules/puterai/GeminiService.js index 2cd187e00..4a37d4f59 100644 --- a/src/backend/src/modules/puterai/GeminiService.js +++ b/src/backend/src/modules/puterai/GeminiService.js @@ -1,6 +1,8 @@ const BaseService = require("../../services/BaseService"); const { GoogleGenerativeAI } = require('@google/generative-ai'); const GeminiSquareHole = require("./lib/GeminiSquareHole"); +const { TypedValue } = require("../../services/drivers/meta/Runtime"); +const putility = require("@heyputer/putility"); class GeminiService extends BaseService { async _init () { @@ -38,7 +40,6 @@ class GeminiService extends BaseService { // History is separate, so the last message gets special treatment. const last_message = messages.pop(); - console.log('last message?', last_message) const last_message_parts = last_message.parts.map( part => typeof part === 'string' ? part : part.text ); @@ -47,15 +48,36 @@ class GeminiService extends BaseService { history: messages, }); - const genResult = await chat.sendMessage(last_message_parts) + const usage_calculator = GeminiSquareHole.create_usage_calculator({ + model_details: (await this.models_()).find(m => m.id === model), + }); + + if ( stream ) { + const genResult = await chat.sendMessageStream(last_message_parts) + const stream = genResult.stream; - debugger; - const message = genResult.response.candidates[0]; - message.content = message.content.parts; - message.role = 'assistant'; + const usage_promise = new putility.libs.promise.TeePromise(); + return new TypedValue({ $: 'ai-chat-intermediate' }, { + stream: true, + init_chat_stream: + GeminiSquareHole.create_chat_stream_handler({ + stream, usage_promise, + }), + usage_promise: usage_promise.then(usageMetadata => { + return usage_calculator({ usageMetadata }); + }), + }) + } else { + const genResult = await chat.sendMessage(last_message_parts) - const result = { message }; - return result; + const message = genResult.response.candidates[0]; + message.content = message.content.parts; + message.role = 'assistant'; + + const result = { message }; + result.usage = usage_calculator(genResult.response); + return result; + } } } } @@ -73,6 +95,17 @@ class GeminiService extends BaseService { output: 30, }, }, + { + id: 'gemini-2.0-flash', + name: 'Gemini 2.0 Flash', + context: 131072, + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 10, + output: 40, + }, + }, ]; } } diff --git a/src/backend/src/modules/puterai/lib/GeminiSquareHole.js b/src/backend/src/modules/puterai/lib/GeminiSquareHole.js index 6470d9dae..4df6b2aef 100644 --- a/src/backend/src/modules/puterai/lib/GeminiSquareHole.js +++ b/src/backend/src/modules/puterai/lib/GeminiSquareHole.js @@ -18,4 +18,55 @@ module.exports = class GeminiSquareHole { return messages; } + + static create_usage_calculator = ({ model_details }) => { + return ({ usageMetadata }) => { + const tokens = []; + + tokens.push({ + type: 'prompt', + model: model_details.id, + amount: usageMetadata.promptTokenCount, + cost: model_details.cost.input * usageMetadata.promptTokenCount, + }); + + tokens.push({ + type: 'completion', + model: model_details.id, + amount: usageMetadata.candidatesTokenCount, + cost: model_details.cost.output * usageMetadata.candidatesTokenCount, + }); + + return tokens; + }; + }; + + static create_chat_stream_handler = ({ + stream, // GenerateContentStreamResult:stream + usage_promise, + }) => async ({ chatStream }) => { + const message = chatStream.message(); + let textblock = message.contentBlock({ type: 'text' }); + let last_usage = null; + for await ( const chunk of stream ) { + // This is spread across several lines so that the stack trace + // is more helpful if we get an exception because of an + // inconsistent response from the model. + const candidate = chunk.candidates[0]; + const content = candidate.content; + const parts = content.parts; + for ( const part of parts ) { + const text = part.text; + textblock.addText(text); + } + + last_usage = chunk.usageMetadata; + } + + usage_promise.resolve(last_usage); + + textblock.end(); + message.end(); + chatStream.end(); + } }