234 lines
6.0 KiB
Go
234 lines
6.0 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/joho/godotenv"
|
||
)
|
||
|
||
type AppConfig struct {
|
||
Env string
|
||
Port string
|
||
|
||
DB DBConfig
|
||
Mail MailConfig
|
||
AI AIConfig
|
||
AuthCenter AuthCenterConfig
|
||
// SiteAdminToken 与前端管理员口令一致,用于更新站点展示配置(如 60s 功能开关);为空则禁止写入
|
||
SiteAdminToken string
|
||
}
|
||
|
||
type DBConfig struct {
|
||
Host string
|
||
Port string
|
||
Name string
|
||
User string
|
||
Password string
|
||
}
|
||
|
||
func (d DBConfig) DSN() string {
|
||
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=10s",
|
||
d.User, d.Password, d.Host, d.Port, d.Name)
|
||
}
|
||
|
||
type MailConfig struct {
|
||
Host string
|
||
Port int
|
||
Username string
|
||
Password string
|
||
}
|
||
|
||
type AuthCenterConfig struct {
|
||
APIURL string
|
||
AdminToken string
|
||
}
|
||
|
||
type AIProviderConfig struct {
|
||
APIKey string `json:"api_key"`
|
||
APIBase string `json:"api_base"`
|
||
Models []string `json:"model"`
|
||
}
|
||
|
||
type AIConfig struct {
|
||
Providers map[string]AIProviderConfig
|
||
}
|
||
|
||
var Cfg *AppConfig
|
||
|
||
const (
|
||
envDevelopment = "development"
|
||
envProduction = "production"
|
||
|
||
defaultDevDBHost = "10.1.1.100"
|
||
defaultDevDBPort = "3306"
|
||
defaultDevDBName = "infogenie-test"
|
||
defaultDevDBUser = "infogenie-test"
|
||
defaultDevDBPassword = "infogenie-test"
|
||
)
|
||
|
||
func Load() (*AppConfig, error) {
|
||
env := normalizeEnv(os.Getenv("APP_ENV"))
|
||
if env != envDevelopment && env != envProduction {
|
||
return nil, fmt.Errorf("不支持的APP_ENV: %s", env)
|
||
}
|
||
|
||
if err := loadEnvFile(env); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
mailPort, _ := strconv.Atoi(getEnv("MAIL_PORT", "465"))
|
||
|
||
dbHost, err := getEnvByEnvironment(env, "DB_HOST", defaultDevDBHost)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
dbPort, err := getEnvByEnvironment(env, "DB_PORT", defaultDevDBPort)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
dbName, err := getEnvByEnvironment(env, "DB_NAME", defaultDevDBName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
dbUser, err := getEnvByEnvironment(env, "DB_USER", defaultDevDBUser)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
dbPassword, err := getEnvByEnvironment(env, "DB_PASSWORD", defaultDevDBPassword)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
cfg := &AppConfig{
|
||
Env: env,
|
||
Port: getEnv("APP_PORT", "5002"),
|
||
DB: DBConfig{
|
||
Host: dbHost,
|
||
Port: dbPort,
|
||
Name: dbName,
|
||
User: dbUser,
|
||
Password: dbPassword,
|
||
},
|
||
Mail: MailConfig{
|
||
Host: getEnv("MAIL_HOST", "smtp.qq.com"),
|
||
Port: mailPort,
|
||
Username: getEnv("MAIL_USERNAME", ""),
|
||
Password: getEnv("MAIL_PASSWORD", ""),
|
||
},
|
||
AuthCenter: AuthCenterConfig{
|
||
APIURL: getEnv("AUTH_CENTER_API_URL", "https://auth.api.shumengya.top"),
|
||
AdminToken: getEnv("AUTH_CENTER_ADMIN_TOKEN", ""),
|
||
},
|
||
SiteAdminToken: getEnv("INFOGENIE_SITE_ADMIN_TOKEN", ""),
|
||
}
|
||
|
||
if err := validateDBConfig(cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// AI配置现在完全从数据库读取,不再加载ai_config.json文件
|
||
cfg.AI = AIConfig{Providers: make(map[string]AIProviderConfig)}
|
||
|
||
Cfg = cfg
|
||
return cfg, nil
|
||
}
|
||
|
||
func loadEnvFile(env string) error {
|
||
envFile := fmt.Sprintf(".env.%s", env)
|
||
if _, err := os.Stat(envFile); err == nil {
|
||
return godotenv.Load(envFile)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func validateDBConfig(cfg *AppConfig) error {
|
||
switch cfg.Env {
|
||
case envDevelopment:
|
||
if !isDevelopmentDBTarget(cfg.DB) {
|
||
return fmt.Errorf("开发环境必须使用测试数据库: host=%s name=%s", cfg.DB.Host, cfg.DB.Name)
|
||
}
|
||
if looksProductionLike(cfg.DB) {
|
||
return fmt.Errorf("开发环境检测到生产数据库配置: host=%s name=%s", cfg.DB.Host, cfg.DB.Name)
|
||
}
|
||
case envProduction:
|
||
missing := make([]string, 0, 4)
|
||
if strings.TrimSpace(cfg.DB.Host) == "" {
|
||
missing = append(missing, "DB_HOST")
|
||
}
|
||
if strings.TrimSpace(cfg.DB.Port) == "" {
|
||
missing = append(missing, "DB_PORT")
|
||
}
|
||
if strings.TrimSpace(cfg.DB.Name) == "" {
|
||
missing = append(missing, "DB_NAME")
|
||
}
|
||
if strings.TrimSpace(cfg.DB.User) == "" {
|
||
missing = append(missing, "DB_USER")
|
||
}
|
||
if strings.TrimSpace(cfg.DB.Password) == "" {
|
||
missing = append(missing, "DB_PASSWORD")
|
||
}
|
||
if len(missing) > 0 {
|
||
return fmt.Errorf("生产环境缺少必需数据库配置: %s", strings.Join(missing, ", "))
|
||
}
|
||
if isDevelopmentDBTarget(cfg.DB) {
|
||
return fmt.Errorf("生产环境数据库配置看起来像开发/测试环境: host=%s name=%s", cfg.DB.Host, cfg.DB.Name)
|
||
}
|
||
default:
|
||
return fmt.Errorf("不支持的APP_ENV: %s", cfg.Env)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func isDevelopmentDBTarget(db DBConfig) bool {
|
||
host := strings.ToLower(strings.TrimSpace(db.Host))
|
||
name := strings.ToLower(strings.TrimSpace(db.Name))
|
||
|
||
if host == "localhost" || strings.HasPrefix(host, "127.") || strings.HasPrefix(host, "10.1.1.") {
|
||
return true
|
||
}
|
||
if strings.Contains(host, "dev") || strings.Contains(host, "test") || strings.Contains(host, "local") {
|
||
return true
|
||
}
|
||
if strings.Contains(name, "test") || strings.Contains(name, "dev") || strings.Contains(name, "local") {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func looksProductionLike(db DBConfig) bool {
|
||
host := strings.ToLower(strings.TrimSpace(db.Host))
|
||
name := strings.ToLower(strings.TrimSpace(db.Name))
|
||
return strings.Contains(host, "bigmengya") || strings.Contains(host, "shumengya.top") || strings.Contains(name, "prod") || strings.Contains(name, "production")
|
||
}
|
||
|
||
func normalizeEnv(raw string) string {
|
||
env := strings.ToLower(strings.TrimSpace(raw))
|
||
if env == "" {
|
||
return envDevelopment
|
||
}
|
||
return env
|
||
}
|
||
|
||
func getEnv(key, fallback string) string {
|
||
if v := os.Getenv(key); v != "" {
|
||
return v
|
||
}
|
||
return fallback
|
||
}
|
||
|
||
func getEnvByEnvironment(env, key, devFallback string) (string, error) {
|
||
if v := strings.TrimSpace(os.Getenv(key)); v != "" {
|
||
return v, nil
|
||
}
|
||
|
||
if env == envProduction {
|
||
return "", fmt.Errorf("生产环境缺少必需配置: %s", key)
|
||
}
|
||
|
||
return devFallback, nil
|
||
}
|