mirror of
https://github.com/HeyPuter/puter.git
synced 2026-05-12 12:30:47 +00:00
New image models (#1542)
* Refactor OpenAIImageGenerationService to more easily support new models, add GPT-image-1 * set default image model to gpt-image-1
This commit is contained in:
@@ -96,6 +96,8 @@ class AIInterfaceService extends BaseService {
|
||||
parameters: {
|
||||
prompt: { type: 'string' },
|
||||
quality: { type: 'string' },
|
||||
model: { type: 'string' },
|
||||
ratio: { type: 'json' },
|
||||
},
|
||||
result_choices: [
|
||||
{
|
||||
|
||||
@@ -38,6 +38,17 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
|
||||
_construct () {
|
||||
this.models_ = {
|
||||
'gpt-image-1': {
|
||||
"low:1024x1024": 0.011,
|
||||
"low:1024x1536": 0.016,
|
||||
"low:1536x1024": 0.016,
|
||||
"medium:1024x1024": 0.042,
|
||||
"medium:1024x1536": 0.063,
|
||||
"medium:1536x1024": 0.063,
|
||||
"high:1024x1024": 0.167,
|
||||
"high:1024x1536": 0.25,
|
||||
"high:1536x1024": 0.25
|
||||
},
|
||||
'dall-e-3': {
|
||||
'1024x1024': 0.04,
|
||||
'1024x1792': 0.08,
|
||||
@@ -87,17 +98,19 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
* @returns {Promise<string>} URL of the generated image
|
||||
* @throws {Error} If prompt is not a string or ratio is invalid
|
||||
*/
|
||||
async generate ({ prompt, quality, test_mode }) {
|
||||
async generate (params) {
|
||||
const { prompt, quality, test_mode, model, ratio } = params;
|
||||
|
||||
if ( test_mode ) {
|
||||
return new TypedValue({
|
||||
$: 'string:url:web',
|
||||
content_type: 'image',
|
||||
}, 'https://puter-sample-data.puter.site/image_example.png');
|
||||
}
|
||||
|
||||
const url = await this.generate(prompt, {
|
||||
quality,
|
||||
ratio: this.constructor.RATIO_SQUARE,
|
||||
ratio: ratio || this.constructor.RATIO_SQUARE,
|
||||
model
|
||||
});
|
||||
|
||||
const image = new TypedValue({
|
||||
@@ -113,6 +126,10 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
static RATIO_SQUARE = { w: 1024, h: 1024 };
|
||||
static RATIO_PORTRAIT = { w: 1024, h: 1792 };
|
||||
static RATIO_LANDSCAPE = { w: 1792, h: 1024 };
|
||||
|
||||
// GPT-Image-1 specific ratios
|
||||
static RATIO_GPT_PORTRAIT = { w: 1024, h: 1536 };
|
||||
static RATIO_GPT_LANDSCAPE = { w: 1536, h: 1024 };
|
||||
|
||||
async generate (prompt, {
|
||||
ratio,
|
||||
@@ -123,11 +140,13 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
throw new Error('`prompt` must be a string');
|
||||
}
|
||||
|
||||
if ( ! ratio || ! this._validate_ratio(ratio) ) {
|
||||
throw new Error('`ratio` must be a valid ratio');
|
||||
if ( ! ratio || ! this._validate_ratio(ratio, model) ) {
|
||||
throw new Error('`ratio` must be a valid ratio for model ' + model);
|
||||
}
|
||||
|
||||
model = model ?? 'dall-e-3';
|
||||
// Somewhat sane defaults
|
||||
model = model ?? 'gpt-image-1';
|
||||
quality = quality ?? 'low'
|
||||
|
||||
if ( ! this.models_[model] ) {
|
||||
throw APIError.create('field_invalid', null, {
|
||||
@@ -138,29 +157,24 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
});
|
||||
}
|
||||
|
||||
if ( quality && quality !== 'standard' && quality !== 'hd' ) {
|
||||
// Validate quality based on the model
|
||||
const validQualities = this._getValidQualities(model);
|
||||
if ( quality !== undefined && !validQualities.includes(quality) ) {
|
||||
throw APIError.create('field_invalid', null, {
|
||||
key: 'quality',
|
||||
expected: 'one of: standard, hd',
|
||||
expected: 'one of: ' + validQualities.join(', ').replace(/^$/, 'none (no quality)'),
|
||||
got: quality,
|
||||
});
|
||||
}
|
||||
|
||||
console.log('SPECIFIED QUALITY:', quality);
|
||||
|
||||
|
||||
const size = `${ratio.w}x${ratio.h}`;
|
||||
const price_key = (quality === 'hd' ? 'hd:' : '') + size;
|
||||
const price_key = this._buildPriceKey(model, quality, size);
|
||||
if ( ! this.models_[model][price_key] ) {
|
||||
const availableSizes = Object.keys(this.models_[model]);
|
||||
throw APIError.create('field_invalid', null, {
|
||||
key: 'size',
|
||||
expected: 'one of: standard, hd',
|
||||
got: quality,
|
||||
});
|
||||
}
|
||||
|
||||
if ( ! this.models_[model][size] ) {
|
||||
throw APIError.create('internal_error', null, {
|
||||
message: `price of ${size} not known for model ${model}`
|
||||
key: 'size/quality combination',
|
||||
expected: 'one of: ' + availableSizes.join(', '),
|
||||
got: price_key,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -189,11 +203,15 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
// We can charge immediately
|
||||
await svc_cost.record_cost({ cost: exact_cost });
|
||||
|
||||
const result = await this.openai.images.generate({
|
||||
// Build API parameters based on model
|
||||
const apiParams = this._buildApiParams(model, {
|
||||
user: user_private_uid,
|
||||
prompt,
|
||||
size,
|
||||
quality
|
||||
});
|
||||
|
||||
const result = await this.openai.images.generate(apiParams);
|
||||
|
||||
// Tiny base64 result for testing
|
||||
// const result = {
|
||||
@@ -217,19 +235,118 @@ class OpenAIImageGenerationService extends BaseService {
|
||||
size: `${ratio.w}x${ratio.h}`,
|
||||
};
|
||||
|
||||
if (quality) {
|
||||
spending_meta.size = quality + ":" + spending_meta.size;
|
||||
}
|
||||
|
||||
const svc_spending = Context.get('services').get('spending');
|
||||
svc_spending.record_spending('openai', 'image-generation', spending_meta);
|
||||
|
||||
const url = result.data?.[0]?.url;
|
||||
const url = result.data?.[0]?.url || (result.data?.[0]?.b64_json ? "data:image/png;base64," + result.data[0].b64_json : null);
|
||||
|
||||
if (!url) {
|
||||
throw new Error('Failed to extract image URL from OpenAI response');
|
||||
}
|
||||
|
||||
return url;
|
||||
}
|
||||
|
||||
_validate_ratio (ratio) {
|
||||
return false
|
||||
|| ratio === this.constructor.RATIO_SQUARE
|
||||
|| ratio === this.constructor.RATIO_PORTRAIT
|
||||
|| ratio === this.constructor.RATIO_LANDSCAPE
|
||||
;
|
||||
/**
|
||||
* Get valid quality levels for a specific model
|
||||
* @param {string} model - The model name
|
||||
* @returns {Array<string>} Array of valid quality levels
|
||||
* @private
|
||||
*/
|
||||
_getValidQualities(model) {
|
||||
if (model === 'gpt-image-1') {
|
||||
return ['low', 'medium', 'high'];
|
||||
}
|
||||
if (model === 'dall-e-2') {
|
||||
return [''];
|
||||
}
|
||||
if (model === 'dall-e-3') {
|
||||
return ['', 'hd'];
|
||||
}
|
||||
// Fallback for unknown models - assume no quality tiers
|
||||
return [''];
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the price key for a model based on quality and size
|
||||
* @param {string} model - The model name
|
||||
* @param {string} quality - The quality level
|
||||
* @param {string} size - The image size (e.g., "1024x1024")
|
||||
* @returns {string} The price key
|
||||
* @private
|
||||
*/
|
||||
_buildPriceKey(model, quality, size) {
|
||||
if (model === 'gpt-image-1') {
|
||||
// gpt-image-1 uses format: "quality:size" - default to low if not specified
|
||||
const qualityLevel = quality || 'low';
|
||||
return `${qualityLevel}:${size}`;
|
||||
} else {
|
||||
// dall-e models use format: "hd:size" or just "size"
|
||||
return (quality === 'hd' ? 'hd:' : '') + size;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build API parameters based on the model
|
||||
* @param {string} model - The model name
|
||||
* @param {Object} baseParams - Base parameters for the API call
|
||||
* @returns {Object} API parameters object
|
||||
* @private
|
||||
*/
|
||||
_buildApiParams(model, baseParams) {
|
||||
const apiParams = {
|
||||
user: baseParams.user,
|
||||
prompt: baseParams.prompt,
|
||||
size: baseParams.size,
|
||||
};
|
||||
|
||||
if (model === 'gpt-image-1') {
|
||||
// gpt-image-1 requires the model parameter and uses different quality mapping
|
||||
apiParams.model = model;
|
||||
// Default to low quality if not specified, consistent with _buildPriceKey
|
||||
apiParams.quality = baseParams.quality || 'low';
|
||||
} else {
|
||||
// dall-e models
|
||||
apiParams.model = model;
|
||||
if (baseParams.quality === 'hd') {
|
||||
apiParams.quality = 'hd';
|
||||
}
|
||||
}
|
||||
|
||||
return apiParams;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get valid ratios for a specific model
|
||||
* @param {string} model - The model name
|
||||
* @returns {Array<Object>} Array of valid ratio objects
|
||||
* @private
|
||||
*/
|
||||
_getValidRatios(model) {
|
||||
const commonRatios = [this.constructor.RATIO_SQUARE];
|
||||
|
||||
if (model === 'gpt-image-1') {
|
||||
return [
|
||||
...commonRatios,
|
||||
this.constructor.RATIO_GPT_PORTRAIT,
|
||||
this.constructor.RATIO_GPT_LANDSCAPE
|
||||
];
|
||||
} else {
|
||||
// DALL-E models
|
||||
return [
|
||||
...commonRatios,
|
||||
this.constructor.RATIO_PORTRAIT,
|
||||
this.constructor.RATIO_LANDSCAPE
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
_validate_ratio (ratio, model) {
|
||||
const validRatios = this._getValidRatios(model);
|
||||
return validRatios.includes(ratio);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -110,6 +110,17 @@ class TrackSpendingService extends BaseService {
|
||||
*/
|
||||
static ImageGenerationStrategy = class ImageGenerationStrategy {
|
||||
static models = {
|
||||
'gpt-image-1': {
|
||||
"low:1024x1024": 0.011,
|
||||
"low:1024x1536": 0.016,
|
||||
"low:1536x1024": 0.016,
|
||||
"medium:1024x1024": 0.042,
|
||||
"medium:1024x1536": 0.063,
|
||||
"medium:1536x1024": 0.063,
|
||||
"high:1024x1024": 0.167,
|
||||
"high:1024x1536": 0.25,
|
||||
"high:1536x1024": 0.25
|
||||
},
|
||||
'dall-e-3': {
|
||||
'1024x1024': 0.04,
|
||||
'1024x1792': 0.08,
|
||||
|
||||
@@ -621,6 +621,15 @@ class AI{
|
||||
if (typeof args[1] === 'boolean' && args[1] === true) {
|
||||
testMode = true;
|
||||
}
|
||||
|
||||
if (typeof args[0] === 'string' && typeof args[1] === "object") {
|
||||
options = args[1];
|
||||
options.prompt = args[0];
|
||||
}
|
||||
|
||||
if (typeof args[0] === 'object') {
|
||||
options = args[0]
|
||||
}
|
||||
|
||||
// Call the original chat.complete method
|
||||
return await utils.make_driver_method(['prompt'], 'puter-image-generation', undefined, 'generate', {
|
||||
|
||||
Reference in New Issue
Block a user