mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
feat(models): add image model type support
Add a dedicated image model type so bots can use image API models without overloading chat model capabilities, while keeping existing chat-based image generation selectable.
This commit is contained in:
@@ -29,6 +29,11 @@ type ImageGenProvider struct {
|
||||
dataMount string
|
||||
}
|
||||
|
||||
type generatedImageFile struct {
|
||||
Data string
|
||||
MediaType string
|
||||
}
|
||||
|
||||
func NewImageGenProvider(
|
||||
log *slog.Logger,
|
||||
settingsSvc *settings.Service,
|
||||
@@ -65,6 +70,10 @@ func (p *ImageGenProvider) Tools(ctx context.Context, session SessionContext) ([
|
||||
if strings.TrimSpace(botSettings.ImageModelID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
modelResp, err := p.models.GetByID(ctx, botSettings.ImageModelID)
|
||||
if err != nil || !supportsImageGeneration(modelResp) {
|
||||
return nil, nil
|
||||
}
|
||||
sess := session
|
||||
return []sdk.Tool{
|
||||
{
|
||||
@@ -112,7 +121,7 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load image model: %w", err)
|
||||
}
|
||||
if !modelResp.HasCompatibility(models.CompatImageOutput) {
|
||||
if !supportsImageGeneration(modelResp) {
|
||||
return nil, errors.New("configured model does not support image generation")
|
||||
}
|
||||
|
||||
@@ -127,43 +136,9 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
|
||||
return nil, fmt.Errorf("failed to resolve provider credentials: %w", err)
|
||||
}
|
||||
|
||||
sdkModel := models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: modelResp.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
BaseURL: providers.ProviderConfigString(provider, "base_url"),
|
||||
})
|
||||
|
||||
userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt)
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(sdkModel),
|
||||
sdk.WithMessages([]sdk.Message{
|
||||
{Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}},
|
||||
}),
|
||||
)
|
||||
file, imgBytes, ext, err := generateImage(ctx, modelResp, provider, creds, prompt, size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("image generation failed: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Files) == 0 {
|
||||
if result.Text != "" {
|
||||
return map[string]any{"error": "no image generated", "model_response": result.Text}, nil
|
||||
}
|
||||
return nil, errors.New("no image was generated by the model")
|
||||
}
|
||||
|
||||
file := result.Files[0]
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(file.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode generated image: %w", err)
|
||||
}
|
||||
|
||||
ext := "png"
|
||||
switch {
|
||||
case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"):
|
||||
ext = "jpg"
|
||||
case strings.Contains(file.MediaType, "webp"):
|
||||
ext = "webp"
|
||||
return nil, err
|
||||
}
|
||||
|
||||
containerPath := fmt.Sprintf("%s/%d.%s", imageGenDir, time.Now().UnixMilli(), ext)
|
||||
@@ -196,3 +171,138 @@ func (p *ImageGenProvider) execGenerateImage(ctx context.Context, session Sessio
|
||||
"size_bytes": len(imgBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func supportsImageGeneration(model models.GetResponse) bool {
|
||||
switch model.Type {
|
||||
case models.ModelTypeChat:
|
||||
return model.HasCompatibility(models.CompatImageOutput)
|
||||
case models.ModelTypeImage:
|
||||
return model.HasCompatibility(models.CompatGenerate)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func generateImage(
|
||||
ctx context.Context,
|
||||
modelResp models.GetResponse,
|
||||
provider sqlc.Provider,
|
||||
creds providers.ModelCredentials,
|
||||
prompt string,
|
||||
size string,
|
||||
) (generatedImageFile, []byte, string, error) {
|
||||
switch modelResp.Type {
|
||||
case models.ModelTypeChat:
|
||||
return generateImageFromChatModel(ctx, modelResp, provider, creds, prompt, size)
|
||||
case models.ModelTypeImage:
|
||||
return generateImageFromImageModel(ctx, modelResp, provider, creds, prompt, size)
|
||||
default:
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("unsupported image model type: %s", modelResp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func generateImageFromChatModel(
|
||||
ctx context.Context,
|
||||
modelResp models.GetResponse,
|
||||
provider sqlc.Provider,
|
||||
creds providers.ModelCredentials,
|
||||
prompt string,
|
||||
size string,
|
||||
) (generatedImageFile, []byte, string, error) {
|
||||
sdkModel := models.NewSDKChatModel(models.SDKModelConfig{
|
||||
ModelID: modelResp.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
BaseURL: providers.ProviderConfigString(provider, "base_url"),
|
||||
})
|
||||
|
||||
userMsg := fmt.Sprintf("Generate an image with the following description. Size: %s\n\n%s", size, prompt)
|
||||
result, err := sdk.GenerateTextResult(ctx,
|
||||
sdk.WithModel(sdkModel),
|
||||
sdk.WithMessages([]sdk.Message{
|
||||
{Role: sdk.MessageRoleUser, Content: []sdk.MessagePart{sdk.TextPart{Text: userMsg}}},
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("image generation failed: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Files) == 0 {
|
||||
if result.Text != "" {
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("no image generated: %s", result.Text)
|
||||
}
|
||||
return generatedImageFile{}, nil, "", errors.New("no image was generated by the model")
|
||||
}
|
||||
|
||||
file := generatedImageFile{
|
||||
Data: result.Files[0].Data,
|
||||
MediaType: result.Files[0].MediaType,
|
||||
}
|
||||
imgBytes, ext, err := decodeGeneratedImage(file)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", err
|
||||
}
|
||||
return file, imgBytes, ext, nil
|
||||
}
|
||||
|
||||
func generateImageFromImageModel(
|
||||
ctx context.Context,
|
||||
modelResp models.GetResponse,
|
||||
provider sqlc.Provider,
|
||||
creds providers.ModelCredentials,
|
||||
prompt string,
|
||||
size string,
|
||||
) (generatedImageFile, []byte, string, error) {
|
||||
imageModel := models.NewSDKImageGenerationModel(models.SDKModelConfig{
|
||||
ModelID: modelResp.ModelID,
|
||||
ClientType: provider.ClientType,
|
||||
APIKey: creds.APIKey,
|
||||
BaseURL: providers.ProviderConfigString(provider, "base_url"),
|
||||
})
|
||||
if imageModel == nil {
|
||||
return generatedImageFile{}, nil, "", errors.New("configured provider does not support image generation API")
|
||||
}
|
||||
|
||||
result, err := sdk.GenerateImage(ctx,
|
||||
sdk.WithImageGenerationModel(imageModel),
|
||||
sdk.WithImagePrompt(prompt),
|
||||
sdk.WithImageSize(size),
|
||||
sdk.WithImageResponseFormat("b64_json"),
|
||||
sdk.WithImageOutputFormat("png"),
|
||||
)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", fmt.Errorf("image generation failed: %w", err)
|
||||
}
|
||||
if len(result.Data) == 0 {
|
||||
return generatedImageFile{}, nil, "", errors.New("no image was generated by the model")
|
||||
}
|
||||
if strings.TrimSpace(result.Data[0].B64JSON) == "" {
|
||||
return generatedImageFile{}, nil, "", errors.New("image model did not return inline image data")
|
||||
}
|
||||
|
||||
file := generatedImageFile{
|
||||
Data: result.Data[0].B64JSON,
|
||||
MediaType: "image/png",
|
||||
}
|
||||
imgBytes, ext, err := decodeGeneratedImage(file)
|
||||
if err != nil {
|
||||
return generatedImageFile{}, nil, "", err
|
||||
}
|
||||
return file, imgBytes, ext, nil
|
||||
}
|
||||
|
||||
func decodeGeneratedImage(file generatedImageFile) ([]byte, string, error) {
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(file.Data)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to decode generated image: %w", err)
|
||||
}
|
||||
|
||||
ext := "png"
|
||||
switch {
|
||||
case strings.Contains(file.MediaType, "jpeg"), strings.Contains(file.MediaType, "jpg"):
|
||||
ext = "jpg"
|
||||
case strings.Contains(file.MediaType, "webp"):
|
||||
ext = "webp"
|
||||
}
|
||||
return imgBytes, ext, nil
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (h *ModelsHandler) Create(c echo.Context) error {
|
||||
// @Summary List all models
|
||||
// @Description Get a list of all configured models, optionally filtered by type or provider client type
|
||||
// @Tags models
|
||||
// @Param type query string false "Model type (chat, embedding)"
|
||||
// @Param type query string false "Model type (chat, embedding, image, speech)"
|
||||
// @Param client_type query string false "Provider client type (openai-responses, openai-completions, anthropic-messages, google-generative-ai)"
|
||||
// @Success 200 {array} models.GetResponse
|
||||
// @Failure 400 {object} ErrorResponse
|
||||
|
||||
@@ -332,12 +332,20 @@ func (h *ProvidersHandler) ImportModels(c echo.Context) error {
|
||||
|
||||
for _, m := range remoteModels {
|
||||
modelType := models.ModelTypeChat
|
||||
if strings.TrimSpace(m.Type) == string(models.ModelTypeEmbedding) {
|
||||
switch strings.TrimSpace(m.Type) {
|
||||
case string(models.ModelTypeEmbedding):
|
||||
modelType = models.ModelTypeEmbedding
|
||||
case string(models.ModelTypeImage):
|
||||
modelType = models.ModelTypeImage
|
||||
}
|
||||
compatibilities := m.Compatibilities
|
||||
if len(compatibilities) == 0 {
|
||||
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
|
||||
switch modelType {
|
||||
case models.ModelTypeImage:
|
||||
compatibilities = []string{models.CompatGenerate, models.CompatEdit}
|
||||
case models.ModelTypeChat:
|
||||
compatibilities = []string{models.CompatVision, models.CompatToolCall, models.CompatReasoning}
|
||||
}
|
||||
}
|
||||
name := strings.TrimSpace(m.Name)
|
||||
if name == "" {
|
||||
|
||||
@@ -128,7 +128,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) {
|
||||
|
||||
// ListByType returns models filtered by type (chat, embedding, or speech).
|
||||
func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func (s *Service) ListEnabled(ctx context.Context) ([]GetResponse, error) {
|
||||
|
||||
// ListEnabledByType returns models from enabled providers filtered by type.
|
||||
func (s *Service) ListEnabledByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
dbModels, err := s.queries.ListEnabledModelsByType(ctx, string(modelType))
|
||||
@@ -206,7 +206,7 @@ func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]Ge
|
||||
|
||||
// ListByProviderIDAndType returns models filtered by provider ID and type.
|
||||
func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string, modelType ModelType) ([]GetResponse, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return nil, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
if strings.TrimSpace(providerID) == "" {
|
||||
@@ -361,7 +361,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) {
|
||||
|
||||
// CountByType returns the number of models of a specific type.
|
||||
func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) {
|
||||
if modelType != ModelTypeChat && modelType != ModelTypeEmbedding && modelType != ModelTypeSpeech {
|
||||
if !IsValidModelType(modelType) {
|
||||
return 0, fmt.Errorf("invalid model type: %s", modelType)
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,19 @@ func TestModel_Validate(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid image model",
|
||||
model: models.Model{
|
||||
ModelID: "gpt-image-1",
|
||||
Name: "GPT Image 1",
|
||||
ProviderID: "11111111-1111-1111-1111-111111111111",
|
||||
Type: models.ModelTypeImage,
|
||||
Config: models.ModelConfig{
|
||||
Compatibilities: []string{models.CompatGenerate, models.CompatEdit},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_id",
|
||||
model: models.Model{
|
||||
@@ -129,12 +142,14 @@ func TestModel_HasCompatibility(t *testing.T) {
|
||||
assert.True(t, m.HasCompatibility("tool-call"))
|
||||
assert.True(t, m.HasCompatibility("reasoning"))
|
||||
assert.False(t, m.HasCompatibility("image-output"))
|
||||
assert.False(t, m.HasCompatibility("generate"))
|
||||
}
|
||||
|
||||
func TestModelTypes(t *testing.T) {
|
||||
t.Run("ModelType constants", func(t *testing.T) {
|
||||
assert.Equal(t, models.ModelTypeChat, models.ModelType("chat"))
|
||||
assert.Equal(t, models.ModelTypeEmbedding, models.ModelType("embedding"))
|
||||
assert.Equal(t, models.ModelTypeImage, models.ModelType("image"))
|
||||
})
|
||||
|
||||
t.Run("ClientType constants", func(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
googlegenerative "github.com/memohai/twilight-ai/provider/google/generativeai"
|
||||
openaicodex "github.com/memohai/twilight-ai/provider/openai/codex"
|
||||
openaicompletions "github.com/memohai/twilight-ai/provider/openai/completions"
|
||||
openaiimages "github.com/memohai/twilight-ai/provider/openai/images"
|
||||
openairesponses "github.com/memohai/twilight-ai/provider/openai/responses"
|
||||
sdk "github.com/memohai/twilight-ai/sdk"
|
||||
|
||||
@@ -121,6 +122,40 @@ func NewSDKChatModel(cfg SDKModelConfig) *sdk.Model {
|
||||
}
|
||||
}
|
||||
|
||||
func NewSDKImageGenerationModel(cfg SDKModelConfig) *sdk.ImageGenerationModel {
|
||||
opts := imageProviderOptions(cfg)
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
return openaiimages.New(opts...).GenerationModel(cfg.ModelID)
|
||||
}
|
||||
|
||||
func NewSDKImageEditModel(cfg SDKModelConfig) *sdk.ImageEditModel {
|
||||
opts := imageProviderOptions(cfg)
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
return openaiimages.New(opts...).EditModel(cfg.ModelID)
|
||||
}
|
||||
|
||||
func imageProviderOptions(cfg SDKModelConfig) []openaiimages.Option {
|
||||
switch ClientType(cfg.ClientType) {
|
||||
case ClientTypeOpenAICompletions, ClientTypeOpenAIResponses:
|
||||
opts := []openaiimages.Option{
|
||||
openaiimages.WithAPIKey(cfg.APIKey),
|
||||
}
|
||||
if cfg.HTTPClient != nil {
|
||||
opts = append(opts, openaiimages.WithHTTPClient(cfg.HTTPClient))
|
||||
}
|
||||
if cfg.BaseURL != "" {
|
||||
opts = append(opts, openaiimages.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
return opts
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// BuildReasoningOptions returns SDK generation options for reasoning/thinking.
|
||||
func BuildReasoningOptions(cfg SDKModelConfig) []sdk.GenerateOption {
|
||||
if cfg.ReasoningConfig == nil || !cfg.ReasoningConfig.Enabled {
|
||||
|
||||
@@ -11,6 +11,7 @@ type ModelType string
|
||||
const (
|
||||
ModelTypeChat ModelType = "chat"
|
||||
ModelTypeEmbedding ModelType = "embedding"
|
||||
ModelTypeImage ModelType = "image"
|
||||
ModelTypeSpeech ModelType = "speech"
|
||||
)
|
||||
|
||||
@@ -30,6 +31,8 @@ const (
|
||||
CompatVision = "vision"
|
||||
CompatToolCall = "tool-call"
|
||||
CompatImageOutput = "image-output"
|
||||
CompatGenerate = "generate"
|
||||
CompatEdit = "edit"
|
||||
CompatReasoning = "reasoning"
|
||||
)
|
||||
|
||||
@@ -43,7 +46,12 @@ const (
|
||||
|
||||
// validCompatibilities enumerates accepted compatibility tokens.
|
||||
var validCompatibilities = map[string]struct{}{
|
||||
CompatVision: {}, CompatToolCall: {}, CompatImageOutput: {}, CompatReasoning: {},
|
||||
CompatVision: {},
|
||||
CompatToolCall: {},
|
||||
CompatImageOutput: {},
|
||||
CompatGenerate: {},
|
||||
CompatEdit: {},
|
||||
CompatReasoning: {},
|
||||
}
|
||||
|
||||
var validReasoningEfforts = map[string]struct{}{
|
||||
@@ -70,6 +78,15 @@ type Model struct {
|
||||
Config ModelConfig `json:"config"`
|
||||
}
|
||||
|
||||
func IsValidModelType(modelType ModelType) bool {
|
||||
switch modelType {
|
||||
case ModelTypeChat, ModelTypeEmbedding, ModelTypeImage, ModelTypeSpeech:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) Validate() error {
|
||||
if m.ModelID == "" {
|
||||
return errors.New("model ID is required")
|
||||
@@ -80,7 +97,7 @@ func (m *Model) Validate() error {
|
||||
if _, err := uuid.Parse(m.ProviderID); err != nil {
|
||||
return errors.New("provider ID must be a valid UUID")
|
||||
}
|
||||
if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding && m.Type != ModelTypeSpeech {
|
||||
if !IsValidModelType(m.Type) {
|
||||
return errors.New("invalid model type")
|
||||
}
|
||||
if m.Type == ModelTypeEmbedding {
|
||||
|
||||
Reference in New Issue
Block a user