mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-25 07:00:48 +09:00
fix(image): add normalization for image parts in messages to fix "The messages do not match the ModelMessage[] schema." (#260)
This commit is contained in:
@@ -11,6 +11,8 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -1963,6 +1965,9 @@ func parseLoopDetectionEnabledFromMetadata(payload []byte) bool {
|
||||
func sanitizeMessages(messages []conversation.ModelMessage) []conversation.ModelMessage {
|
||||
cleaned := make([]conversation.ModelMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
if normalized, ok := normalizeImagePartsToDataURL(msg); ok {
|
||||
msg = normalized
|
||||
}
|
||||
if strings.TrimSpace(msg.Role) == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1974,6 +1979,127 @@ func sanitizeMessages(messages []conversation.ModelMessage) []conversation.Model
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func normalizeImagePartsToDataURL(msg conversation.ModelMessage) (conversation.ModelMessage, bool) {
|
||||
if len(msg.Content) == 0 {
|
||||
return msg, false
|
||||
}
|
||||
var parts []map[string]json.RawMessage
|
||||
if err := json.Unmarshal(msg.Content, &parts); err != nil || len(parts) == 0 {
|
||||
return msg, false
|
||||
}
|
||||
|
||||
changed := false
|
||||
for i := range parts {
|
||||
partTypeRaw, ok := parts[i]["type"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var partType string
|
||||
if err := json.Unmarshal(partTypeRaw, &partType); err != nil || !strings.EqualFold(partType, "image") {
|
||||
continue
|
||||
}
|
||||
|
||||
imageRaw, ok := parts[i]["image"]
|
||||
if !ok || len(imageRaw) == 0 {
|
||||
continue
|
||||
}
|
||||
var tmp string
|
||||
if json.Unmarshal(imageRaw, &tmp) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
if b, ok := decodeIndexedByteObject(imageRaw); ok {
|
||||
payload = b
|
||||
} else if b, ok := decodeByteArray(imageRaw); ok {
|
||||
payload = b
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// action trigger to image only here.
|
||||
mediaType := "application/octet-stream"
|
||||
if mediaTypeRaw, ok := parts[i]["mediaType"]; ok {
|
||||
var mt string
|
||||
if err := json.Unmarshal(mediaTypeRaw, &mt); err == nil && strings.TrimSpace(mt) != "" {
|
||||
mediaType = strings.TrimSpace(mt)
|
||||
}
|
||||
}
|
||||
dataURL := "data:" + mediaType + ";base64," + base64.StdEncoding.EncodeToString(payload)
|
||||
rebuilt, err := json.Marshal(dataURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
parts[i]["image"] = rebuilt
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return msg, false
|
||||
}
|
||||
rebuiltContent, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return msg, false
|
||||
}
|
||||
msg.Content = rebuiltContent
|
||||
return msg, true
|
||||
}
|
||||
|
||||
func decodeByteArray(raw json.RawMessage) ([]byte, bool) {
|
||||
var arr []int
|
||||
if err := json.Unmarshal(raw, &arr); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if len(arr) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]byte, len(arr))
|
||||
for i, v := range arr {
|
||||
if v < 0 || v > 255 {
|
||||
return nil, false
|
||||
}
|
||||
out[i] = byte(v)
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func decodeIndexedByteObject(raw json.RawMessage) ([]byte, bool) {
|
||||
var obj map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &obj); err != nil || len(obj) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
type indexedByte struct {
|
||||
idx int
|
||||
val byte
|
||||
}
|
||||
items := make([]indexedByte, 0, len(obj))
|
||||
for k, vRaw := range obj {
|
||||
idx, err := strconv.Atoi(k)
|
||||
if err != nil || idx < 0 {
|
||||
return nil, false
|
||||
}
|
||||
var val int
|
||||
if err := json.Unmarshal(vRaw, &val); err != nil || val < 0 || val > 255 {
|
||||
return nil, false
|
||||
}
|
||||
items = append(items, indexedByte{idx: idx, val: byte(val)})
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool { return items[i].idx < items[j].idx })
|
||||
for i := range items {
|
||||
if items[i].idx != i {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
out := make([]byte, len(items))
|
||||
for i := range items {
|
||||
out[i] = items[i].val
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func normalizeGatewaySkill(entry SkillEntry) (gatewaySkill, bool) {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name == "" {
|
||||
|
||||
@@ -528,3 +528,52 @@ func TestOutboundAssetRefsToMessageRefs_Empty(t *testing.T) {
|
||||
t.Fatalf("expected nil, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeImagePartsToDataURL_ConvertsIndexedObject(t *testing.T) {
|
||||
msg := conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"text","text":"hello"},
|
||||
{"type":"image","image":{"0":82,"1":73,"2":70,"3":70},"mediaType":"image/webp"}
|
||||
]`),
|
||||
}
|
||||
|
||||
normalized, changed := normalizeImagePartsToDataURL(msg)
|
||||
if !changed {
|
||||
t.Fatal("expected message to be normalized")
|
||||
}
|
||||
|
||||
var parts []map[string]any
|
||||
if err := json.Unmarshal(normalized.Content, &parts); err != nil {
|
||||
t.Fatalf("failed to unmarshal normalized content: %v", err)
|
||||
}
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
image, ok := parts[1]["image"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected image to be string data url, got %T", parts[1]["image"])
|
||||
}
|
||||
expected := "data:image/webp;base64," + base64.StdEncoding.EncodeToString([]byte{82, 73, 70, 70})
|
||||
if image != expected {
|
||||
t.Fatalf("unexpected data url, got %q", image)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeImagePartsToDataURL_LeavesStringImageUntouched(t *testing.T) {
|
||||
original := `[
|
||||
{"type":"image","image":"data:image/png;base64,AAAA","mediaType":"image/png"}
|
||||
]`
|
||||
msg := conversation.ModelMessage{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(original),
|
||||
}
|
||||
|
||||
normalized, changed := normalizeImagePartsToDataURL(msg)
|
||||
if changed {
|
||||
t.Fatal("expected no normalization for string image")
|
||||
}
|
||||
if string(normalized.Content) != original {
|
||||
t.Fatalf("expected content unchanged, got %s", string(normalized.Content))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user