diff --git a/internal/channel/route/service.go b/internal/channel/route/service.go index 12cfc381..7436a9f5 100644 --- a/internal/channel/route/service.go +++ b/internal/channel/route/service.go @@ -219,6 +219,14 @@ func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput) ReplyTarget: input.ReplyTarget, }) if err != nil { + // Concurrent insert race: another goroutine created the same route between + // our Find and Create calls. Fall back to Find the winning row. + if dbpkg.IsUniqueViolation(err) { + existing, findErr := s.Find(ctx, input.BotID, input.Platform, input.ConversationID, input.ThreadID) + if findErr == nil { + return ResolveConversationResult{ChatID: existing.ChatID, RouteID: existing.ID, Created: false}, nil + } + } return ResolveConversationResult{}, fmt.Errorf("create route: %w", err) } diff --git a/internal/db/utils.go b/internal/db/utils.go new file mode 100644 index 00000000..42bd5ba9 --- /dev/null +++ b/internal/db/utils.go @@ -0,0 +1,49 @@ +package db + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +// ParseUUID converts a string UUID to pgtype.UUID. +func ParseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(strings.TrimSpace(id)) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID: %w", err) + } + var pgID pgtype.UUID + pgID.Valid = true + copy(pgID.Bytes[:], parsed[:]) + return pgID, nil +} + +// TimeFromPg converts a pgtype.Timestamptz to time.Time. +func TimeFromPg(value pgtype.Timestamptz) time.Time { + if value.Valid { + return value.Time + } + return time.Time{} +} + +// TextToString returns the string value of pgtype.Text, or "" when invalid. +func TextToString(value pgtype.Text) string { + if !value.Valid { + return "" + } + return value.String +} + +// IsUniqueViolation reports whether err is a PostgreSQL unique constraint violation (SQLSTATE 23505). +func IsUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) { + return false + } + return pgErr.Code == "23505" +} diff --git a/internal/db/utils_test.go b/internal/db/utils_test.go new file mode 100644 index 00000000..92b01386 --- /dev/null +++ b/internal/db/utils_test.go @@ -0,0 +1,122 @@ +package db + +import ( + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" +) + +func TestParseUUID(t *testing.T) { + validUUID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000") + tests := []struct { + name string + id string + wantErr bool + want pgtype.UUID + }{ + { + name: "valid", + id: "550e8400-e29b-41d4-a716-446655440000", + wantErr: false, + want: pgtype.UUID{Bytes: validUUID, Valid: true}, + }, + { + name: "valid with whitespace", + id: " 550e8400-e29b-41d4-a716-446655440000 ", + wantErr: false, + want: pgtype.UUID{Bytes: validUUID, Valid: true}, + }, + { + name: "invalid format", + id: "not-a-uuid", + wantErr: true, + }, + { + name: "empty", + id: "", + wantErr: true, + }, + { + name: "partial", + id: "550e8400-e29b", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseUUID(tt.id) + if (err != nil) != tt.wantErr { + t.Errorf("ParseUUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && (got.Valid != tt.want.Valid || got.Bytes != tt.want.Bytes) { + t.Errorf("ParseUUID() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTimeFromPg(t *testing.T) { + now := time.Now() + tests := []struct { + name string + value pgtype.Timestamptz + want time.Time + }{ + {"valid", pgtype.Timestamptz{Time: now, Valid: true}, now}, + {"invalid", pgtype.Timestamptz{}, time.Time{}}, + {"valid zero", pgtype.Timestamptz{Time: time.Time{}, Valid: true}, time.Time{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := TimeFromPg(tt.value) + if !got.Equal(tt.want) { + t.Errorf("TimeFromPg() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTextToString(t *testing.T) { + tests := []struct { + name string + value pgtype.Text + want string + }{ + {"valid", pgtype.Text{String: "hello", Valid: true}, "hello"}, + {"invalid", pgtype.Text{}, ""}, + {"valid empty", pgtype.Text{String: "", Valid: true}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TextToString(tt.value); got != tt.want { + t.Errorf("TextToString() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestIsUniqueViolation(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"plain error", fmt.Errorf("some error"), false}, + {"unique violation", &pgconn.PgError{Code: "23505"}, true}, + {"other pg error", &pgconn.PgError{Code: "23503"}, false}, + {"wrapped unique violation", fmt.Errorf("wrapped: %w", &pgconn.PgError{Code: "23505"}), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsUniqueViolation(tt.err); got != tt.want { + t.Errorf("IsUniqueViolation() = %v, want %v", got, tt.want) + } + }) + } +}