mirror of
https://github.com/memohai/Memoh.git
synced 2026-04-27 07:16:19 +09:00
fix(route): handle concurrent route creation race condition
Catch unique constraint violation (SQLSTATE 23505) during route creation in ResolveConversation and fall back to Find, preventing duplicate key errors when concurrent inbound messages hit the same chat simultaneously. Add shared IsUniqueViolation helper to internal/db.
This commit is contained in:
@@ -219,6 +219,14 @@ func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput)
|
||||
ReplyTarget: input.ReplyTarget,
|
||||
})
|
||||
if err != nil {
|
||||
// Concurrent insert race: another goroutine created the same route between
|
||||
// our Find and Create calls. Fall back to Find the winning row.
|
||||
if dbpkg.IsUniqueViolation(err) {
|
||||
existing, findErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID)
|
||||
if findErr == nil {
|
||||
return ResolveConversationResult{ChatID: existing.ChatID, RouteID: existing.ID, Created: false}, nil
|
||||
}
|
||||
}
|
||||
return ResolveConversationResult{}, fmt.Errorf("create route: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
// ParseUUID converts a string UUID to pgtype.UUID.
|
||||
func ParseUUID(id string) (pgtype.UUID, error) {
|
||||
parsed, err := uuid.Parse(strings.TrimSpace(id))
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err)
|
||||
}
|
||||
var pgID pgtype.UUID
|
||||
pgID.Valid = true
|
||||
copy(pgID.Bytes[:], parsed[:])
|
||||
return pgID, nil
|
||||
}
|
||||
|
||||
// TimeFromPg converts a pgtype.Timestamptz to time.Time.
|
||||
func TimeFromPg(value pgtype.Timestamptz) time.Time {
|
||||
if value.Valid {
|
||||
return value.Time
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// TextToString returns the string value of pgtype.Text, or "" when invalid.
|
||||
func TextToString(value pgtype.Text) string {
|
||||
if !value.Valid {
|
||||
return ""
|
||||
}
|
||||
return value.String
|
||||
}
|
||||
|
||||
// IsUniqueViolation reports whether err is a PostgreSQL unique constraint violation (SQLSTATE 23505).
|
||||
func IsUniqueViolation(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if !errors.As(err, &pgErr) {
|
||||
return false
|
||||
}
|
||||
return pgErr.Code == "23505"
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func TestParseUUID(t *testing.T) {
|
||||
validUUID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
wantErr bool
|
||||
want pgtype.UUID
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
id: "550e8400-e29b-41d4-a716-446655440000",
|
||||
wantErr: false,
|
||||
want: pgtype.UUID{Bytes: validUUID, Valid: true},
|
||||
},
|
||||
{
|
||||
name: "valid with whitespace",
|
||||
id: " 550e8400-e29b-41d4-a716-446655440000 ",
|
||||
wantErr: false,
|
||||
want: pgtype.UUID{Bytes: validUUID, Valid: true},
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
id: "not-a-uuid",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
id: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "partial",
|
||||
id: "550e8400-e29b",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseUUID(tt.id)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseUUID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && (got.Valid != tt.want.Valid || got.Bytes != tt.want.Bytes) {
|
||||
t.Errorf("ParseUUID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeFromPg(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
value pgtype.Timestamptz
|
||||
want time.Time
|
||||
}{
|
||||
{"valid", pgtype.Timestamptz{Time: now, Valid: true}, now},
|
||||
{"invalid", pgtype.Timestamptz{}, time.Time{}},
|
||||
{"valid zero", pgtype.Timestamptz{Time: time.Time{}, Valid: true}, time.Time{}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := TimeFromPg(tt.value)
|
||||
if !got.Equal(tt.want) {
|
||||
t.Errorf("TimeFromPg() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextToString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value pgtype.Text
|
||||
want string
|
||||
}{
|
||||
{"valid", pgtype.Text{String: "hello", Valid: true}, "hello"},
|
||||
{"invalid", pgtype.Text{}, ""},
|
||||
{"valid empty", pgtype.Text{String: "", Valid: true}, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := TextToString(tt.value); got != tt.want {
|
||||
t.Errorf("TextToString() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsUniqueViolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"plain error", fmt.Errorf("some error"), false},
|
||||
{"unique violation", &pgconn.PgError{Code: "23505"}, true},
|
||||
{"other pg error", &pgconn.PgError{Code: "23503"}, false},
|
||||
{"wrapped unique violation", fmt.Errorf("wrapped: %w", &pgconn.PgError{Code: "23505"}), true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsUniqueViolation(tt.err); got != tt.want {
|
||||
t.Errorf("IsUniqueViolation() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user