mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
65b2797626
Move CreateModel, BuildReasoningOptions, ReasoningBudgetTokens and related types from internal/agent to internal/models as NewSDKChatModel, SDKModelConfig, etc. This eliminates duplicate ClientType constants and centralises all Twilight AI SDK instance creation in a single package. NewSDKEmbeddingModel now accepts a clientType parameter and dispatches to the native Google embedding provider for google-generative-ai, instead of always using the OpenAI-compatible endpoint.
176 lines
4.4 KiB
Go
176 lines
4.4 KiB
Go
package agent
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
sdk "github.com/memohai/twilight-ai/sdk"
|
|
|
|
agenttools "github.com/memohai/memoh/internal/agent/tools"
|
|
"github.com/memohai/memoh/internal/models"
|
|
)
|
|
|
|
func decorateReadMediaTools(model *sdk.Model, tools []sdk.Tool) ([]sdk.Tool, *readMediaDecorationState) {
|
|
if len(tools) == 0 {
|
|
return tools, nil
|
|
}
|
|
|
|
clientType := models.ResolveClientType(model)
|
|
state := &readMediaDecorationState{
|
|
pendingImages: make(map[string]sdk.ImagePart),
|
|
}
|
|
wrapped := make([]sdk.Tool, 0, len(tools))
|
|
found := false
|
|
|
|
for _, tool := range tools {
|
|
if tool.Name != agenttools.ReadMediaToolName || tool.Execute == nil {
|
|
wrapped = append(wrapped, tool)
|
|
continue
|
|
}
|
|
|
|
found = true
|
|
originalExecute := tool.Execute
|
|
toolCopy := tool
|
|
toolCopy.Execute = func(ctx *sdk.ToolExecContext, input any) (any, error) {
|
|
output, err := originalExecute(ctx, input)
|
|
if err != nil {
|
|
return output, err
|
|
}
|
|
|
|
publicResult, image, ok := normalizeReadMediaOutput(output, clientType)
|
|
if !ok {
|
|
return output, nil
|
|
}
|
|
if ctx != nil && strings.TrimSpace(ctx.ToolCallID) != "" && strings.TrimSpace(image.Image) != "" {
|
|
if _, exists := state.pendingImages[ctx.ToolCallID]; !exists {
|
|
state.pendingOrder = append(state.pendingOrder, ctx.ToolCallID)
|
|
}
|
|
state.pendingImages[ctx.ToolCallID] = image
|
|
}
|
|
return publicResult, nil
|
|
}
|
|
wrapped = append(wrapped, toolCopy)
|
|
}
|
|
|
|
if !found {
|
|
return tools, nil
|
|
}
|
|
|
|
return wrapped, state
|
|
}
|
|
|
|
type readMediaDecorationState struct {
|
|
pendingOrder []string
|
|
pendingImages map[string]sdk.ImagePart
|
|
prepareCalls int
|
|
injections []readMediaInjection
|
|
}
|
|
|
|
type readMediaInjection struct {
|
|
afterStep int
|
|
message sdk.Message
|
|
}
|
|
|
|
func (s *readMediaDecorationState) prepareStep(params *sdk.GenerateParams) *sdk.GenerateParams {
|
|
if s == nil || params == nil {
|
|
return nil
|
|
}
|
|
|
|
afterStep := s.prepareCalls
|
|
s.prepareCalls++
|
|
|
|
if len(s.pendingOrder) == 0 {
|
|
return nil
|
|
}
|
|
|
|
parts := make([]sdk.MessagePart, 0, len(s.pendingOrder))
|
|
for _, toolCallID := range s.pendingOrder {
|
|
image, ok := s.pendingImages[toolCallID]
|
|
delete(s.pendingImages, toolCallID)
|
|
if !ok || strings.TrimSpace(image.Image) == "" {
|
|
continue
|
|
}
|
|
parts = append(parts, image)
|
|
}
|
|
s.pendingOrder = s.pendingOrder[:0]
|
|
|
|
if len(parts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
message := sdk.Message{
|
|
Role: sdk.MessageRoleUser,
|
|
Content: parts,
|
|
}
|
|
s.injections = append(s.injections, readMediaInjection{
|
|
afterStep: afterStep,
|
|
message: message,
|
|
})
|
|
|
|
next := *params
|
|
next.Messages = append(append([]sdk.Message(nil), params.Messages...), message)
|
|
return &next
|
|
}
|
|
|
|
func (s *readMediaDecorationState) mergeMessages(steps []sdk.StepResult, fallback []sdk.Message) []sdk.Message {
|
|
if s == nil || len(s.injections) == 0 {
|
|
return fallback
|
|
}
|
|
if len(steps) == 0 {
|
|
merged := append([]sdk.Message(nil), fallback...)
|
|
for _, injection := range s.injections {
|
|
merged = append(merged, injection.message)
|
|
}
|
|
return merged
|
|
}
|
|
|
|
merged := make([]sdk.Message, 0, len(fallback)+len(s.injections))
|
|
injectionIndex := 0
|
|
for stepIndex, step := range steps {
|
|
merged = append(merged, step.Messages...)
|
|
for injectionIndex < len(s.injections) && s.injections[injectionIndex].afterStep == stepIndex {
|
|
merged = append(merged, s.injections[injectionIndex].message)
|
|
injectionIndex++
|
|
}
|
|
}
|
|
for injectionIndex < len(s.injections) {
|
|
merged = append(merged, s.injections[injectionIndex].message)
|
|
injectionIndex++
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func normalizeReadMediaOutput(output any, clientType string) (any, sdk.ImagePart, bool) {
|
|
switch value := output.(type) {
|
|
case agenttools.ReadMediaToolOutput:
|
|
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
|
|
case *agenttools.ReadMediaToolOutput:
|
|
if value == nil {
|
|
return nil, sdk.ImagePart{}, false
|
|
}
|
|
return value.Public, buildReadMediaImagePart(clientType, value.ImageBase64, value.ImageMediaType), true
|
|
default:
|
|
return nil, sdk.ImagePart{}, false
|
|
}
|
|
}
|
|
|
|
func buildReadMediaImagePart(clientType, imageBase64, mediaType string) sdk.ImagePart {
|
|
imageBase64 = strings.TrimSpace(imageBase64)
|
|
mediaType = strings.TrimSpace(mediaType)
|
|
if imageBase64 == "" {
|
|
return sdk.ImagePart{}
|
|
}
|
|
if mediaType == "" {
|
|
mediaType = "image/png"
|
|
}
|
|
|
|
image := imageBase64
|
|
if clientType != string(models.ClientTypeAnthropicMessages) {
|
|
image = fmt.Sprintf("data:%s;base64,%s", mediaType, imageBase64)
|
|
}
|
|
return sdk.ImagePart{
|
|
Image: image,
|
|
MediaType: mediaType,
|
|
}
|
|
}
|