214 lines
5.9 KiB
Go
214 lines
5.9 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"infogenie-backend/internal/database"
|
||
"infogenie-backend/internal/model"
|
||
)
|
||
|
||
type ChatMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
type chatRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatMessage `json:"messages"`
|
||
Temperature float64 `json:"temperature"`
|
||
MaxTokens int `json:"max_tokens"`
|
||
}
|
||
|
||
type chatResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
// loadAIConfig 从数据库读取AI配置
|
||
func loadAIConfig(provider string) (apiKey, apiBase, defaultModel string, models []string, ok bool) {
|
||
if database.DB == nil {
|
||
return "", "", "", nil, false
|
||
}
|
||
|
||
var config model.AIConfig
|
||
if err := database.DB.Where("provider = ? AND is_enabled = ?", provider, true).First(&config).Error; err != nil {
|
||
return "", "", "", nil, false
|
||
}
|
||
|
||
// 解析models JSON
|
||
var modelList []string
|
||
if config.Models != "" {
|
||
if err := json.Unmarshal([]byte(config.Models), &modelList); err != nil {
|
||
// 如果解析失败,返回空的模型列表
|
||
modelList = []string{}
|
||
}
|
||
}
|
||
|
||
return config.APIKey, config.APIBase, config.DefaultModel, modelList, true
|
||
}
|
||
|
||
// loadRuntimeDeepSeek 读取管理员在后台配置的 DeepSeek 兼容接口(OpenAI 格式),优先于 ai_config.json
|
||
func loadRuntimeDeepSeek() (apiBase, apiKey, defModel string, ok bool) {
|
||
if database.DB == nil {
|
||
return "", "", "", false
|
||
}
|
||
var row model.SiteAIRuntime
|
||
if err := database.DB.First(&row, 1).Error; err != nil {
|
||
return "", "", "", false
|
||
}
|
||
base := strings.TrimSpace(row.APIBase)
|
||
key := strings.TrimSpace(row.APIKey)
|
||
dm := strings.TrimSpace(row.DefaultModel)
|
||
if base != "" && key != "" {
|
||
return base, key, dm, true
|
||
}
|
||
return "", "", "", false
|
||
}
|
||
|
||
func CallDeepSeek(messages []ChatMessage, model string, maxRetries int) (string, error) {
|
||
// 首先尝试从SiteAIRuntime读取配置(向后兼容)
|
||
if base, key, defModel, ok := loadRuntimeDeepSeek(); ok {
|
||
if model == "" {
|
||
model = defModel
|
||
}
|
||
if model == "" {
|
||
model = "deepseek-chat"
|
||
}
|
||
url := strings.TrimSuffix(base, "/") + "/chat/completions"
|
||
return callOpenAICompatible(url, key, model, messages, maxRetries, 90*time.Second)
|
||
}
|
||
|
||
// 从新的AI配置表读取
|
||
if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("deepseek"); ok {
|
||
if model == "" {
|
||
model = defaultModel
|
||
}
|
||
if model == "" {
|
||
model = "deepseek-chat"
|
||
}
|
||
// 验证模型是否在允许列表中
|
||
if len(models) > 0 {
|
||
allowed := false
|
||
for _, m := range models {
|
||
if m == model {
|
||
allowed = true
|
||
break
|
||
}
|
||
}
|
||
if !allowed {
|
||
model = models[0] // 使用第一个允许的模型
|
||
}
|
||
}
|
||
url := strings.TrimSuffix(apiBase, "/") + "/chat/completions"
|
||
return callOpenAICompatible(url, apiKey, model, messages, maxRetries, 90*time.Second)
|
||
}
|
||
|
||
return "", fmt.Errorf("DeepSeek配置未设置,请在管理员后台配置API Key和Base URL")
|
||
}
|
||
|
||
func CallKimi(messages []ChatMessage, model string) (string, error) {
|
||
// 从新的AI配置表读取
|
||
if apiKey, apiBase, defaultModel, models, ok := loadAIConfig("kimi"); ok {
|
||
if model == "" {
|
||
model = defaultModel
|
||
}
|
||
if model == "" {
|
||
model = "kimi-k2-0905-preview"
|
||
}
|
||
// 验证模型是否在允许列表中
|
||
if len(models) > 0 {
|
||
allowed := false
|
||
for _, m := range models {
|
||
if m == model {
|
||
allowed = true
|
||
break
|
||
}
|
||
}
|
||
if !allowed {
|
||
model = models[0] // 使用第一个允许的模型
|
||
}
|
||
}
|
||
url := strings.TrimSuffix(apiBase, "/") + "/v1/chat/completions"
|
||
return callOpenAICompatible(url, apiKey, model, messages, 1, 30*time.Second)
|
||
}
|
||
|
||
return "", fmt.Errorf("Kimi配置未设置,请在管理员后台配置API Key和Base URL")
|
||
}
|
||
|
||
func callOpenAICompatible(url, apiKey, model string, messages []ChatMessage, maxRetries int, timeout time.Duration) (string, error) {
|
||
reqBody := chatRequest{
|
||
Model: model,
|
||
Messages: messages,
|
||
Temperature: 0.7,
|
||
MaxTokens: 2000,
|
||
}
|
||
|
||
bodyBytes, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
client := &http.Client{Timeout: timeout}
|
||
|
||
var lastErr error
|
||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||
req, _ := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
lastErr = err
|
||
if attempt < maxRetries-1 {
|
||
backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second
|
||
time.Sleep(backoff)
|
||
continue
|
||
}
|
||
return "", fmt.Errorf("API调用异常(已重试%d次): %w", maxRetries, err)
|
||
}
|
||
|
||
respBody, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
|
||
if resp.StatusCode == 200 {
|
||
var result chatResponse
|
||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
if len(result.Choices) == 0 {
|
||
return "", fmt.Errorf("AI未返回有效内容")
|
||
}
|
||
return result.Choices[0].Message.Content, nil
|
||
}
|
||
|
||
lastErr = fmt.Errorf("API调用失败: %d - %s", resp.StatusCode, string(respBody))
|
||
if attempt < maxRetries-1 {
|
||
backoff := time.Duration(math.Pow(2, float64(attempt))) * time.Second
|
||
time.Sleep(backoff)
|
||
}
|
||
}
|
||
|
||
return "", lastErr
|
||
}
|
||
|
||
func CallAI(provider, model string, messages []ChatMessage) (string, error) {
|
||
switch provider {
|
||
case "deepseek":
|
||
return CallDeepSeek(messages, model, 3)
|
||
case "kimi":
|
||
return CallKimi(messages, model)
|
||
default:
|
||
return "", fmt.Errorf("不支持的AI提供商: %s,目前支持的提供商: deepseek, kimi", provider)
|
||
}
|
||
}
|