Files
InfoGenie/infogenie-backend-go/config/config.go
2026-03-28 20:59:52 +08:00

234 lines
6.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package config
import (
"fmt"
"os"
"strconv"
"strings"
"github.com/joho/godotenv"
)
type AppConfig struct {
Env string
Port string
DB DBConfig
Mail MailConfig
AI AIConfig
AuthCenter AuthCenterConfig
// SiteAdminToken 与前端管理员口令一致,用于更新站点展示配置(如 60s 功能开关);为空则禁止写入
SiteAdminToken string
}
type DBConfig struct {
Host string
Port string
Name string
User string
Password string
}
func (d DBConfig) DSN() string {
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=10s",
d.User, d.Password, d.Host, d.Port, d.Name)
}
type MailConfig struct {
Host string
Port int
Username string
Password string
}
type AuthCenterConfig struct {
APIURL string
AdminToken string
}
type AIProviderConfig struct {
APIKey string `json:"api_key"`
APIBase string `json:"api_base"`
Models []string `json:"model"`
}
type AIConfig struct {
Providers map[string]AIProviderConfig
}
var Cfg *AppConfig
const (
envDevelopment = "development"
envProduction = "production"
defaultDevDBHost = "10.1.1.100"
defaultDevDBPort = "3306"
defaultDevDBName = "infogenie-test"
defaultDevDBUser = "infogenie-test"
defaultDevDBPassword = "infogenie-test"
)
func Load() (*AppConfig, error) {
env := normalizeEnv(os.Getenv("APP_ENV"))
if env != envDevelopment && env != envProduction {
return nil, fmt.Errorf("不支持的APP_ENV: %s", env)
}
if err := loadEnvFile(env); err != nil {
return nil, err
}
mailPort, _ := strconv.Atoi(getEnv("MAIL_PORT", "465"))
dbHost, err := getEnvByEnvironment(env, "DB_HOST", defaultDevDBHost)
if err != nil {
return nil, err
}
dbPort, err := getEnvByEnvironment(env, "DB_PORT", defaultDevDBPort)
if err != nil {
return nil, err
}
dbName, err := getEnvByEnvironment(env, "DB_NAME", defaultDevDBName)
if err != nil {
return nil, err
}
dbUser, err := getEnvByEnvironment(env, "DB_USER", defaultDevDBUser)
if err != nil {
return nil, err
}
dbPassword, err := getEnvByEnvironment(env, "DB_PASSWORD", defaultDevDBPassword)
if err != nil {
return nil, err
}
cfg := &AppConfig{
Env: env,
Port: getEnv("APP_PORT", "5002"),
DB: DBConfig{
Host: dbHost,
Port: dbPort,
Name: dbName,
User: dbUser,
Password: dbPassword,
},
Mail: MailConfig{
Host: getEnv("MAIL_HOST", "smtp.qq.com"),
Port: mailPort,
Username: getEnv("MAIL_USERNAME", ""),
Password: getEnv("MAIL_PASSWORD", ""),
},
AuthCenter: AuthCenterConfig{
APIURL: getEnv("AUTH_CENTER_API_URL", "https://auth.api.shumengya.top"),
AdminToken: getEnv("AUTH_CENTER_ADMIN_TOKEN", ""),
},
SiteAdminToken: getEnv("INFOGENIE_SITE_ADMIN_TOKEN", ""),
}
if err := validateDBConfig(cfg); err != nil {
return nil, err
}
// AI配置现在完全从数据库读取不再加载ai_config.json文件
cfg.AI = AIConfig{Providers: make(map[string]AIProviderConfig)}
Cfg = cfg
return cfg, nil
}
func loadEnvFile(env string) error {
envFile := fmt.Sprintf(".env.%s", env)
if _, err := os.Stat(envFile); err == nil {
return godotenv.Load(envFile)
}
return nil
}
func validateDBConfig(cfg *AppConfig) error {
switch cfg.Env {
case envDevelopment:
if !isDevelopmentDBTarget(cfg.DB) {
return fmt.Errorf("开发环境必须使用测试数据库: host=%s name=%s", cfg.DB.Host, cfg.DB.Name)
}
if looksProductionLike(cfg.DB) {
return fmt.Errorf("开发环境检测到生产数据库配置: host=%s name=%s", cfg.DB.Host, cfg.DB.Name)
}
case envProduction:
missing := make([]string, 0, 4)
if strings.TrimSpace(cfg.DB.Host) == "" {
missing = append(missing, "DB_HOST")
}
if strings.TrimSpace(cfg.DB.Port) == "" {
missing = append(missing, "DB_PORT")
}
if strings.TrimSpace(cfg.DB.Name) == "" {
missing = append(missing, "DB_NAME")
}
if strings.TrimSpace(cfg.DB.User) == "" {
missing = append(missing, "DB_USER")
}
if strings.TrimSpace(cfg.DB.Password) == "" {
missing = append(missing, "DB_PASSWORD")
}
if len(missing) > 0 {
return fmt.Errorf("生产环境缺少必需数据库配置: %s", strings.Join(missing, ", "))
}
if isDevelopmentDBTarget(cfg.DB) {
return fmt.Errorf("生产环境数据库配置看起来像开发/测试环境: host=%s name=%s", cfg.DB.Host, cfg.DB.Name)
}
default:
return fmt.Errorf("不支持的APP_ENV: %s", cfg.Env)
}
return nil
}
func isDevelopmentDBTarget(db DBConfig) bool {
host := strings.ToLower(strings.TrimSpace(db.Host))
name := strings.ToLower(strings.TrimSpace(db.Name))
if host == "localhost" || strings.HasPrefix(host, "127.") || strings.HasPrefix(host, "10.1.1.") {
return true
}
if strings.Contains(host, "dev") || strings.Contains(host, "test") || strings.Contains(host, "local") {
return true
}
if strings.Contains(name, "test") || strings.Contains(name, "dev") || strings.Contains(name, "local") {
return true
}
return false
}
func looksProductionLike(db DBConfig) bool {
host := strings.ToLower(strings.TrimSpace(db.Host))
name := strings.ToLower(strings.TrimSpace(db.Name))
return strings.Contains(host, "bigmengya") || strings.Contains(host, "shumengya.top") || strings.Contains(name, "prod") || strings.Contains(name, "production")
}
func normalizeEnv(raw string) string {
env := strings.ToLower(strings.TrimSpace(raw))
if env == "" {
return envDevelopment
}
return env
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvByEnvironment(env, key, devFallback string) (string, error) {
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
return v, nil
}
if env == envProduction {
return "", fmt.Errorf("生产环境缺少必需配置: %s", key)
}
return devFallback, nil
}