fix(route): handle concurrent route creation race condition

Catch unique constraint violation (SQLSTATE 23505) during route creation
in ResolveConversation and fall back to Find, preventing duplicate key
errors when concurrent inbound messages hit the same chat simultaneously.
Add shared IsUniqueViolation helper to internal/db.
This commit is contained in:
BBQ
2026-02-13 20:26:22 +08:00
parent f1d53e1c2c
commit 670698090f
3 changed files with 179 additions and 0 deletions
+8
View File
@@ -219,6 +219,14 @@ func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput)
ReplyTarget: input.ReplyTarget,
})
if err != nil {
// Concurrent insert race: another goroutine created the same route between
// our Find and Create calls. Fall back to Find the winning row.
if dbpkg.IsUniqueViolation(err) {
existing, findErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID)
if findErr == nil {
return ResolveConversationResult{ChatID: existing.ChatID, RouteID: existing.ID, Created: false}, nil
}
}
return ResolveConversationResult{}, fmt.Errorf("create route: %w", err)
}
+49
View File
@@ -0,0 +1,49 @@
package db
import (
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
// ParseUUID converts a string UUID to pgtype.UUID.
func ParseUUID(id string) (pgtype.UUID, error) {
parsed, err := uuid.Parse(strings.TrimSpace(id))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err)
}
var pgID pgtype.UUID
pgID.Valid = true
copy(pgID.Bytes[:], parsed[:])
return pgID, nil
}
// TimeFromPg converts a pgtype.Timestamptz to time.Time.
func TimeFromPg(value pgtype.Timestamptz) time.Time {
if value.Valid {
return value.Time
}
return time.Time{}
}
// TextToString returns the string value of pgtype.Text, or "" when invalid.
func TextToString(value pgtype.Text) string {
if !value.Valid {
return ""
}
return value.String
}
// IsUniqueViolation reports whether err is a PostgreSQL unique constraint violation (SQLSTATE 23505).
func IsUniqueViolation(err error) bool {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
return false
}
return pgErr.Code == "23505"
}
+122
View File
@@ -0,0 +1,122 @@
package db
import (
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
func TestParseUUID(t *testing.T) {
validUUID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
tests := []struct {
name string
id string
wantErr bool
want pgtype.UUID
}{
{
name: "valid",
id: "550e8400-e29b-41d4-a716-446655440000",
wantErr: false,
want: pgtype.UUID{Bytes: validUUID, Valid: true},
},
{
name: "valid with whitespace",
id: " 550e8400-e29b-41d4-a716-446655440000 ",
wantErr: false,
want: pgtype.UUID{Bytes: validUUID, Valid: true},
},
{
name: "invalid format",
id: "not-a-uuid",
wantErr: true,
},
{
name: "empty",
id: "",
wantErr: true,
},
{
name: "partial",
id: "550e8400-e29b",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseUUID(tt.id)
if (err != nil) != tt.wantErr {
t.Errorf("ParseUUID() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && (got.Valid != tt.want.Valid || got.Bytes != tt.want.Bytes) {
t.Errorf("ParseUUID() = %v, want %v", got, tt.want)
}
})
}
}
func TestTimeFromPg(t *testing.T) {
now := time.Now()
tests := []struct {
name string
value pgtype.Timestamptz
want time.Time
}{
{"valid", pgtype.Timestamptz{Time: now, Valid: true}, now},
{"invalid", pgtype.Timestamptz{}, time.Time{}},
{"valid zero", pgtype.Timestamptz{Time: time.Time{}, Valid: true}, time.Time{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := TimeFromPg(tt.value)
if !got.Equal(tt.want) {
t.Errorf("TimeFromPg() = %v, want %v", got, tt.want)
}
})
}
}
func TestTextToString(t *testing.T) {
tests := []struct {
name string
value pgtype.Text
want string
}{
{"valid", pgtype.Text{String: "hello", Valid: true}, "hello"},
{"invalid", pgtype.Text{}, ""},
{"valid empty", pgtype.Text{String: "", Valid: true}, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := TextToString(tt.value); got != tt.want {
t.Errorf("TextToString() = %q, want %q", got, tt.want)
}
})
}
}
func TestIsUniqueViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"plain error", fmt.Errorf("some error"), false},
{"unique violation", &pgconn.PgError{Code: "23505"}, true},
{"other pg error", &pgconn.PgError{Code: "23503"}, false},
{"wrapped unique violation", fmt.Errorf("wrapped: %w", &pgconn.PgError{Code: "23505"}), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsUniqueViolation(tt.err); got != tt.want {
t.Errorf("IsUniqueViolation() = %v, want %v", got, tt.want)
}
})
}
}