diff --git a/db/migrations/0001_init.up.sql b/db/migrations/0001_init.up.sql index 490d5cb9..48639ab7 100644 --- a/db/migrations/0001_init.up.sql +++ b/db/migrations/0001_init.up.sql @@ -55,6 +55,23 @@ CREATE TABLE IF NOT EXISTS snapshots ( created_at TIMESTAMPTZ NOT NULL DEFAULT now() ); +CREATE TABLE IF NOT EXISTS models ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + model_id TEXT NOT NULL, + name TEXT, + base_url TEXT NOT NULL, + api_key TEXT NOT NULL, + client_type TEXT NOT NULL, + dimensions INTEGER, + type TEXT NOT NULL DEFAULT 'chat', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + CONSTRAINT models_model_id_unique UNIQUE (model_id), + CONSTRAINT models_type_check CHECK (type IN ('chat', 'embedding')), + CONSTRAINT models_client_type_check CHECK (client_type IN ('openai', 'anthropic', 'google')), + CONSTRAINT models_dimensions_check CHECK (type != 'embedding' OR dimensions IS NOT NULL) +); + CREATE INDEX IF NOT EXISTS idx_snapshots_container_id ON snapshots(container_id); CREATE INDEX IF NOT EXISTS idx_snapshots_parent_id ON snapshots(parent_snapshot_id); diff --git a/db/queries/models.sql b/db/queries/models.sql new file mode 100644 index 00000000..6d21f972 --- /dev/null +++ b/db/queries/models.sql @@ -0,0 +1,71 @@ +-- name: CreateModel :one +INSERT INTO models (model_id, name, base_url, api_key, client_type, dimensions, type) +VALUES ( + sqlc.arg(model_id), + sqlc.arg(name), + sqlc.arg(base_url), + sqlc.arg(api_key), + sqlc.arg(client_type), + sqlc.arg(dimensions), + sqlc.arg(type) +) +RETURNING *; + +-- name: GetModelByID :one +SELECT * FROM models WHERE id = sqlc.arg(id); + +-- name: GetModelByModelID :one +SELECT * FROM models WHERE model_id = sqlc.arg(model_id); + +-- name: ListModels :many +SELECT * FROM models +ORDER BY created_at DESC; + +-- name: ListModelsByType :many +SELECT * FROM models +WHERE type = sqlc.arg(type) +ORDER BY created_at DESC; + +-- name: ListModelsByClientType :many +SELECT * FROM models +WHERE client_type = sqlc.arg(client_type) +ORDER BY created_at DESC; + +-- name: UpdateModel :one +UPDATE models +SET + name = sqlc.arg(name), + base_url = sqlc.arg(base_url), + api_key = sqlc.arg(api_key), + client_type = sqlc.arg(client_type), + dimensions = sqlc.arg(dimensions), + type = sqlc.arg(type), + updated_at = now() +WHERE id = sqlc.arg(id) +RETURNING *; + +-- name: UpdateModelByModelID :one +UPDATE models +SET + name = sqlc.arg(name), + base_url = sqlc.arg(base_url), + api_key = sqlc.arg(api_key), + client_type = sqlc.arg(client_type), + dimensions = sqlc.arg(dimensions), + type = sqlc.arg(type), + updated_at = now() +WHERE model_id = sqlc.arg(model_id) +RETURNING *; + +-- name: DeleteModel :exec +DELETE FROM models WHERE id = sqlc.arg(id); + +-- name: DeleteModelByModelID :exec +DELETE FROM models WHERE model_id = sqlc.arg(model_id); + +-- name: CountModels :one +SELECT COUNT(*) FROM models; + +-- name: CountModelsByType :one +SELECT COUNT(*) FROM models WHERE type = sqlc.arg(type); + diff --git a/docs/docs.go b/docs/docs.go index 8344d67e..2d46cf22 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -610,6 +610,395 @@ const docTemplate = `{ } } } + }, + "/models": { + "get": { + "description": "Get a list of all configured models, optionally filtered by type or client type", + "tags": [ + "models" + ], + "summary": "List all models", + "parameters": [ + { + "type": "string", + "description": "Model type (chat, embedding)", + "name": "type", + "in": "query" + }, + { + "type": "string", + "description": "Client type (openai, anthropic, google)", + "name": "client_type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/models.GetResponse" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "post": { + "description": "Create a new model configuration", + "tags": [ + "models" + ], + "summary": "Create a new model", + "parameters": [ + { + "description": "Model configuration", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.AddRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/models.AddResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/models/count": { + "get": { + "description": "Get the total count of models, optionally filtered by type", + "tags": [ + "models" + ], + "summary": "Get model count", + "parameters": [ + { + "type": "string", + "description": "Model type (chat, embedding)", + "name": "type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.CountResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/models/model/{modelId}": { + "get": { + "description": "Get a model configuration by its model_id field (e.g., gpt-4)", + "tags": [ + "models" + ], + "summary": "Get model by model ID", + "parameters": [ + { + "type": "string", + "description": "Model ID (e.g., gpt-4)", + "name": "modelId", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update a model configuration by its model_id field (e.g., gpt-4)", + "tags": [ + "models" + ], + "summary": "Update model by model ID", + "parameters": [ + { + "type": "string", + "description": "Model ID (e.g., gpt-4)", + "name": "modelId", + "in": "path", + "required": true + }, + { + "description": "Updated model configuration", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.UpdateRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a model configuration by its model_id field (e.g., gpt-4)", + "tags": [ + "models" + ], + "summary": "Delete model by model ID", + "parameters": [ + { + "type": "string", + "description": "Model ID (e.g., gpt-4)", + "name": "modelId", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/models/{id}": { + "get": { + "description": "Get a model configuration by its internal UUID", + "tags": [ + "models" + ], + "summary": "Get model by internal ID", + "parameters": [ + { + "type": "string", + "description": "Model internal ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update a model configuration by its internal UUID", + "tags": [ + "models" + ], + "summary": "Update model by internal ID", + "parameters": [ + { + "type": "string", + "description": "Model internal ID (UUID)", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "Updated model configuration", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.UpdateRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a model configuration by its internal UUID", + "tags": [ + "models" + ], + "summary": "Delete model by internal ID", + "parameters": [ + { + "type": "string", + "description": "Model internal ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } } }, "definitions": { @@ -923,18 +1312,128 @@ const docTemplate = `{ "type": "string" } } + }, + "models.AddRequest": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/models.ClientType" + }, + "dimensions": { + "type": "integer" + }, + "model_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "models.AddResponse": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "model_id": { + "type": "string" + } + } + }, + "models.ClientType": { + "type": "string", + "enum": [ + "openai", + "anthropic", + "google" + ], + "x-enum-varnames": [ + "ClientTypeOpenAI", + "ClientTypeAnthropic", + "ClientTypeGoogle" + ] + }, + "models.CountResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + } + } + }, + "models.GetResponse": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/models.ClientType" + }, + "dimensions": { + "type": "integer" + }, + "model_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "models.UpdateRequest": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/models.ClientType" + }, + "dimensions": { + "type": "integer" + }, + "model_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } } } }` // SwaggerInfo holds exported Swagger Info so clients can modify it var SwaggerInfo = &swag.Spec{ - Version: "1.0", + Version: "", Host: "", - BasePath: "/", - Schemes: []string{"http"}, - Title: "Memoh Go API", - Description: "User-scoped filesystem API for containerd-backed data.", + BasePath: "", + Schemes: []string{}, + Title: "", + Description: "", InfoInstanceName: "swagger", SwaggerTemplate: docTemplate, LeftDelim: "{{", diff --git a/docs/swagger.json b/docs/swagger.json index 16e55989..0c058f78 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -1,15 +1,8 @@ { - "schemes": [ - "http" - ], "swagger": "2.0", "info": { - "description": "User-scoped filesystem API for containerd-backed data.", - "title": "Memoh Go API", - "contact": {}, - "version": "1.0" + "contact": {} }, - "basePath": "/", "paths": { "/auth/login": { "post": { @@ -606,6 +599,395 @@ } } } + }, + "/models": { + "get": { + "description": "Get a list of all configured models, optionally filtered by type or client type", + "tags": [ + "models" + ], + "summary": "List all models", + "parameters": [ + { + "type": "string", + "description": "Model type (chat, embedding)", + "name": "type", + "in": "query" + }, + { + "type": "string", + "description": "Client type (openai, anthropic, google)", + "name": "client_type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/models.GetResponse" + } + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "post": { + "description": "Create a new model configuration", + "tags": [ + "models" + ], + "summary": "Create a new model", + "parameters": [ + { + "description": "Model configuration", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.AddRequest" + } + } + ], + "responses": { + "201": { + "description": "Created", + "schema": { + "$ref": "#/definitions/models.AddResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/models/count": { + "get": { + "description": "Get the total count of models, optionally filtered by type", + "tags": [ + "models" + ], + "summary": "Get model count", + "parameters": [ + { + "type": "string", + "description": "Model type (chat, embedding)", + "name": "type", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.CountResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/models/model/{modelId}": { + "get": { + "description": "Get a model configuration by its model_id field (e.g., gpt-4)", + "tags": [ + "models" + ], + "summary": "Get model by model ID", + "parameters": [ + { + "type": "string", + "description": "Model ID (e.g., gpt-4)", + "name": "modelId", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update a model configuration by its model_id field (e.g., gpt-4)", + "tags": [ + "models" + ], + "summary": "Update model by model ID", + "parameters": [ + { + "type": "string", + "description": "Model ID (e.g., gpt-4)", + "name": "modelId", + "in": "path", + "required": true + }, + { + "description": "Updated model configuration", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.UpdateRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a model configuration by its model_id field (e.g., gpt-4)", + "tags": [ + "models" + ], + "summary": "Delete model by model ID", + "parameters": [ + { + "type": "string", + "description": "Model ID (e.g., gpt-4)", + "name": "modelId", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } + }, + "/models/{id}": { + "get": { + "description": "Get a model configuration by its internal UUID", + "tags": [ + "models" + ], + "summary": "Get model by internal ID", + "parameters": [ + { + "type": "string", + "description": "Model internal ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "put": { + "description": "Update a model configuration by its internal UUID", + "tags": [ + "models" + ], + "summary": "Update model by internal ID", + "parameters": [ + { + "type": "string", + "description": "Model internal ID (UUID)", + "name": "id", + "in": "path", + "required": true + }, + { + "description": "Updated model configuration", + "name": "payload", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/models.UpdateRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/models.GetResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + }, + "delete": { + "description": "Delete a model configuration by its internal UUID", + "tags": [ + "models" + ], + "summary": "Delete model by internal ID", + "parameters": [ + { + "type": "string", + "description": "Model internal ID (UUID)", + "name": "id", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.ErrorResponse" + } + } + } + } } }, "definitions": { @@ -919,6 +1301,116 @@ "type": "string" } } + }, + "models.AddRequest": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/models.ClientType" + }, + "dimensions": { + "type": "integer" + }, + "model_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "models.AddResponse": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "model_id": { + "type": "string" + } + } + }, + "models.ClientType": { + "type": "string", + "enum": [ + "openai", + "anthropic", + "google" + ], + "x-enum-varnames": [ + "ClientTypeOpenAI", + "ClientTypeAnthropic", + "ClientTypeGoogle" + ] + }, + "models.CountResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + } + } + }, + "models.GetResponse": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/models.ClientType" + }, + "dimensions": { + "type": "integer" + }, + "model_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } + }, + "models.UpdateRequest": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "client_type": { + "$ref": "#/definitions/models.ClientType" + }, + "dimensions": { + "type": "integer" + }, + "model_id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "type": { + "type": "string" + } + } } } } \ No newline at end of file diff --git a/docs/swagger.yaml b/docs/swagger.yaml index ba781d74..af179497 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -1,4 +1,3 @@ -basePath: / definitions: handlers.ApplyPatchRequest: properties: @@ -203,11 +202,81 @@ definitions: memory_id: type: string type: object + models.AddRequest: + properties: + api_key: + type: string + base_url: + type: string + client_type: + $ref: '#/definitions/models.ClientType' + dimensions: + type: integer + model_id: + type: string + name: + type: string + type: + type: string + type: object + models.AddResponse: + properties: + id: + type: string + model_id: + type: string + type: object + models.ClientType: + enum: + - openai + - anthropic + - google + type: string + x-enum-varnames: + - ClientTypeOpenAI + - ClientTypeAnthropic + - ClientTypeGoogle + models.CountResponse: + properties: + count: + type: integer + type: object + models.GetResponse: + properties: + api_key: + type: string + base_url: + type: string + client_type: + $ref: '#/definitions/models.ClientType' + dimensions: + type: integer + model_id: + type: string + name: + type: string + type: + type: string + type: object + models.UpdateRequest: + properties: + api_key: + type: string + base_url: + type: string + client_type: + $ref: '#/definitions/models.ClientType' + dimensions: + type: integer + model_id: + type: string + name: + type: string + type: + type: string + type: object info: contact: {} - description: User-scoped filesystem API for containerd-backed data. - title: Memoh Go API - version: "1.0" paths: /auth/login: post: @@ -599,6 +668,262 @@ paths: summary: Update memory tags: - memory -schemes: -- http + /models: + get: + description: Get a list of all configured models, optionally filtered by type + or client type + parameters: + - description: Model type (chat, embedding) + in: query + name: type + type: string + - description: Client type (openai, anthropic, google) + in: query + name: client_type + type: string + responses: + "200": + description: OK + schema: + items: + $ref: '#/definitions/models.GetResponse' + type: array + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: List all models + tags: + - models + post: + description: Create a new model configuration + parameters: + - description: Model configuration + in: body + name: payload + required: true + schema: + $ref: '#/definitions/models.AddRequest' + responses: + "201": + description: Created + schema: + $ref: '#/definitions/models.AddResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Create a new model + tags: + - models + /models/{id}: + delete: + description: Delete a model configuration by its internal UUID + parameters: + - description: Model internal ID (UUID) + in: path + name: id + required: true + type: string + responses: + "204": + description: No Content + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Delete model by internal ID + tags: + - models + get: + description: Get a model configuration by its internal UUID + parameters: + - description: Model internal ID (UUID) + in: path + name: id + required: true + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/models.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get model by internal ID + tags: + - models + put: + description: Update a model configuration by its internal UUID + parameters: + - description: Model internal ID (UUID) + in: path + name: id + required: true + type: string + - description: Updated model configuration + in: body + name: payload + required: true + schema: + $ref: '#/definitions/models.UpdateRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/models.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Update model by internal ID + tags: + - models + /models/count: + get: + description: Get the total count of models, optionally filtered by type + parameters: + - description: Model type (chat, embedding) + in: query + name: type + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/models.CountResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get model count + tags: + - models + /models/model/{modelId}: + delete: + description: Delete a model configuration by its model_id field (e.g., gpt-4) + parameters: + - description: Model ID (e.g., gpt-4) + in: path + name: modelId + required: true + type: string + responses: + "204": + description: No Content + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Delete model by model ID + tags: + - models + get: + description: Get a model configuration by its model_id field (e.g., gpt-4) + parameters: + - description: Model ID (e.g., gpt-4) + in: path + name: modelId + required: true + type: string + responses: + "200": + description: OK + schema: + $ref: '#/definitions/models.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Get model by model ID + tags: + - models + put: + description: Update a model configuration by its model_id field (e.g., gpt-4) + parameters: + - description: Model ID (e.g., gpt-4) + in: path + name: modelId + required: true + type: string + - description: Updated model configuration + in: body + name: payload + required: true + schema: + $ref: '#/definitions/models.UpdateRequest' + responses: + "200": + description: OK + schema: + $ref: '#/definitions/models.GetResponse' + "400": + description: Bad Request + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "404": + description: Not Found + schema: + $ref: '#/definitions/handlers.ErrorResponse' + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.ErrorResponse' + summary: Update model by model ID + tags: + - models swagger: "2.0" diff --git a/go.mod b/go.mod index 8d8fa760..3fff1f63 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/containerd/containerd/api v1.10.0 github.com/containerd/containerd/v2 v2.2.1 github.com/containerd/errdefs v1.0.0 - github.com/cyphar/filepath-securejoin v0.6.1 + github.com/cyphar/filepath-securejoin v0.5.1 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.8.0 @@ -16,12 +16,12 @@ require ( github.com/opencontainers/runtime-spec v1.3.0 github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 github.com/qdrant/go-client v1.16.2 + github.com/stretchr/testify v1.11.1 github.com/swaggo/swag v1.16.6 golang.org/x/crypto v0.47.0 ) require ( - cyphar.com/go-pathrs v0.2.2 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/hcsshim v0.14.0-rc.1 // indirect @@ -86,8 +86,9 @@ require ( golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect golang.org/x/time v0.14.0 // indirect - golang.org/x/tools v0.41.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260114163908-3f89685c29c3 // indirect - google.golang.org/grpc v1.78.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect + golang.org/x/tools v0.40.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect + google.golang.org/grpc v1.76.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c224c192..da0ea083 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cyphar.com/go-pathrs v0.2.2 h1:y9w7hxbkr3zEL78Fjzeg4HEhs2xNy+fbwHiHGJJY2Xo= -cyphar.com/go-pathrs v0.2.2/go.mod h1:y8f1EMG7r+hCuFf/rXsKqMJrJAUoADZGNh5/vZPKcGc= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -41,8 +39,8 @@ github.com/containerd/ttrpc v1.2.7 h1:qIrroQvuOL9HQ1X6KHe2ohc7p+HP/0VE6XPU7elJRq github.com/containerd/ttrpc v1.2.7/go.mod h1:YCXHsb32f+Sq5/72xHubdiJRQY9inL4a4ZQrAbN1q9o= github.com/containerd/typeurl/v2 v2.2.3 h1:yNA/94zxWdvYACdYO8zofhrTVuQY73fFU1y++dYSw40= github.com/containerd/typeurl/v2 v2.2.3/go.mod h1:95ljDnPfD3bAbDJRugOiShd/DlAAsxGtUBhJxIn7SCk= -github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE= -github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc= +github.com/cyphar/filepath-securejoin v0.5.1 h1:eYgfMq5yryL4fbWfkLpFFy2ukSELzaJOTaUTuh+oF48= +github.com/cyphar/filepath-securejoin v0.5.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -267,8 +265,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -280,15 +278,15 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260114163908-3f89685c29c3 h1:C4WAdL+FbjnGlpp2S+HMVhBeCq2Lcib4xZqfPNF6OoQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260114163908-3f89685c29c3/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba h1:UKgtfRM7Yh93Sya0Fo8ZzhDP4qBckrrxEr2oF5UIVb8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= -google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= +google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= +google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -298,8 +296,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= -google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/chat/chat.go b/internal/chat/chat.go new file mode 100644 index 00000000..0efde489 --- /dev/null +++ b/internal/chat/chat.go @@ -0,0 +1,2 @@ +package chat + diff --git a/internal/db/sqlc/models.go b/internal/db/sqlc/models.go index 045b2c8b..63b9d97a 100644 --- a/internal/db/sqlc/models.go +++ b/internal/db/sqlc/models.go @@ -41,6 +41,19 @@ type LifecycleEvent struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } +type Model struct { + ID pgtype.UUID `json:"id"` + ModelID string `json:"model_id"` + Name pgtype.Text `json:"name"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` + ClientType string `json:"client_type"` + Dimensions pgtype.Int4 `json:"dimensions"` + Type string `json:"type"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + type Snapshot struct { ID string `json:"id"` ContainerID string `json:"container_id"` diff --git a/internal/db/sqlc/models.sql.go b/internal/db/sqlc/models.sql.go new file mode 100644 index 00000000..9c4d23be --- /dev/null +++ b/internal/db/sqlc/models.sql.go @@ -0,0 +1,356 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: models.sql + +package sqlc + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const countModels = `-- name: CountModels :one +SELECT COUNT(*) FROM models +` + +func (q *Queries) CountModels(ctx context.Context) (int64, error) { + row := q.db.QueryRow(ctx, countModels) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countModelsByType = `-- name: CountModelsByType :one +SELECT COUNT(*) FROM models WHERE type = $1 +` + +func (q *Queries) CountModelsByType(ctx context.Context, type_ string) (int64, error) { + row := q.db.QueryRow(ctx, countModelsByType, type_) + var count int64 + err := row.Scan(&count) + return count, err +} + +const createModel = `-- name: CreateModel :one +INSERT INTO models (model_id, name, base_url, api_key, client_type, dimensions, type) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at +` + +type CreateModelParams struct { + ModelID string `json:"model_id"` + Name pgtype.Text `json:"name"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` + ClientType string `json:"client_type"` + Dimensions pgtype.Int4 `json:"dimensions"` + Type string `json:"type"` +} + +func (q *Queries) CreateModel(ctx context.Context, arg CreateModelParams) (Model, error) { + row := q.db.QueryRow(ctx, createModel, + arg.ModelID, + arg.Name, + arg.BaseUrl, + arg.ApiKey, + arg.ClientType, + arg.Dimensions, + arg.Type, + ) + var i Model + err := row.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const deleteModel = `-- name: DeleteModel :exec +DELETE FROM models WHERE id = $1 +` + +func (q *Queries) DeleteModel(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteModel, id) + return err +} + +const deleteModelByModelID = `-- name: DeleteModelByModelID :exec +DELETE FROM models WHERE model_id = $1 +` + +func (q *Queries) DeleteModelByModelID(ctx context.Context, modelID string) error { + _, err := q.db.Exec(ctx, deleteModelByModelID, modelID) + return err +} + +const getModelByID = `-- name: GetModelByID :one +SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models WHERE id = $1 +` + +func (q *Queries) GetModelByID(ctx context.Context, id pgtype.UUID) (Model, error) { + row := q.db.QueryRow(ctx, getModelByID, id) + var i Model + err := row.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getModelByModelID = `-- name: GetModelByModelID :one +SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models WHERE model_id = $1 +` + +func (q *Queries) GetModelByModelID(ctx context.Context, modelID string) (Model, error) { + row := q.db.QueryRow(ctx, getModelByModelID, modelID) + var i Model + err := row.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const listModels = `-- name: ListModels :many +SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models +ORDER BY created_at DESC +` + +func (q *Queries) ListModels(ctx context.Context) ([]Model, error) { + rows, err := q.db.Query(ctx, listModels) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Model + for rows.Next() { + var i Model + if err := rows.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listModelsByClientType = `-- name: ListModelsByClientType :many +SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models +WHERE client_type = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListModelsByClientType(ctx context.Context, clientType string) ([]Model, error) { + rows, err := q.db.Query(ctx, listModelsByClientType, clientType) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Model + for rows.Next() { + var i Model + if err := rows.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listModelsByType = `-- name: ListModelsByType :many +SELECT id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at FROM models +WHERE type = $1 +ORDER BY created_at DESC +` + +func (q *Queries) ListModelsByType(ctx context.Context, type_ string) ([]Model, error) { + rows, err := q.db.Query(ctx, listModelsByType, type_) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Model + for rows.Next() { + var i Model + if err := rows.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateModel = `-- name: UpdateModel :one +UPDATE models +SET + name = $1, + base_url = $2, + api_key = $3, + client_type = $4, + dimensions = $5, + type = $6, + updated_at = now() +WHERE id = $7 +RETURNING id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at +` + +type UpdateModelParams struct { + Name pgtype.Text `json:"name"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` + ClientType string `json:"client_type"` + Dimensions pgtype.Int4 `json:"dimensions"` + Type string `json:"type"` + ID pgtype.UUID `json:"id"` +} + +func (q *Queries) UpdateModel(ctx context.Context, arg UpdateModelParams) (Model, error) { + row := q.db.QueryRow(ctx, updateModel, + arg.Name, + arg.BaseUrl, + arg.ApiKey, + arg.ClientType, + arg.Dimensions, + arg.Type, + arg.ID, + ) + var i Model + err := row.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateModelByModelID = `-- name: UpdateModelByModelID :one +UPDATE models +SET + name = $1, + base_url = $2, + api_key = $3, + client_type = $4, + dimensions = $5, + type = $6, + updated_at = now() +WHERE model_id = $7 +RETURNING id, model_id, name, base_url, api_key, client_type, dimensions, type, created_at, updated_at +` + +type UpdateModelByModelIDParams struct { + Name pgtype.Text `json:"name"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` + ClientType string `json:"client_type"` + Dimensions pgtype.Int4 `json:"dimensions"` + Type string `json:"type"` + ModelID string `json:"model_id"` +} + +func (q *Queries) UpdateModelByModelID(ctx context.Context, arg UpdateModelByModelIDParams) (Model, error) { + row := q.db.QueryRow(ctx, updateModelByModelID, + arg.Name, + arg.BaseUrl, + arg.ApiKey, + arg.ClientType, + arg.Dimensions, + arg.Type, + arg.ModelID, + ) + var i Model + err := row.Scan( + &i.ID, + &i.ModelID, + &i.Name, + &i.BaseUrl, + &i.ApiKey, + &i.ClientType, + &i.Dimensions, + &i.Type, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/handlers/models.go b/internal/handlers/models.go new file mode 100644 index 00000000..49484f2e --- /dev/null +++ b/internal/handlers/models.go @@ -0,0 +1,258 @@ +package handlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" + + "github.com/memohai/memoh/internal/models" +) + +type ModelsHandler struct { + service *models.Service +} + +func NewModelsHandler(service *models.Service) *ModelsHandler { + return &ModelsHandler{service: service} +} + +func (h *ModelsHandler) Register(e *echo.Echo) { + group := e.Group("/models") + group.POST("", h.Create) + group.GET("", h.List) + group.GET("/:id", h.GetByID) + group.GET("/model/:modelId", h.GetByModelID) + group.PUT("/:id", h.UpdateByID) + group.PUT("/model/:modelId", h.UpdateByModelID) + group.DELETE("/:id", h.DeleteByID) + group.DELETE("/model/:modelId", h.DeleteByModelID) + group.GET("/count", h.Count) +} + +// Create godoc +// @Summary Create a new model +// @Description Create a new model configuration +// @Tags models +// @Param payload body models.AddRequest true "Model configuration" +// @Success 201 {object} models.AddResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models [post] +func (h *ModelsHandler) Create(c echo.Context) error { + var req models.AddRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + resp, err := h.service.Create(c.Request().Context(), req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusCreated, resp) +} + +// List godoc +// @Summary List all models +// @Description Get a list of all configured models, optionally filtered by type or client type +// @Tags models +// @Param type query string false "Model type (chat, embedding)" +// @Param client_type query string false "Client type (openai, anthropic, google)" +// @Success 200 {array} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models [get] +func (h *ModelsHandler) List(c echo.Context) error { + modelType := c.QueryParam("type") + clientType := c.QueryParam("client_type") + + var resp []models.GetResponse + var err error + + if modelType != "" { + resp, err = h.service.ListByType(c.Request().Context(), models.ModelType(modelType)) + } else if clientType != "" { + resp, err = h.service.ListByClientType(c.Request().Context(), models.ClientType(clientType)) + } else { + resp, err = h.service.List(c.Request().Context()) + } + + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// GetByID godoc +// @Summary Get model by internal ID +// @Description Get a model configuration by its internal UUID +// @Tags models +// @Param id path string true "Model internal ID (UUID)" +// @Success 200 {object} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/{id} [get] +func (h *ModelsHandler) GetByID(c echo.Context) error { + id := c.Param("id") + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + + resp, err := h.service.GetByID(c.Request().Context(), id) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// GetByModelID godoc +// @Summary Get model by model ID +// @Description Get a model configuration by its model_id field (e.g., gpt-4) +// @Tags models +// @Param modelId path string true "Model ID (e.g., gpt-4)" +// @Success 200 {object} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/model/{modelId} [get] +func (h *ModelsHandler) GetByModelID(c echo.Context) error { + modelID := c.Param("modelId") + if modelID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "modelId is required") + } + + resp, err := h.service.GetByModelID(c.Request().Context(), modelID) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// UpdateByID godoc +// @Summary Update model by internal ID +// @Description Update a model configuration by its internal UUID +// @Tags models +// @Param id path string true "Model internal ID (UUID)" +// @Param payload body models.UpdateRequest true "Updated model configuration" +// @Success 200 {object} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/{id} [put] +func (h *ModelsHandler) UpdateByID(c echo.Context) error { + id := c.Param("id") + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + + var req models.UpdateRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + resp, err := h.service.UpdateByID(c.Request().Context(), id, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// UpdateByModelID godoc +// @Summary Update model by model ID +// @Description Update a model configuration by its model_id field (e.g., gpt-4) +// @Tags models +// @Param modelId path string true "Model ID (e.g., gpt-4)" +// @Param payload body models.UpdateRequest true "Updated model configuration" +// @Success 200 {object} models.GetResponse +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/model/{modelId} [put] +func (h *ModelsHandler) UpdateByModelID(c echo.Context) error { + modelID := c.Param("modelId") + if modelID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "modelId is required") + } + + var req models.UpdateRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + resp, err := h.service.UpdateByModelID(c.Request().Context(), modelID, req) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, resp) +} + +// DeleteByID godoc +// @Summary Delete model by internal ID +// @Description Delete a model configuration by its internal UUID +// @Tags models +// @Param id path string true "Model internal ID (UUID)" +// @Success 204 "No Content" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/{id} [delete] +func (h *ModelsHandler) DeleteByID(c echo.Context) error { + id := c.Param("id") + if id == "" { + return echo.NewHTTPError(http.StatusBadRequest, "id is required") + } + + if err := h.service.DeleteByID(c.Request().Context(), id); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// DeleteByModelID godoc +// @Summary Delete model by model ID +// @Description Delete a model configuration by its model_id field (e.g., gpt-4) +// @Tags models +// @Param modelId path string true "Model ID (e.g., gpt-4)" +// @Success 204 "No Content" +// @Failure 400 {object} ErrorResponse +// @Failure 404 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/model/{modelId} [delete] +func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { + modelID := c.Param("modelId") + if modelID == "" { + return echo.NewHTTPError(http.StatusBadRequest, "modelId is required") + } + + if err := h.service.DeleteByModelID(c.Request().Context(), modelID); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.NoContent(http.StatusNoContent) +} + +// Count godoc +// @Summary Get model count +// @Description Get the total count of models, optionally filtered by type +// @Tags models +// @Param type query string false "Model type (chat, embedding)" +// @Success 200 {object} models.CountResponse +// @Failure 400 {object} ErrorResponse +// @Failure 500 {object} ErrorResponse +// @Router /models/count [get] +func (h *ModelsHandler) Count(c echo.Context) error { + modelType := c.QueryParam("type") + + var count int64 + var err error + + if modelType != "" { + count, err = h.service.CountByType(c.Request().Context(), models.ModelType(modelType)) + } else { + count, err = h.service.Count(c.Request().Context()) + } + + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, models.CountResponse{Count: count}) +} diff --git a/internal/handlers/swagger.go b/internal/handlers/swagger.go index b7edbcda..64f2c937 100644 --- a/internal/handlers/swagger.go +++ b/internal/handlers/swagger.go @@ -8,7 +8,7 @@ import ( "github.com/labstack/echo/v4" ) -//go:generate go run github.com/swaggo/swag/cmd/swag@latest init -g ../../cmd/agent/docs.go -o ../../docs --parseDependency --parseInternal +//go:generate go run github.com/swaggo/swag/cmd/swag@latest init -g swagger.go -o ../../docs --parseDependency --parseInternal var ( swaggerSpec []byte diff --git a/internal/models/models.go b/internal/models/models.go new file mode 100644 index 00000000..31771eec --- /dev/null +++ b/internal/models/models.go @@ -0,0 +1,302 @@ +package models + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db/sqlc" +) + +// Service provides CRUD operations for models +type Service struct { + queries *sqlc.Queries +} + +// NewService creates a new models service +func NewService(queries *sqlc.Queries) *Service { + return &Service{ + queries: queries, + } +} + +// Create adds a new model to the database +func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, error) { + model := Model(req) + if err := model.Validate(); err != nil { + return AddResponse{}, fmt.Errorf("validation failed: %w", err) + } + + // Convert to sqlc params + params := sqlc.CreateModelParams{ + ModelID: model.ModelID, + BaseUrl: model.BaseURL, + ApiKey: model.APIKey, + ClientType: string(model.ClientType), + Type: string(model.Type), + } + + // Handle optional name field + if model.Name != "" { + params.Name = pgtype.Text{String: model.Name, Valid: true} + } + + // Handle optional dimensions field (only for embedding models) + if model.Type == ModelTypeEmbedding && model.Dimensions > 0 { + params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} + } + + created, err := s.queries.CreateModel(ctx, params) + if err != nil { + return AddResponse{}, fmt.Errorf("failed to create model: %w", err) + } + + // Convert pgtype.UUID to string + var idStr string + if created.ID.Valid { + id, err := uuid.FromBytes(created.ID.Bytes[:]) + if err != nil { + return AddResponse{}, fmt.Errorf("failed to convert UUID: %w", err) + } + idStr = id.String() + } + + return AddResponse{ + ID: idStr, + ModelID: created.ModelID, + }, nil +} + +// GetByID retrieves a model by its internal UUID +func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { + uuid, err := parseUUID(id) + if err != nil { + return GetResponse{}, fmt.Errorf("invalid ID: %w", err) + } + + dbModel, err := s.queries.GetModelByID(ctx, uuid) + if err != nil { + return GetResponse{}, fmt.Errorf("failed to get model: %w", err) + } + + return convertToGetResponse(dbModel), nil +} + +// GetByModelID retrieves a model by its model_id field +func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse, error) { + if modelID == "" { + return GetResponse{}, fmt.Errorf("model_id is required") + } + + dbModel, err := s.queries.GetModelByModelID(ctx, modelID) + if err != nil { + return GetResponse{}, fmt.Errorf("failed to get model: %w", err) + } + + return convertToGetResponse(dbModel), nil +} + +// List returns all models +func (s *Service) List(ctx context.Context) ([]GetResponse, error) { + dbModels, err := s.queries.ListModels(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list models: %w", err) + } + + return convertToGetResponseList(dbModels), nil +} + +// ListByType returns models filtered by type (chat or embedding) +func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) { + if modelType != ModelTypeChat && modelType != ModelTypeEmbedding { + return nil, fmt.Errorf("invalid model type: %s", modelType) + } + + dbModels, err := s.queries.ListModelsByType(ctx, string(modelType)) + if err != nil { + return nil, fmt.Errorf("failed to list models by type: %w", err) + } + + return convertToGetResponseList(dbModels), nil +} + +// ListByClientType returns models filtered by client type +func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) { + if clientType != ClientTypeOpenAI && clientType != ClientTypeAnthropic && clientType != ClientTypeGoogle { + return nil, fmt.Errorf("invalid client type: %s", clientType) + } + + dbModels, err := s.queries.ListModelsByClientType(ctx, string(clientType)) + if err != nil { + return nil, fmt.Errorf("failed to list models by client type: %w", err) + } + + return convertToGetResponseList(dbModels), nil +} + +// UpdateByID updates a model by its internal UUID +func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { + uuid, err := parseUUID(id) + if err != nil { + return GetResponse{}, fmt.Errorf("invalid ID: %w", err) + } + + model := Model(req) + if err := model.Validate(); err != nil { + return GetResponse{}, fmt.Errorf("validation failed: %w", err) + } + + params := sqlc.UpdateModelParams{ + ID: uuid, + BaseUrl: model.BaseURL, + ApiKey: model.APIKey, + ClientType: string(model.ClientType), + Type: string(model.Type), + } + + if model.Name != "" { + params.Name = pgtype.Text{String: model.Name, Valid: true} + } + + if model.Type == ModelTypeEmbedding && model.Dimensions > 0 { + params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} + } + + updated, err := s.queries.UpdateModel(ctx, params) + if err != nil { + return GetResponse{}, fmt.Errorf("failed to update model: %w", err) + } + + return convertToGetResponse(updated), nil +} + +// UpdateByModelID updates a model by its model_id field +func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req UpdateRequest) (GetResponse, error) { + if modelID == "" { + return GetResponse{}, fmt.Errorf("model_id is required") + } + + model := Model(req) + if err := model.Validate(); err != nil { + return GetResponse{}, fmt.Errorf("validation failed: %w", err) + } + + params := sqlc.UpdateModelByModelIDParams{ + ModelID: modelID, + BaseUrl: model.BaseURL, + ApiKey: model.APIKey, + ClientType: string(model.ClientType), + Type: string(model.Type), + } + + if model.Name != "" { + params.Name = pgtype.Text{String: model.Name, Valid: true} + } + + if model.Type == ModelTypeEmbedding && model.Dimensions > 0 { + params.Dimensions = pgtype.Int4{Int32: int32(model.Dimensions), Valid: true} + } + + updated, err := s.queries.UpdateModelByModelID(ctx, params) + if err != nil { + return GetResponse{}, fmt.Errorf("failed to update model: %w", err) + } + + return convertToGetResponse(updated), nil +} + +// DeleteByID deletes a model by its internal UUID +func (s *Service) DeleteByID(ctx context.Context, id string) error { + uuid, err := parseUUID(id) + if err != nil { + return fmt.Errorf("invalid ID: %w", err) + } + + if err := s.queries.DeleteModel(ctx, uuid); err != nil { + return fmt.Errorf("failed to delete model: %w", err) + } + + return nil +} + +// DeleteByModelID deletes a model by its model_id field +func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error { + if modelID == "" { + return fmt.Errorf("model_id is required") + } + + if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil { + return fmt.Errorf("failed to delete model: %w", err) + } + + return nil +} + +// Count returns the total number of models +func (s *Service) Count(ctx context.Context) (int64, error) { + count, err := s.queries.CountModels(ctx) + if err != nil { + return 0, fmt.Errorf("failed to count models: %w", err) + } + return count, nil +} + +// CountByType returns the number of models of a specific type +func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) { + if modelType != ModelTypeChat && modelType != ModelTypeEmbedding { + return 0, fmt.Errorf("invalid model type: %s", modelType) + } + + count, err := s.queries.CountModelsByType(ctx, string(modelType)) + if err != nil { + return 0, fmt.Errorf("failed to count models by type: %w", err) + } + return count, nil +} + +// Helper functions + +func parseUUID(id string) (pgtype.UUID, error) { + parsed, err := uuid.Parse(id) + if err != nil { + return pgtype.UUID{}, fmt.Errorf("invalid UUID format: %w", err) + } + + var pgUUID pgtype.UUID + copy(pgUUID.Bytes[:], parsed[:]) + pgUUID.Valid = true + + return pgUUID, nil +} + +func convertToGetResponse(dbModel sqlc.Model) GetResponse { + resp := GetResponse{ + ModelId: dbModel.ModelID, + Model: Model{ + ModelID: dbModel.ModelID, + BaseURL: dbModel.BaseUrl, + APIKey: dbModel.ApiKey, + ClientType: ClientType(dbModel.ClientType), + Type: ModelType(dbModel.Type), + }, + } + + if dbModel.Name.Valid { + resp.Model.Name = dbModel.Name.String + } + + if dbModel.Dimensions.Valid { + resp.Model.Dimensions = int(dbModel.Dimensions.Int32) + } + + return resp +} + +func convertToGetResponseList(dbModels []sqlc.Model) []GetResponse { + responses := make([]GetResponse, 0, len(dbModels)) + for _, dbModel := range dbModels { + responses = append(responses, convertToGetResponse(dbModel)) + } + return responses +} diff --git a/internal/models/models_test.go b/internal/models/models_test.go new file mode 100644 index 00000000..315e669b --- /dev/null +++ b/internal/models/models_test.go @@ -0,0 +1,297 @@ +package models_test + +import ( + "testing" + + "github.com/memohai/memoh/internal/models" + "github.com/stretchr/testify/assert" +) + +// This is an example test file demonstrating how to use the models service +// Actual tests would require database setup and mocking + +func ExampleService_Create() { + // Example usage - in real code, you would initialize with actual database connection + // service := models.NewService(queries) + + // ctx := context.Background() + // req := models.AddRequest{ + // ModelID: "gpt-4", + // Name: "GPT-4", + // BaseURL: "https://api.openai.com/v1", + // APIKey: "sk-...", + // ClientType: models.ClientTypeOpenAI, + // Type: models.ModelTypeChat, + // } + + // resp, err := service.Create(ctx, req) + // if err != nil { + // // handle error + // } + // fmt.Printf("Created model with ID: %s\n", resp.ID) +} + +func ExampleService_GetByModelID() { + // Example usage + // service := models.NewService(queries) + + // ctx := context.Background() + // resp, err := service.GetByModelID(ctx, "gpt-4") + // if err != nil { + // // handle error + // } + // fmt.Printf("Model: %+v\n", resp.Model) +} + +func ExampleService_List() { + // Example usage + // service := models.NewService(queries) + + // ctx := context.Background() + // models, err := service.List(ctx) + // if err != nil { + // // handle error + // } + // for _, model := range models { + // fmt.Printf("Model ID: %s, Type: %s\n", model.ModelID, model.Type) + // } +} + +func ExampleService_ListByType() { + // Example usage + // service := models.NewService(queries) + + // ctx := context.Background() + // chatModels, err := service.ListByType(ctx, models.ModelTypeChat) + // if err != nil { + // // handle error + // } + // fmt.Printf("Found %d chat models\n", len(chatModels)) +} + +func ExampleService_UpdateByModelID() { + // Example usage + // service := models.NewService(queries) + + // ctx := context.Background() + // req := models.UpdateRequest{ + // ModelID: "gpt-4", + // Name: "GPT-4 Turbo", + // BaseURL: "https://api.openai.com/v1", + // APIKey: "sk-...", + // ClientType: models.ClientTypeOpenAI, + // Type: models.ModelTypeChat, + // } + + // resp, err := service.UpdateByModelID(ctx, "gpt-4", req) + // if err != nil { + // // handle error + // } + // fmt.Printf("Updated model: %s\n", resp.ModelId) +} + +func ExampleService_DeleteByModelID() { + // Example usage + // service := models.NewService(queries) + + // ctx := context.Background() + // err := service.DeleteByModelID(ctx, "gpt-4") + // if err != nil { + // // handle error + // } + // fmt.Println("Model deleted successfully") +} + +func TestModel_Validate(t *testing.T) { + tests := []struct { + name string + model models.Model + wantErr bool + }{ + { + name: "valid chat model", + model: models.Model{ + ModelID: "gpt-4", + Name: "GPT-4", + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test", + ClientType: models.ClientTypeOpenAI, + Type: models.ModelTypeChat, + }, + wantErr: false, + }, + { + name: "valid embedding model", + model: models.Model{ + ModelID: "text-embedding-ada-002", + Name: "Ada Embeddings", + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test", + ClientType: models.ClientTypeOpenAI, + Type: models.ModelTypeEmbedding, + Dimensions: 1536, + }, + wantErr: false, + }, + { + name: "missing model_id", + model: models.Model{ + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test", + ClientType: models.ClientTypeOpenAI, + Type: models.ModelTypeChat, + }, + wantErr: true, + }, + { + name: "missing base_url", + model: models.Model{ + ModelID: "gpt-4", + APIKey: "sk-test", + ClientType: models.ClientTypeOpenAI, + Type: models.ModelTypeChat, + }, + wantErr: true, + }, + { + name: "missing api_key", + model: models.Model{ + ModelID: "gpt-4", + BaseURL: "https://api.openai.com/v1", + ClientType: models.ClientTypeOpenAI, + Type: models.ModelTypeChat, + }, + wantErr: true, + }, + { + name: "invalid client type", + model: models.Model{ + ModelID: "gpt-4", + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test", + ClientType: "invalid", + Type: models.ModelTypeChat, + }, + wantErr: true, + }, + { + name: "invalid model type", + model: models.Model{ + ModelID: "gpt-4", + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test", + ClientType: models.ClientTypeOpenAI, + Type: "invalid", + }, + wantErr: true, + }, + { + name: "embedding model missing dimensions", + model: models.Model{ + ModelID: "text-embedding-ada-002", + BaseURL: "https://api.openai.com/v1", + APIKey: "sk-test", + ClientType: models.ClientTypeOpenAI, + Type: models.ModelTypeEmbedding, + Dimensions: 0, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.model.Validate() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestModelTypes(t *testing.T) { + t.Run("ModelType constants", func(t *testing.T) { + assert.Equal(t, models.ModelType("chat"), models.ModelTypeChat) + assert.Equal(t, models.ModelType("embedding"), models.ModelTypeEmbedding) + }) + + t.Run("ClientType constants", func(t *testing.T) { + assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI) + assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic) + assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle) + }) +} + +// Integration test example (requires actual database) +// func TestService_Integration(t *testing.T) { +// if testing.Short() { +// t.Skip("Skipping integration test") +// } +// +// ctx := context.Background() +// +// // Setup database connection +// pool, err := db.Open(ctx, config.PostgresConfig{ +// Host: "localhost", +// Port: 5432, +// User: "test", +// Password: "test", +// Database: "test_db", +// SSLMode: "disable", +// }) +// require.NoError(t, err) +// defer pool.Close() +// +// queries := sqlc.New(pool) +// service := models.NewService(queries) +// +// // Test Create +// createReq := models.AddRequest{ +// ModelID: "test-gpt-4", +// Name: "Test GPT-4", +// BaseURL: "https://api.openai.com/v1", +// APIKey: "sk-test", +// ClientType: models.ClientTypeOpenAI, +// Type: models.ModelTypeChat, +// } +// createResp, err := service.Create(ctx, createReq) +// require.NoError(t, err) +// assert.NotEmpty(t, createResp.ID) +// assert.Equal(t, "test-gpt-4", createResp.ModelID) +// +// // Test GetByModelID +// getResp, err := service.GetByModelID(ctx, "test-gpt-4") +// require.NoError(t, err) +// assert.Equal(t, "test-gpt-4", getResp.ModelID) +// assert.Equal(t, "Test GPT-4", getResp.Name) +// +// // Test List +// models, err := service.List(ctx) +// require.NoError(t, err) +// assert.NotEmpty(t, models) +// +// // Test Update +// updateReq := models.UpdateRequest{ +// ModelID: "test-gpt-4", +// Name: "Updated GPT-4", +// BaseURL: "https://api.openai.com/v1", +// APIKey: "sk-test-updated", +// ClientType: models.ClientTypeOpenAI, +// Type: models.ModelTypeChat, +// } +// updateResp, err := service.UpdateByModelID(ctx, "test-gpt-4", updateReq) +// require.NoError(t, err) +// assert.Equal(t, "Updated GPT-4", updateResp.Name) +// +// // Test Count +// count, err := service.Count(ctx) +// require.NoError(t, err) +// assert.Greater(t, count, int64(0)) +// +// // Test Delete +// err = service.DeleteByModelID(ctx, "test-gpt-4") +// require.NoError(t, err) +// } + diff --git a/internal/models/types.go b/internal/models/types.go new file mode 100644 index 00000000..3256bf07 --- /dev/null +++ b/internal/models/types.go @@ -0,0 +1,91 @@ +package models + +import ( + "errors" +) + +type ModelType string + +const ( + ModelTypeChat = "chat" + ModelTypeEmbedding = "embedding" +) + +type ClientType string + +const ( + ClientTypeOpenAI ClientType = "openai" + ClientTypeAnthropic ClientType = "anthropic" + ClientTypeGoogle ClientType = "google" +) + +type Model struct { + ModelID string `json:"model_id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + ClientType ClientType `json:"client_type"` + Type ModelType `json:"type"` + Dimensions int `json:"dimensions"` +} + +func (m *Model) Validate() error { + if m.ModelID == "" { + return errors.New("model ID is required") + } + if m.BaseURL == "" { + return errors.New("base URL is required") + } + if m.APIKey == "" { + return errors.New("API key is required") + } + if m.ClientType == "" { + return errors.New("client type is required") + } + if m.Type != ModelTypeChat && m.Type != ModelTypeEmbedding { + return errors.New("invalid model type") + } + if m.ClientType != ClientTypeOpenAI && m.ClientType != ClientTypeAnthropic && m.ClientType != ClientTypeGoogle { + return errors.New("invalid client type") + } + if m.Type == ModelTypeEmbedding && m.Dimensions <= 0 { + return errors.New("dimensions must be greater than 0") + } + return nil +} + +type AddRequest Model + +type AddResponse struct { + ID string `json:"id"` + ModelID string `json:"model_id"` +} + +type GetRequest struct { + ID string `json:"id"` +} + +type GetResponse struct { + ModelId string `json:"model_id"` + Model +} + +type UpdateRequest Model + +type ListRequest struct { + Type ModelType `json:"type,omitempty"` + ClientType ClientType `json:"client_type,omitempty"` +} + +type DeleteRequest struct { + ID string `json:"id,omitempty"` + ModelID string `json:"model_id,omitempty"` +} + +type DeleteResponse struct { + Message string `json:"message"` +} + +type CountResponse struct { + Count int64 `json:"count"` +} diff --git a/package.json b/package.json index 2dd62037..8323d47b 100644 --- a/package.json +++ b/package.json @@ -37,7 +37,8 @@ "vue-eslint-parser": "^10.2.0" }, "dependencies": { - "dotenv": "^17.2.3" + "dotenv": "^17.2.3", + "drizzle-kit": "^0.31.8" }, "pnpm": { "peerDependencyRules": { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 08e1f3a9..fc8f1dd0 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -11,6 +11,9 @@ importers: dotenv: specifier: ^17.2.3 version: 17.2.3 + drizzle-kit: + specifier: ^0.31.8 + version: 0.31.8 devDependencies: '@types/node': specifier: ^25.0.3