feat(models): add image model type support

Add a dedicated image model type so bots can use image API models without overloading chat model capabilities, while keeping existing chat-based image generation selectable.
This commit is contained in:
Acbox
2026-04-16 16:00:22 +08:00
parent 33e18e7e64
commit ddda00f980
23 changed files with 326 additions and 68 deletions
+147 -37
View File
@@ -29,6 +29,11 @@ type ImageGenProvider struct {
dataMount string
}
type generatedImageFile struct {
Data string
MediaType string
}
func NewImageGenProvider(
log *slog.Logger,
settingsSvc *settings.Service,
@@ -65,6 +70,10 @@ func (p *ImageGenProvider) Tools(ctx context.Context, session SessionContext) ([
if strings.TrimSpace(botSettings.ImageModelID) == "" {
return nil, nil
}
modelResp, err := p.models.GetByID(ctx, botSettings.ImageModelID)
if err != nil || !supportsImageGeneration(modelResp) {
return nil, nil
}
sess := session
return []sdk.Tool{
{
@@ -112,7 +121,7 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
if err != nil {
return nil, fmt.Errorf("failed to load image model: %w", err)
}
if !modelResp.HasCompatibility(models.CompatImageOutput) {
if !supportsImageGeneration(modelResp) {
return nil, errors.New("configured model does not support image generation")
}
@@ -127,43 +136,9 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
return nil, fmt.Errorf("failed to resolve provider credentials: %w", err)
}
sdkModel := models.NewSDKChatModel(models.SDKModelConfig{
ModelID: modelResp.ModelID,
ClientType: provider.ClientType,
APIKey: creds.APIKey,
BaseURL: providers.ProviderConfigString(provider, "base_url"),
})
userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt)
result, err := sdk.GenerateTextResult(ctx,
sdk.WithModel(sdkModel),
sdk.WithMessages([]sdk.Message{
{Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}},
}),
)
file, imgBytes, ext, err := generateImage(ctx, modelResp, provider, creds, prompt, size)
if err != nil {
return nil, fmt.Errorf("image generation failed: %w", err)
}
if len(result.Files) == 0 {
if result.Text != "" {
return map[string]any{"error": "no image generated", "model_response": result.Text}, nil
}
return nil, errors.New("no image was generated by the model")
}
file := result.Files[0]
imgBytes, err := base64.StdEncoding.DecodeString(file.Data)
if err != nil {
return nil, fmt.Errorf("failed to decode generated image: %w", err)
}
ext := "png"
switch {
case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"):
ext = "jpg"
case strings.Contains(file.MediaType, "webp"):
ext = "webp"
return nil, err
}
containerPath := fmt.Sprintf("%s/%d.%s", imageGenDir, time.Now().UnixMilli(), ext)
@@ -196,3 +171,138 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
"size_bytes": len(imgBytes),
}, nil
}
func supportsImageGeneration(model models.GetResponse) bool {
switch model.Type {
case models.ModelTypeChat:
return model.HasCompatibility(models.CompatImageOutput)
case models.ModelTypeImage:
return model.HasCompatibility(models.CompatGenerate)
default:
return false
}
}
func generateImage(
ctx context.Context,
modelResp models.GetResponse,
provider sqlc.Provider,
creds providers.ModelCredentials,
prompt string,
size string,
) (generatedImageFile, []byte, string, error) {
switch modelResp.Type {
case models.ModelTypeChat:
return generateImageFromChatModel(ctx, modelResp, provider, creds, prompt, size)
case models.ModelTypeImage:
return generateImageFromImageModel(ctx, modelResp, provider, creds, prompt, size)
default:
return generatedImageFile{}, nil, "", fmt.Errorf("unsupported image model type: %s", modelResp.Type)
}
}
func generateImageFromChatModel(
ctx context.Context,
modelResp models.GetResponse,
provider sqlc.Provider,
creds providers.ModelCredentials,
prompt string,
size string,
) (generatedImageFile, []byte, string, error) {
sdkModel := models.NewSDKChatModel(models.SDKModelConfig{
ModelID: modelResp.ModelID,
ClientType: provider.ClientType,
APIKey: creds.APIKey,
BaseURL: providers.ProviderConfigString(provider, "base_url"),
})
userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt)
result, err := sdk.GenerateTextResult(ctx,
sdk.WithModel(sdkModel),
sdk.WithMessages([]sdk.Message{
{Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}},
}),
)
if err != nil {
return generatedImageFile{}, nil, "", fmt.Errorf("image generation failed: %w", err)
}
if len(result.Files) == 0 {
if result.Text != "" {
return generatedImageFile{}, nil, "", fmt.Errorf("no image generated: %s", result.Text)
}
return generatedImageFile{}, nil, "", errors.New("no image was generated by the model")
}
file := generatedImageFile{
Data: result.Files[0].Data,
MediaType: result.Files[0].MediaType,
}
imgBytes, ext, err := decodeGeneratedImage(file)
if err != nil {
return generatedImageFile{}, nil, "", err
}
return file, imgBytes, ext, nil
}
func generateImageFromImageModel(
ctx context.Context,
modelResp models.GetResponse,
provider sqlc.Provider,
creds providers.ModelCredentials,
prompt string,
size string,
) (generatedImageFile, []byte, string, error) {
imageModel := models.NewSDKImageGenerationModel(models.SDKModelConfig{
ModelID: modelResp.ModelID,
ClientType: provider.ClientType,
APIKey: creds.APIKey,
BaseURL: providers.ProviderConfigString(provider, "base_url"),
})
if imageModel == nil {
return generatedImageFile{}, nil, "", errors.New("configured provider does not support image generation API")
}
result, err := sdk.GenerateImage(ctx,
sdk.WithImageGenerationModel(imageModel),
sdk.WithImagePrompt(prompt),
sdk.WithImageSize(size),
sdk.WithImageResponseFormat("b64_json"),
sdk.WithImageOutputFormat("png"),
)
if err != nil {
return generatedImageFile{}, nil, "", fmt.Errorf("image generation failed: %w", err)
}
if len(result.Data) == 0 {
return generatedImageFile{}, nil, "", errors.New("no image was generated by the model")
}
if strings.TrimSpace(result.Data[0].B64JSON) == "" {
return generatedImageFile{}, nil, "", errors.New("image model did not return inline image data")
}
file := generatedImageFile{
Data: result.Data[0].B64JSON,
MediaType: "image/png",
}
imgBytes, ext, err := decodeGeneratedImage(file)
if err != nil {
return generatedImageFile{}, nil, "", err
}
return file, imgBytes, ext, nil
}
func decodeGeneratedImage(file generatedImageFile) ([]byte, string, error) {
imgBytes, err := base64.StdEncoding.DecodeString(file.Data)
if err != nil {
return nil, "", fmt.Errorf("failed to decode generated image: %w", err)
}
ext := "png"
switch {
case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"):
ext = "jpg"
case strings.Contains(file.MediaType, "webp"):
ext = "webp"
}
return imgBytes, ext, nil
}
+1 -1
View File
@@ -70,7 +70,7 @@ func (h *ModelsHandler) Create(c echo.Context) error {
// @Summary List all models
// @Description Get a list of all configured models, optionally filtered by type or provider client type
// @Tags models
// @Param type query string false "Model type (chat, embedding)"
// @Param type query string false "Model type (chat, embedding, image, speech)"
// @Param client_type query string false "Provider client type (openai-responses, openai-completions, anthropic-messages, google-generative-ai)"
// @Success 200 {array} models.GetResponse
// @Failure 400 {object} ErrorResponse
+10 -2
View File
@@ -332,12 +332,20 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
for _, m := range remoteModels {
modelType := models.ModelTypeChat
if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) {
switch strings.TrimSpace(m.Type) {
case string(models.ModelTypeEmbedding):
modelType = models.ModelTypeEmbedding
case string(models.ModelTypeImage):
modelType = models.ModelTypeImage
}
compatibilities := m.Compatibilities
if len(compatibilities) == 0 {
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
switch modelType {
case models.ModelTypeImage:
compatibilities = []string{models.CompatGenerate, models.CompatEdit}
case models.ModelTypeChat:
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
}
}
name := strings.TrimSpace(m.Name)
if name == "" {
+4 -4
View File
@@ -128,7 +128,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
// ListByType returns models filtered by type (chat, embedding, or speech).
func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
if !IsValidModelType(modelType) {
return nil, fmt.Errorf("invalid model type: %s", modelType)
}
@@ -165,7 +165,7 @@ func (s *Service) ListEnabled(ctx context.Context) ([]GetResponse, error) {
// ListEnabledByType returns models from enabled providers filtered by type.
func (s *Service) ListEnabledByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
if !IsValidModelType(modelType) {
return nil, fmt.Errorf("invalid model type: %s", modelType)
}
dbModels, err := s.queries.ListEnabledModelsByType(ctx, string(modelType))
@@ -206,7 +206,7 @@ func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]Ge
// ListByProviderIDAndType returns models filtered by provider ID and type.
func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string, modelType ModelType) ([]GetResponse, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
if !IsValidModelType(modelType) {
return nil, fmt.Errorf("invalid model type: %s", modelType)
}
if strings.TrimSpace(providerID) == "" {
@@ -361,7 +361,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
// CountByType returns the number of models of a specific type.
func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) {
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
if !IsValidModelType(modelType) {
return 0, fmt.Errorf("invalid model type: %s", modelType)
}
+15
View File
@@ -50,6 +50,19 @@ func TestModel_Validate(t *testing.T) {
},
wantErr: false,
},
{
name: "valid image model",
model: models.Model{
ModelID: "gpt-image-1",
Name: "GPT Image 1",
ProviderID: "11111111-1111-1111-1111-111111111111",
Type: models.ModelTypeImage,
Config: models.ModelConfig{
Compatibilities: []string{models.CompatGenerate, models.CompatEdit},
},
},
wantErr: false,
},
{
name: "missing model_id",
model: models.Model{
@@ -129,12 +142,14 @@ func TestModel_HasCompatibility(t *testing.T) {
assert.True(t, m.HasCompatibility("tool-call"))
assert.True(t, m.HasCompatibility("reasoning"))
assert.False(t, m.HasCompatibility("image-output"))
assert.False(t, m.HasCompatibility("generate"))
}
func TestModelTypes(t *testing.T) {
t.Run("ModelType constants", func(t *testing.T) {
assert.Equal(t, models.ModelTypeChat, models.ModelType("chat"))
assert.Equal(t, models.ModelTypeEmbedding, models.ModelType("embedding"))
assert.Equal(t, models.ModelTypeImage, models.ModelType("image"))
})
t.Run("ClientType constants", func(t *testing.T) {
+35
View File
@@ -8,6 +8,7 @@ import (
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
openaiimages "github.com/memohai/twilight-ai/provider/openai/images"
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
sdk "github.com/memohai/twilight-ai/sdk"
@@ -121,6 +122,40 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
}
}
func NewSDKImageGenerationModel(cfg SDKModelConfig) *sdk.ImageGenerationModel {
opts := imageProviderOptions(cfg)
if opts == nil {
return nil
}
return openaiimages.New(opts...).GenerationModel(cfg.ModelID)
}
func NewSDKImageEditModel(cfg SDKModelConfig) *sdk.ImageEditModel {
opts := imageProviderOptions(cfg)
if opts == nil {
return nil
}
return openaiimages.New(opts...).EditModel(cfg.ModelID)
}
func imageProviderOptions(cfg SDKModelConfig) []openaiimages.Option {
switch ClientType(cfg.ClientType) {
case ClientTypeOpenAICompletions, ClientTypeOpenAIResponses:
opts := []openaiimages.Option{
openaiimages.WithAPIKey(cfg.APIKey),
}
if cfg.HTTPClient != nil {
opts = append(opts, openaiimages.WithHTTPClient(cfg.HTTPClient))
}
if cfg.BaseURL != "" {
opts = append(opts, openaiimages.WithBaseURL(cfg.BaseURL))
}
return opts
default:
return nil
}
}
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
+19 -2
View File
@@ -11,6 +11,7 @@ type ModelType string
const (
ModelTypeChat ModelType = "chat"
ModelTypeEmbedding ModelType = "embedding"
ModelTypeImage ModelType = "image"
ModelTypeSpeech ModelType = "speech"
)
@@ -30,6 +31,8 @@ const (
CompatVision = "vision"
CompatToolCall = "tool-call"
CompatImageOutput = "image-output"
CompatGenerate = "generate"
CompatEdit = "edit"
CompatReasoning = "reasoning"
)
@@ -43,7 +46,12 @@ const (
// validCompatibilities enumerates accepted compatibility tokens.
var validCompatibilities = map[string]struct{}{
CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {},
CompatVision: {},
CompatToolCall: {},
CompatImageOutput: {},
CompatGenerate: {},
CompatEdit: {},
CompatReasoning: {},
}
var validReasoningEfforts = map[string]struct{}{
@@ -70,6 +78,15 @@ type Model struct {
Config ModelConfig `json:"config"`
}
func IsValidModelType(modelType ModelType) bool {
switch modelType {
case ModelTypeChat, ModelTypeEmbedding, ModelTypeImage, ModelTypeSpeech:
return true
default:
return false
}
}
func (m *Model) Validate() error {
if m.ModelID == "" {
return errors.New("model ID is required")
@@ -80,7 +97,7 @@ func (m *Model) Validate() error {
if _, err := uuid.Parse(m.ProviderID); err != nil {
return errors.New("provider ID must be a valid UUID")
}
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding && m.Type != ModelTypeSpeech {
if !IsValidModelType(m.Type) {
return errors.New("invalid model type")
}
if m.Type == ModelTypeEmbedding {