From aa3dcea46233df54f2e5808d27798a6daecaee16 Mon Sep 17 00:00:00 2001 From: KernelDeimos Date: Fri, 22 Nov 2024 15:52:37 -0500 Subject: [PATCH] refactor: central controller for all LLM services Adds AIChatService, an implementor of puter-chat-completion which can delegate to other implementors (implementors that have registered with AIChatService at initialization) based on details of the request. Makes AIChatService the test implementation. AIChatService then delegates to FakeChatService when in test mode. Adds `models()` method to puter-chat-completion. This method, instead of returning only the names of supported models, includes other details such as the cost and maximum output size. Implements `models()` on Claude and XAI. Registers Claude and XAI with AIChatService. --- .../src/modules/puterai/AIChatService.js | 130 ++++++++++++++++++ .../src/modules/puterai/AIInterfaceService.js | 5 + .../src/modules/puterai/AITestModeService.js | 2 +- .../src/modules/puterai/ClaudeService.js | 63 ++++++++- .../src/modules/puterai/FakeChatService.js | 20 ++- .../puterai/OpenAICompletionService.js | 6 - .../src/modules/puterai/PuterAIModule.js | 3 + src/backend/src/modules/puterai/XAIService.js | 36 ++++- .../src/services/drivers/DriverService.js | 33 ++++- src/backend/src/util/context.js | 3 + 10 files changed, 276 insertions(+), 25 deletions(-) create mode 100644 src/backend/src/modules/puterai/AIChatService.js diff --git a/src/backend/src/modules/puterai/AIChatService.js b/src/backend/src/modules/puterai/AIChatService.js new file mode 100644 index 000000000..f63476035 --- /dev/null +++ b/src/backend/src/modules/puterai/AIChatService.js @@ -0,0 +1,130 @@ +const BaseService = require("../../services/BaseService"); +const { Context } = require("../../util/context"); + +class AIChatService extends BaseService { + _construct () { + this.providers = []; + + this.simple_model_list = []; + this.detail_model_list = []; + this.detail_model_map = {}; + } + _init () { + const svc_driver = this.services.get('driver') + + for ( const provider of this.providers ) { + svc_driver.register_service_alias('ai-chat', provider.service_name); + } + } + + async ['__on_boot.consolidation'] () { + // TODO: get models and pricing for each model + for ( const provider of this.providers ) { + const delegate = this.services.get(provider.service_name) + .as('puter-chat-completion'); + + // Populate simple model list + { + const models = await delegate.list(); + this.simple_model_list.push(...models); + } + + // Populate detail model list and map + { + const models = await delegate.models(); + const annotated_models = []; + for ( const model of models ) { + annotated_models.push({ + ...model, + provider: provider.service_name, + }); + } + this.detail_model_list.push(...annotated_models); + for ( const model of annotated_models ) { + if ( this.detail_model_map[model.id] ) { + let array = this.detail_model_map[model.id]; + // replace with array + if ( ! Array.isArray(array) ) { + array = [array]; + this.detail_model_map[model.id] = array; + } + + array.push(model); + continue; + } + + this.detail_model_map[model.id] = model; + } + } + } + } + + register_provider (spec) { + this.providers.push(spec); + } + + static IMPLEMENTS = { + ['driver-capabilities']: { + supports_test_mode (iface, method_name) { + return iface === 'puter-chat-completion' && + method_name === 'complete'; + } + }, + ['puter-chat-completion']: { + async models () { + const delegate = this.get_delegate(); + if ( ! delegate ) return await this.models_(); + return await delegate.models(); + }, + async list () { + const delegate = this.get_delegate(); + if ( ! delegate ) return await this.list_(); + return await delegate.list(); + }, + async complete (parameters) { + const client_driver_call = Context.get('client_driver_call'); + const { test_mode } = client_driver_call; + let { intended_service } = client_driver_call; + + if ( test_mode ) { + intended_service = 'fake-chat'; + } + + if ( intended_service === this.service_name ) { + throw new Error('Calling ai-chat directly is not yet supported'); + } + + const svc_driver = this.services.get('driver'); + const ret = await svc_driver.call_new_({ + actor: Context.get('actor'), + service_name: intended_service, + iface: 'puter-chat-completion', + method: 'complete', + args: parameters, + }); + ret.result.via_ai_chat_service = true; + return ret.result; + } + } + } + + async models_ () { + return this.detail_model_list; + } + + async list_ () { + return this.simple_model_list; + } + + get_delegate () { + const client_driver_call = Context.get('client_driver_call'); + if ( client_driver_call.intended_service === this.service_name ) { + return undefined; + } + console.log('getting service', client_driver_call.intended_service); + const service = this.services.get(client_driver_call.intended_service); + return service.as('puter-chat-completion'); + } +} + +module.exports = { AIChatService }; diff --git a/src/backend/src/modules/puterai/AIInterfaceService.js b/src/backend/src/modules/puterai/AIInterfaceService.js index bd0d90832..17052cbda 100644 --- a/src/backend/src/modules/puterai/AIInterfaceService.js +++ b/src/backend/src/modules/puterai/AIInterfaceService.js @@ -28,6 +28,11 @@ class AIInterfaceService extends BaseService { col_interfaces.set('puter-chat-completion', { description: 'Chatbot.', methods: { + models: { + description: 'List supported models and their details.', + result: { type: 'json' }, + parameters: {}, + }, list: { description: 'List supported models', result: { type: 'json' }, diff --git a/src/backend/src/modules/puterai/AITestModeService.js b/src/backend/src/modules/puterai/AITestModeService.js index 64289acae..97dbfc385 100644 --- a/src/backend/src/modules/puterai/AITestModeService.js +++ b/src/backend/src/modules/puterai/AITestModeService.js @@ -3,7 +3,7 @@ const BaseService = require("../../services/BaseService"); class AITestModeService extends BaseService { async _init () { const svc_driver = this.services.get('driver'); - svc_driver.register_test_service('puter-chat-completion', 'openai-completion'); + svc_driver.register_test_service('puter-chat-completion', 'ai-chat'); } } diff --git a/src/backend/src/modules/puterai/ClaudeService.js b/src/backend/src/modules/puterai/ClaudeService.js index 33efd5a5e..f302a1dd3 100644 --- a/src/backend/src/modules/puterai/ClaudeService.js +++ b/src/backend/src/modules/puterai/ClaudeService.js @@ -22,17 +22,29 @@ class ClaudeService extends BaseService { this.anthropic = new Anthropic({ apiKey: this.config.apiKey }); + + const svc_aiChat = this.services.get('ai-chat'); + svc_aiChat.register_provider({ + service_name: this.service_name, + alias: true, + }); } static IMPLEMENTS = { ['puter-chat-completion']: { + async models () { + return await this.models_(); + }, async list () { - return [ - 'claude-3-5-sonnet-latest', - 'claude-3-5-sonnet-20241022', - 'claude-3-5-sonnet-20240620', - 'claude-3-haiku-20240307', - ]; + const models = await this.models_(); + const model_names = []; + for ( const model of models ) { + model_names.push(model.id); + if ( model.aliases ) { + model_names.push(...model.aliases); + } + } + return model_names; }, async complete ({ messages, stream, model }) { const adapted_messages = []; @@ -112,6 +124,45 @@ class ClaudeService extends BaseService { } } } + + async models_ () { + return [ + { + id: 'claude-3-5-sonnet-20241022', + aliases: ['claude-3-5-sonnet-latest'], + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 300, + output: 1500, + }, + qualitative_speed: 'fast', + max_output: 8192, + training_cutoff: '2024-04', + }, + { + id: 'claude-3-5-sonnet-20240620', + succeeded_by: 'claude-3-5-sonnet-20241022', + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 300, + output: 1500, + }, + }, + { + id: 'claude-3-haiku-20240307', + // aliases: ['claude-3-haiku-latest'], + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 25, + output: 125, + }, + qualitative_speed: 'fastest', + }, + ]; + } } module.exports = { diff --git a/src/backend/src/modules/puterai/FakeChatService.js b/src/backend/src/modules/puterai/FakeChatService.js index 915d39005..500f9d7d5 100644 --- a/src/backend/src/modules/puterai/FakeChatService.js +++ b/src/backend/src/modules/puterai/FakeChatService.js @@ -7,7 +7,19 @@ class FakeChatService extends BaseService { return ['fake']; }, async complete ({ messages, stream, model }) { + const { LoremIpsum } = require('lorem-ipsum'); + const li = new LoremIpsum({ + sentencesPerParagraph: { + max: 8, + min: 4 + }, + wordsPerSentence: { + max: 20, + min: 12 + }, + }); return { + "index": 0, message: { "id": "00000000-0000-0000-0000-000000000000", "type": "message", @@ -16,7 +28,9 @@ class FakeChatService extends BaseService { "content": [ { "type": "text", - "text": "I am a fake AI, I don't know how to respond to anything." + "text": li.generateParagraphs( + Math.floor(Math.random() * 3) + 1 + ) } ], "stop_reason": "end_turn", @@ -25,7 +39,9 @@ class FakeChatService extends BaseService { "input_tokens": 0, "output_tokens": 1 } - } + }, + "logprobs": null, + "finish_reason": "stop" } } } diff --git a/src/backend/src/modules/puterai/OpenAICompletionService.js b/src/backend/src/modules/puterai/OpenAICompletionService.js index 5c4fc1f07..7c170fe31 100644 --- a/src/backend/src/modules/puterai/OpenAICompletionService.js +++ b/src/backend/src/modules/puterai/OpenAICompletionService.js @@ -22,12 +22,6 @@ class OpenAICompletionService extends BaseService { } static IMPLEMENTS = { - ['driver-capabilities']: { - supports_test_mode (iface, method_name) { - return iface === 'puter-chat-completion' && - method_name === 'complete'; - } - }, ['puter-chat-completion']: { async list () { return [ diff --git a/src/backend/src/modules/puterai/PuterAIModule.js b/src/backend/src/modules/puterai/PuterAIModule.js index 81f8279b5..8844592df 100644 --- a/src/backend/src/modules/puterai/PuterAIModule.js +++ b/src/backend/src/modules/puterai/PuterAIModule.js @@ -57,6 +57,9 @@ class PuterAIModule extends AdvancedBase { // services.registerService('claude', ClaudeEnoughService); } + const { AIChatService } = require('./AIChatService'); + services.registerService('ai-chat', AIChatService); + const { FakeChatService } = require('./FakeChatService'); services.registerService('fake-chat', FakeChatService); diff --git a/src/backend/src/modules/puterai/XAIService.js b/src/backend/src/modules/puterai/XAIService.js index e5a8febbd..44f584261 100644 --- a/src/backend/src/modules/puterai/XAIService.js +++ b/src/backend/src/modules/puterai/XAIService.js @@ -31,14 +31,29 @@ class XAIService extends BaseService { apiKey: this.global_config.services.xai.apiKey, baseURL: 'https://api.x.ai' }); + + const svc_aiChat = this.services.get('ai-chat'); + svc_aiChat.register_provider({ + service_name: this.service_name, + alias: true, + }); } static IMPLEMENTS = { ['puter-chat-completion']: { + async models () { + return await this.models_(); + }, async list () { - return [ - 'grok-beta', - ]; + const models = await this.models_(); + const model_names = []; + for ( const model of models ) { + model_names.push(model.id); + if ( model.aliases ) { + model_names.push(...model.aliases); + } + } + return model_names; }, async complete ({ messages, stream, model }) { model = this.adapt_model(model); @@ -121,6 +136,21 @@ class XAIService extends BaseService { } } } + + async models_ () { + return [ + { + id: 'grok-beta', + name: 'Grok Beta', + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 500, + output: 1500, + }, + } + ]; + } } module.exports = { diff --git a/src/backend/src/services/drivers/DriverService.js b/src/backend/src/services/drivers/DriverService.js index c9220b67b..52a5ed7da 100644 --- a/src/backend/src/services/drivers/DriverService.js +++ b/src/backend/src/services/drivers/DriverService.js @@ -38,6 +38,7 @@ class DriverService extends BaseService { this.drivers = {}; this.interface_to_implementation = {}; this.interface_to_test_service = {}; + this.service_aliases = {}; } async ['__on_registry.collections'] () { @@ -82,6 +83,10 @@ class DriverService extends BaseService { register_test_service (interface_name, service_name) { this.interface_to_test_service[interface_name] = service_name; } + + register_service_alias (service_name, alias) { + this.service_aliases[alias] = service_name; + } get_interface (interface_name) { const o = {}; @@ -152,6 +157,12 @@ class DriverService extends BaseService { driver = this.interface_to_test_service[iface]; } + const client_driver_call = { + intended_service: driver, + test_mode, + }; + driver = this.service_aliases[driver] ?? driver; + const driver_service_exists = (() => { console.log('CHECKING FOR THIS', driver, iface); return this.services.has(driver) && @@ -165,13 +176,17 @@ class DriverService extends BaseService { if ( test_mode && caps && caps.supports_test_mode(iface, method) ) { skip_usage = true; } - - return await this.call_new_({ - actor, - service, - service_name: driver, - iface, method, args: processed_args, - skip_usage, + + return await Context.sub({ + client_driver_call, + }).arun(async () => { + return await this.call_new_({ + actor, + service, + service_name: driver, + iface, method, args: processed_args, + skip_usage, + }); }); } @@ -261,6 +276,10 @@ class DriverService extends BaseService { iface, method, args, skip_usage, }) { + if ( ! service ) { + service = this.services.get(service_name); + } + const svc_permission = this.services.get('permission'); const reading = await svc_permission.scan( actor, diff --git a/src/backend/src/util/context.js b/src/backend/src/util/context.js index cf68cb4ca..552c64cb3 100644 --- a/src/backend/src/util/context.js +++ b/src/backend/src/util/context.js @@ -63,6 +63,9 @@ class Context { static arun (cb) { return this.get().arun(cb); } + static sub (values, opt_name) { + return this.get().sub(values, opt_name); + } get (k) { return this.values_[k]; }