Files
InfoGenie/infogenie-backend-go/internal/service/ai.go
2026-03-28 20:59:52 +08:00

214 lines
5.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"time"
"infogenie-backend/internal/database"
"infogenie-backend/internal/model"
)
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type chatRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
}
type chatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
// loadAIConfig 从数据库读取AI配置
func loadAIConfig(provider string) (apiKey, apiBase, defaultModel string, models []string, ok bool) {
if database.DB == nil {
return "", "", "", nil, false
}
var config model.AIConfig
if err := database.DB.Where("provider = ? AND is_enabled = ?", provider, true).First(&config).Error; err != nil {
return "", "", "", nil, false
}
// 解析models JSON
var modelList []string
if config.Models != "" {
if err := json.Unmarshal([]byte(config.Models), &modelList); err != nil {
// 如果解析失败,返回空的模型列表
modelList = []string{}
}
}
return config.APIKey, config.APIBase, config.DefaultModel, modelList, true
}
// loadRuntimeDeepSeek 读取管理员在后台配置的 DeepSeek 兼容接口OpenAI 格式),优先于 ai_config.json
func loadRuntimeDeepSeek() (apiBase, apiKey, defModel string, ok bool) {
if database.DB == nil {
return "", "", "", false
}
var row model.SiteAIRuntime
if err := database.DB.First(&row, 1).Error; err != nil {
return "", "", "", false
}
base := strings.TrimSpace(row.APIBase)
key := strings.TrimSpace(row.APIKey)
dm := strings.TrimSpace(row.DefaultModel)
if base != "" && key != "" {
return base, key, dm, true
}
return "", "", "", false
}
func CallDeepSeek(messages []ChatMessage, model string, maxRetries int) (string, error) {
// 首先尝试从SiteAIRuntime读取配置向后兼容
if base, key, defModel, ok := loadRuntimeDeepSeek(); ok {
if model == "" {
model = defModel
}
if model == "" {
model = "deepseek-chat"
}
url := strings.TrimSuffix(base, "/") + "/chat/completions"
return callOpenAICompatible(url, key, model, messages, maxRetries, 90*time.Second)
}
// 从新的AI配置表读取
if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("deepseek"); ok {
if model == "" {
model = defaultModel
}
if model == "" {
model = "deepseek-chat"
}
// 验证模型是否在允许列表中
if len(models) > 0 {
allowed := false
for _, m := range models {
if m == model {
allowed = true
break
}
}
if !allowed {
model = models[0] // 使用第一个允许的模型
}
}
url := strings.TrimSuffix(apiBase, "/") + "/chat/completions"
return callOpenAICompatible(url, apiKey, model, messages, maxRetries, 90*time.Second)
}
return "", fmt.Errorf("DeepSeek配置未设置请在管理员后台配置API Key和Base URL")
}
func CallKimi(messages []ChatMessage, model string) (string, error) {
// 从新的AI配置表读取
if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("kimi"); ok {
if model == "" {
model = defaultModel
}
if model == "" {
model = "kimi-k2-0905-preview"
}
// 验证模型是否在允许列表中
if len(models) > 0 {
allowed := false
for _, m := range models {
if m == model {
allowed = true
break
}
}
if !allowed {
model = models[0] // 使用第一个允许的模型
}
}
url := strings.TrimSuffix(apiBase, "/") + "/v1/chat/completions"
return callOpenAICompatible(url, apiKey, model, messages, 1, 30*time.Second)
}
return "", fmt.Errorf("Kimi配置未设置请在管理员后台配置API Key和Base URL")
}
func callOpenAICompatible(url, apiKey, model string, messages []ChatMessage, maxRetries int, timeout time.Duration) (string, error) {
reqBody := chatRequest{
Model: model,
Messages: messages,
Temperature: 0.7,
MaxTokens: 2000,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
client := &http.Client{Timeout: timeout}
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
req, _ := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
lastErr = err
if attempt < maxRetries-1 {
backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second
time.Sleep(backoff)
continue
}
return "", fmt.Errorf("API调用异常已重试%d次: %w", maxRetries, err)
}
respBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode == 200 {
var result chatResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("AI未返回有效内容")
}
return result.Choices[0].Message.Content, nil
}
lastErr = fmt.Errorf("API调用失败: %d - %s", resp.StatusCode, string(respBody))
if attempt < maxRetries-1 {
backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second
time.Sleep(backoff)
}
}
return "", lastErr
}
func CallAI(provider, model string, messages []ChatMessage) (string, error) {
switch provider {
case "deepseek":
return CallDeepSeek(messages, model, 3)
case "kimi":
return CallKimi(messages, model)
default:
return "", fmt.Errorf("不支持的AI提供商: %s目前支持的提供商: deepseek, kimi", provider)
}
}