diff --git a/src/backend/drivers/ai-chat/ChatCompletionDriver.ts b/src/backend/drivers/ai-chat/ChatCompletionDriver.ts index 1d9d56a34..d7702be90 100644 --- a/src/backend/drivers/ai-chat/ChatCompletionDriver.ts +++ b/src/backend/drivers/ai-chat/ChatCompletionDriver.ts @@ -20,6 +20,7 @@ import { OpenAiResponsesChatProvider } from './providers/openai/OpenAiChatRespon import { OpenRouterProvider } from './providers/openrouter/OpenRouterProvider.js'; import { TogetherAIProvider } from './providers/together/TogetherAIProvider.js'; import { XAIProvider } from './providers/xai/XAIProvider.js'; +import { ZAIProvider } from './providers/zai/ZAIProvider.js'; import type { IChatCompleteResult, IChatModel, @@ -588,6 +589,18 @@ export class ChatCompletionDriver extends PuterDriver { ); } + const zai = providers['zai']; + const zaiKey = readKey(zai); + if (zaiKey) { + this.#providers['zai'] = new ZAIProvider( + { + apiKey: zaiKey, + apiBaseUrl: zai?.apiBaseUrl as string | undefined, + }, + metering, + ); + } + const openrouter = providers['openrouter']; const openrouterKey = readKey(openrouter); if (openrouterKey) { diff --git a/src/backend/drivers/ai-chat/providers/zai/ZAIProvider.ts b/src/backend/drivers/ai-chat/providers/zai/ZAIProvider.ts new file mode 100644 index 000000000..2bc7b776f --- /dev/null +++ b/src/backend/drivers/ai-chat/providers/zai/ZAIProvider.ts @@ -0,0 +1,220 @@ +/* + * Copyright (C) 2024-present Puter Technologies Inc. + * + * This file is part of Puter. + * + * Puter is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { OpenAI } from 'openai'; +import { ChatCompletionCreateParams } from 'openai/resources/index.js'; +import { Context } from '../../../../core/context.js'; +import type { MeteringService } from '../../../../services/metering/MeteringService.js'; +import type { IChatProvider, ICompleteArguments } from '../../types.js'; +import * as OpenAIUtil from '../../utils/OpenAIUtil.js'; +import { ZAI_MODELS } from './models.js'; + +type ZAIConfig = { + apiBaseUrl?: string; + apiKey: string; +}; + +type ZAICustomParams = { + do_sample?: boolean; + request_id?: string; + response_format?: unknown; + stop?: string[]; + thinking?: { + type?: 'enabled' | 'disabled'; + clear_thinking?: boolean; + }; + tool_stream?: boolean; + user_id?: string; +}; + +const asRecord = (value: unknown): Record => + value && typeof value === 'object' && !Array.isArray(value) + ? (value as Record) + : {}; + +export class ZAIProvider implements IChatProvider { + #openai: OpenAI; + + #meteringService: MeteringService; + + #defaultModel = 'glm-5.1'; + + constructor(config: ZAIConfig, meteringService: MeteringService) { + this.#openai = new OpenAI({ + apiKey: config.apiKey, + baseURL: config.apiBaseUrl ?? 'https://api.z.ai/api/paas/v4', + }); + this.#meteringService = meteringService; + } + + getDefaultModel() { + return this.#defaultModel; + } + + models() { + return ZAI_MODELS; + } + + list() { + const modelIds: string[] = []; + for (const model of this.models()) { + modelIds.push(model.id); + if (model.aliases) { + modelIds.push(...model.aliases); + } + } + return modelIds; + } + + async complete( + params: ICompleteArguments, + ): ReturnType { + const { + custom, + max_tokens, + stream, + temperature, + tools, + tool_choice, + top_p, + } = params; + let { messages, model } = params; + const actor = Context.get('actor'); + const availableModels = this.models(); + const modelUsed = + availableModels.find((m) => + [m.id, ...(m.aliases || [])].includes(model), + ) || availableModels.find((m) => m.id === this.getDefaultModel())!; + + messages = await OpenAIUtil.process_input_messages(messages); + messages = messages.map((message) => { + delete message.cache_control; + return message; + }); + + const customParams = asRecord(custom) as ZAICustomParams; + const userId = + customParams.user_id ?? + (actor?.user?.id + ? `puter-${actor.user.id}${actor.app?.uid ? `-${actor.app.uid}` : ''}`.slice( + 0, + 128, + ) + : undefined); + + const completionParams: ChatCompletionCreateParams = { + messages, + model: modelUsed.id, + ...(tools ? { tools } : {}), + ...(tool_choice !== undefined ? { tool_choice } : {}), + ...(max_tokens !== undefined ? { max_tokens } : {}), + ...(temperature !== undefined ? { temperature } : {}), + ...(top_p !== undefined ? { top_p } : {}), + ...(customParams.do_sample !== undefined + ? { do_sample: customParams.do_sample } + : {}), + ...(customParams.request_id + ? { request_id: customParams.request_id } + : {}), + ...(customParams.response_format + ? { response_format: customParams.response_format } + : {}), + ...(customParams.stop ? { stop: customParams.stop } : {}), + ...(customParams.thinking + ? { thinking: customParams.thinking } + : {}), + ...(customParams.tool_stream !== undefined + ? { tool_stream: customParams.tool_stream } + : {}), + ...(userId ? { user_id: userId } : {}), + stream: !!stream, + ...(stream + ? { + stream_options: { include_usage: true }, + } + : {}), + } as ChatCompletionCreateParams; + + const completion = + await this.#openai.chat.completions.create(completionParams); + + const result = await OpenAIUtil.handle_completion_output({ + usage_calculator: ({ usage }) => { + const trackedUsage = usage + ? OpenAIUtil.extractMeteredUsage(usage) + : { + prompt_tokens: 0, + completion_tokens: 0, + cached_tokens: 0, + }; + const costsOverrideFromModel = Object.fromEntries( + Object.entries(trackedUsage).map(([key, value]) => { + return [key, value * Number(modelUsed.costs[key] ?? 0)]; + }), + ); + this.#meteringService.utilRecordUsageObject( + trackedUsage, + actor, + `zai:${modelUsed.id}`, + costsOverrideFromModel, + ); + return trackedUsage; + }, + stream, + completion, + }); + + this.#normalizeReasoningContent(result); + return result; + } + + checkModeration( + _text: string, + ): ReturnType { + throw new Error('Method not implemented.'); + } + + #normalizeReasoningContent( + result: Awaited>, + ) { + if (!('message' in result) || !result.message) return; + + const message = result.message as Record; + if ( + message.reasoning === undefined && + message.reasoning_content !== undefined + ) { + message.reasoning = message.reasoning_content; + } + delete message.reasoning_content; + + if (!Array.isArray(message.content)) return; + + for (const contentPart of message.content) { + const part = asRecord(contentPart); + if ( + part.reasoning === undefined && + part.reasoning_content !== undefined + ) { + part.reasoning = part.reasoning_content; + } + delete part.reasoning_content; + } + } +} diff --git a/src/backend/drivers/ai-chat/providers/zai/models.ts b/src/backend/drivers/ai-chat/providers/zai/models.ts new file mode 100644 index 000000000..eac619b87 --- /dev/null +++ b/src/backend/drivers/ai-chat/providers/zai/models.ts @@ -0,0 +1,192 @@ +import type { IChatModel } from '../../types.js'; + +const CENTS_PER_USD = 100; +const MTOK = 1_000_000; +const K = 1_000; + +const usdPerMToken = ( + inputUsd: number, + outputUsd: number, + cachedInputUsd = 0, +) => ({ + tokens: MTOK, + prompt_tokens: inputUsd * CENTS_PER_USD, + completion_tokens: outputUsd * CENTS_PER_USD, + cached_tokens: cachedInputUsd * CENTS_PER_USD, +}); + +const textModel = ( + id: string, + name: string, + context: number, + maxTokens: number, + costs: IChatModel['costs'], +): IChatModel => ({ + puterId: `zai:zai/${id}`, + id, + name, + aliases: [`zai/${id}`], + modalities: { input: ['text'], output: ['text'] }, + open_weights: false, + tool_call: true, + context, + max_tokens: maxTokens, + costs_currency: 'usd-cents', + input_cost_key: 'prompt_tokens', + output_cost_key: 'completion_tokens', + costs, +}); + +const visionModel = ( + id: string, + name: string, + context: number, + maxTokens: number, + costs: IChatModel['costs'], +): IChatModel => ({ + puterId: `zai:zai/${id}`, + id, + name, + aliases: [`zai/${id}`], + modalities: { input: ['text', 'image', 'video', 'file'], output: ['text'] }, + open_weights: false, + tool_call: true, + context, + max_tokens: maxTokens, + costs_currency: 'usd-cents', + input_cost_key: 'prompt_tokens', + output_cost_key: 'completion_tokens', + costs, +}); + +// Hardcoded from https://docs.z.ai/api-reference/llm/chat-completion and +// https://docs.z.ai/guides/overview/pricing. +export const ZAI_MODELS: IChatModel[] = [ + textModel( + 'glm-5.1', + 'GLM-5.1', + 200 * K, + 128 * K, + usdPerMToken(1.4, 4.4, 0.26), + ), + textModel('glm-5', 'GLM-5', 200 * K, 128 * K, usdPerMToken(1, 3.2, 0.2)), + textModel( + 'glm-5-turbo', + 'GLM-5-Turbo', + 200 * K, + 128 * K, + usdPerMToken(1.2, 4, 0.24), + ), + textModel( + 'glm-4.7', + 'GLM-4.7', + 200 * K, + 128 * K, + usdPerMToken(0.6, 2.2, 0.11), + ), + textModel( + 'glm-4.7-flashx', + 'GLM-4.7-FlashX', + 200 * K, + 128 * K, + usdPerMToken(0.07, 0.4, 0.01), + ), + textModel( + 'glm-4.7-flash', + 'GLM-4.7-Flash', + 200 * K, + 128 * K, + usdPerMToken(0, 0, 0), + ), + textModel( + 'glm-4.6', + 'GLM-4.6', + 200 * K, + 128 * K, + usdPerMToken(0.6, 2.2, 0.11), + ), + textModel( + 'glm-4.5', + 'GLM-4.5', + 128 * K, + 96 * K, + usdPerMToken(0.6, 2.2, 0.11), + ), + textModel( + 'glm-4.5-x', + 'GLM-4.5-X', + 128 * K, + 96 * K, + usdPerMToken(2.2, 8.9, 0.45), + ), + textModel( + 'glm-4.5-air', + 'GLM-4.5-Air', + 128 * K, + 96 * K, + usdPerMToken(0.2, 1.1, 0.03), + ), + textModel( + 'glm-4.5-airx', + 'GLM-4.5-AirX', + 128 * K, + 96 * K, + usdPerMToken(1.1, 4.5, 0.22), + ), + textModel( + 'glm-4.5-flash', + 'GLM-4.5-Flash', + 128 * K, + 96 * K, + usdPerMToken(0, 0, 0), + ), + textModel( + 'glm-4-32b-0414-128k', + 'GLM-4-32B-0414-128K', + 128 * K, + 16 * K, + usdPerMToken(0.1, 0.1, 0), + ), + visionModel( + 'glm-5v-turbo', + 'GLM-5V-Turbo', + 200 * K, + 128 * K, + usdPerMToken(1.2, 4, 0.24), + ), + visionModel( + 'glm-4.6v', + 'GLM-4.6V', + 128 * K, + 32 * K, + usdPerMToken(0.3, 0.9, 0.05), + ), + visionModel( + 'glm-4.6v-flashx', + 'GLM-4.6V-FlashX', + 128 * K, + 32 * K, + usdPerMToken(0.04, 0.4, 0.004), + ), + visionModel( + 'glm-4.6v-flash', + 'GLM-4.6V-Flash', + 128 * K, + 32 * K, + usdPerMToken(0, 0, 0), + ), + visionModel( + 'glm-4.5v', + 'GLM-4.5V', + 128 * K, + 16 * K, + usdPerMToken(0.6, 1.8, 0.11), + ), + visionModel( + 'autoglm-phone-multilingual', + 'AutoGLM-Phone-Multilingual', + 4 * K, + 4 * K, + usdPerMToken(0, 0, 0), + ), +];