Files
2026-03-20 20:42:33 +08:00

588 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package storage
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
"sync"
"sproutgate-backend/internal/models"
)
type AdminConfig struct {
Token string `json:"token"`
}
type AuthConfig struct {
JWTSecret string `json:"jwtSecret"`
Issuer string `json:"issuer"`
}
type EmailConfig struct {
FromName string `json:"fromName"`
FromAddress string `json:"fromAddress"`
Username string `json:"username"`
Password string `json:"password"`
SMTPHost string `json:"smtpHost"`
SMTPPort int `json:"smtpPort"`
Encryption string `json:"encryption"`
}
type CheckInConfig struct {
RewardCoins int `json:"rewardCoins"`
}
type Store struct {
dataDir string
usersDir string
pendingDir string
resetDir string
secondaryDir string
adminConfigPath string
authConfigPath string
emailConfigPath string
checkInPath string
registrationPath string
registrationConfig RegistrationConfig
adminToken string
jwtSecret []byte
issuer string
emailConfig EmailConfig
checkInConfig CheckInConfig
mu sync.Mutex
}
func NewStore(dataDir string) (*Store, error) {
if dataDir == "" {
dataDir = "./data"
}
absDir, err := filepath.Abs(dataDir)
if err != nil {
return nil, err
}
usersDir := filepath.Join(absDir, "users")
pendingDir := filepath.Join(absDir, "pending")
resetDir := filepath.Join(absDir, "reset")
secondaryDir := filepath.Join(absDir, "secondary")
configDir := filepath.Join(absDir, "config")
if err := os.MkdirAll(usersDir, 0755); err != nil {
return nil, err
}
if err := os.MkdirAll(pendingDir, 0755); err != nil {
return nil, err
}
if err := os.MkdirAll(resetDir, 0755); err != nil {
return nil, err
}
if err := os.MkdirAll(secondaryDir, 0755); err != nil {
return nil, err
}
if err := os.MkdirAll(configDir, 0755); err != nil {
return nil, err
}
store := &Store{
dataDir: absDir,
usersDir: usersDir,
pendingDir: pendingDir,
resetDir: resetDir,
secondaryDir: secondaryDir,
adminConfigPath: filepath.Join(configDir, "admin.json"),
authConfigPath: filepath.Join(configDir, "auth.json"),
emailConfigPath: filepath.Join(configDir, "email.json"),
checkInPath: filepath.Join(configDir, "checkin.json"),
registrationPath: filepath.Join(configDir, "registration.json"),
}
if err := store.loadOrCreateAdminConfig(); err != nil {
return nil, err
}
if err := store.loadOrCreateAuthConfig(); err != nil {
return nil, err
}
if err := store.loadOrCreateEmailConfig(); err != nil {
return nil, err
}
if err := store.loadOrCreateCheckInConfig(); err != nil {
return nil, err
}
if err := store.loadOrCreateRegistrationConfig(); err != nil {
return nil, err
}
return store, nil
}
func (s *Store) DataDir() string {
return s.dataDir
}
func (s *Store) AdminToken() string {
return s.adminToken
}
func (s *Store) JWTSecret() []byte {
return s.jwtSecret
}
func (s *Store) JWTIssuer() string {
return s.issuer
}
func (s *Store) EmailConfig() EmailConfig {
return s.emailConfig
}
func (s *Store) CheckInConfig() CheckInConfig {
s.mu.Lock()
defer s.mu.Unlock()
cfg := s.checkInConfig
if cfg.RewardCoins <= 0 {
cfg.RewardCoins = 1
}
return cfg
}
func (s *Store) UpdateCheckInConfig(cfg CheckInConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
if cfg.RewardCoins <= 0 {
cfg.RewardCoins = 1
}
if err := writeJSONFile(s.checkInPath, cfg); err != nil {
return err
}
s.checkInConfig = cfg
return nil
}
func (s *Store) loadOrCreateAdminConfig() error {
if _, err := os.Stat(s.adminConfigPath); errors.Is(err, os.ErrNotExist) {
token, err := generateToken()
if err != nil {
return err
}
cfg := AdminConfig{Token: token}
if err := writeJSONFile(s.adminConfigPath, cfg); err != nil {
return err
}
s.adminToken = cfg.Token
return nil
}
var cfg AdminConfig
if err := readJSONFile(s.adminConfigPath, &cfg); err != nil {
return err
}
if strings.TrimSpace(cfg.Token) == "" {
token, err := generateToken()
if err != nil {
return err
}
cfg.Token = token
if err := writeJSONFile(s.adminConfigPath, cfg); err != nil {
return err
}
}
s.adminToken = cfg.Token
return nil
}
func (s *Store) loadOrCreateAuthConfig() error {
if _, err := os.Stat(s.authConfigPath); errors.Is(err, os.ErrNotExist) {
secret, err := generateSecret()
if err != nil {
return err
}
cfg := AuthConfig{
JWTSecret: base64.StdEncoding.EncodeToString(secret),
Issuer: "sproutgate",
}
if err := writeJSONFile(s.authConfigPath, cfg); err != nil {
return err
}
s.jwtSecret = secret
s.issuer = cfg.Issuer
return nil
}
var cfg AuthConfig
if err := readJSONFile(s.authConfigPath, &cfg); err != nil {
return err
}
secretBytes, err := base64.StdEncoding.DecodeString(cfg.JWTSecret)
if err != nil || len(secretBytes) == 0 {
secretBytes, err = generateSecret()
if err != nil {
return err
}
cfg.JWTSecret = base64.StdEncoding.EncodeToString(secretBytes)
if strings.TrimSpace(cfg.Issuer) == "" {
cfg.Issuer = "sproutgate"
}
if err := writeJSONFile(s.authConfigPath, cfg); err != nil {
return err
}
}
if strings.TrimSpace(cfg.Issuer) == "" {
cfg.Issuer = "sproutgate"
if err := writeJSONFile(s.authConfigPath, cfg); err != nil {
return err
}
}
s.jwtSecret = secretBytes
s.issuer = cfg.Issuer
return nil
}
func (s *Store) loadOrCreateEmailConfig() error {
if _, err := os.Stat(s.emailConfigPath); errors.Is(err, os.ErrNotExist) {
cfg := EmailConfig{
FromName: "萌芽账户认证中心",
FromAddress: "notice@smyhub.com",
Username: "",
Password: "",
SMTPHost: "smtp.qiye.aliyun.com",
SMTPPort: 465,
Encryption: "SSL",
}
if err := writeJSONFile(s.emailConfigPath, cfg); err != nil {
return err
}
if cfg.Username == "" {
cfg.Username = cfg.FromAddress
}
s.emailConfig = cfg
return nil
}
var cfg EmailConfig
if err := readJSONFile(s.emailConfigPath, &cfg); err != nil {
return err
}
if strings.TrimSpace(cfg.FromName) == "" {
cfg.FromName = "萌芽账户认证中心"
}
if strings.TrimSpace(cfg.FromAddress) == "" {
cfg.FromAddress = "notice@smyhub.com"
}
if strings.TrimSpace(cfg.Username) == "" {
cfg.Username = cfg.FromAddress
}
if strings.TrimSpace(cfg.SMTPHost) == "" {
cfg.SMTPHost = "smtp.qiye.aliyun.com"
}
if cfg.SMTPPort == 0 {
cfg.SMTPPort = 465
}
if strings.TrimSpace(cfg.Encryption) == "" {
cfg.Encryption = "SSL"
}
if err := writeJSONFile(s.emailConfigPath, cfg); err != nil {
return err
}
s.emailConfig = cfg
return nil
}
func (s *Store) loadOrCreateCheckInConfig() error {
if _, err := os.Stat(s.checkInPath); errors.Is(err, os.ErrNotExist) {
cfg := CheckInConfig{RewardCoins: 1}
if err := writeJSONFile(s.checkInPath, cfg); err != nil {
return err
}
s.checkInConfig = cfg
return nil
}
var cfg CheckInConfig
if err := readJSONFile(s.checkInPath, &cfg); err != nil {
return err
}
if cfg.RewardCoins <= 0 {
cfg.RewardCoins = 1
if err := writeJSONFile(s.checkInPath, cfg); err != nil {
return err
}
}
s.checkInConfig = cfg
return nil
}
func generateSecret() ([]byte, error) {
secret := make([]byte, 32)
_, err := rand.Read(secret)
return secret, err
}
func generateToken() (string, error) {
secret, err := generateSecret()
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(secret), nil
}
func (s *Store) ListUsers() ([]models.UserRecord, error) {
s.mu.Lock()
defer s.mu.Unlock()
entries, err := os.ReadDir(s.usersDir)
if err != nil {
return nil, err
}
users := make([]models.UserRecord, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
if !strings.HasSuffix(entry.Name(), ".json") {
continue
}
var record models.UserRecord
path := filepath.Join(s.usersDir, entry.Name())
if err := readJSONFile(path, &record); err != nil {
return nil, err
}
users = append(users, record)
}
return users, nil
}
func (s *Store) GetUser(account string) (models.UserRecord, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(account)
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return models.UserRecord{}, false, nil
}
var record models.UserRecord
if err := readJSONFile(path, &record); err != nil {
return models.UserRecord{}, false, err
}
return record, true, nil
}
func (s *Store) CreateUser(record models.UserRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(record.Account)
if _, err := os.Stat(path); err == nil {
return errors.New("account already exists")
}
if record.CreatedAt == "" {
record.CreatedAt = models.NowISO()
}
record.UpdatedAt = record.CreatedAt
return writeJSONFile(path, record)
}
func (s *Store) SaveUser(record models.UserRecord) error {
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(record.Account)
record.UpdatedAt = models.NowISO()
return writeJSONFile(path, record)
}
// RecordAuthClient 在成功认证后记录第三方应用标识clientID 须已规范化)。
func (s *Store) RecordAuthClient(account string, clientID string, displayName string) (models.UserRecord, error) {
if clientID == "" {
return models.UserRecord{}, errors.New("client id required")
}
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(account)
var record models.UserRecord
if err := readJSONFile(path, &record); err != nil {
if errors.Is(err, os.ErrNotExist) {
return models.UserRecord{}, os.ErrNotExist
}
return models.UserRecord{}, err
}
now := models.NowISO()
displayName = models.ClampAuthClientName(displayName)
found := false
for i := range record.AuthClients {
if record.AuthClients[i].ClientID == clientID {
record.AuthClients[i].LastSeenAt = now
if displayName != "" {
record.AuthClients[i].DisplayName = displayName
}
found = true
break
}
}
if !found {
record.AuthClients = append(record.AuthClients, models.AuthClientEntry{
ClientID: clientID,
DisplayName: displayName,
FirstSeenAt: now,
LastSeenAt: now,
})
}
record.UpdatedAt = now
if err := writeJSONFile(path, &record); err != nil {
return models.UserRecord{}, err
}
return record, nil
}
func (s *Store) RecordVisit(account string, today string, at string) (models.UserRecord, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(account)
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return models.UserRecord{}, false, os.ErrNotExist
}
var record models.UserRecord
if err := readJSONFile(path, &record); err != nil {
return models.UserRecord{}, false, err
}
if record.LastVisitDate == today || models.HasActivityDate(record.VisitTimes, today) {
return record, false, nil
}
if strings.TrimSpace(at) == "" {
at = models.CurrentActivityTime()
}
record.LastVisitDate = today
record.LastVisitAt = at
record.VisitTimes = append(record.VisitTimes, at)
if record.CreatedAt == "" {
record.CreatedAt = models.NowISO()
}
record.UpdatedAt = models.NowISO()
if err := writeJSONFile(path, record); err != nil {
return models.UserRecord{}, false, err
}
return record, true, nil
}
const maxLastVisitIPLen = 45
const maxLastVisitDisplayLocationLen = 512
func clampVisitMeta(ip, displayLocation string) (string, string) {
ip = strings.TrimSpace(ip)
displayLocation = strings.TrimSpace(displayLocation)
if len(ip) > maxLastVisitIPLen {
ip = ip[:maxLastVisitIPLen]
}
if len(displayLocation) > maxLastVisitDisplayLocationLen {
displayLocation = displayLocation[:maxLastVisitDisplayLocationLen]
}
return ip, displayLocation
}
// UpdateLastVisitMeta 更新用户最近一次访问的客户端 IP 与展示用地理位置(由前端调用地理接口后传入)。
func (s *Store) UpdateLastVisitMeta(account string, ip string, displayLocation string) (models.UserRecord, error) {
ip, displayLocation = clampVisitMeta(ip, displayLocation)
if ip == "" && displayLocation == "" {
rec, found, err := s.GetUser(account)
if err != nil {
return models.UserRecord{}, err
}
if !found {
return models.UserRecord{}, os.ErrNotExist
}
return rec, nil
}
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(account)
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return models.UserRecord{}, os.ErrNotExist
}
var record models.UserRecord
if err := readJSONFile(path, &record); err != nil {
return models.UserRecord{}, err
}
if ip != "" {
record.LastVisitIP = ip
}
if displayLocation != "" {
record.LastVisitDisplayLocation = displayLocation
}
record.UpdatedAt = models.NowISO()
if err := writeJSONFile(path, record); err != nil {
return models.UserRecord{}, err
}
return record, nil
}
func (s *Store) CheckIn(account string, today string, at string) (models.UserRecord, int, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(account)
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return models.UserRecord{}, 0, false, os.ErrNotExist
}
var record models.UserRecord
if err := readJSONFile(path, &record); err != nil {
return models.UserRecord{}, 0, false, err
}
if record.LastCheckInDate == today || models.HasActivityDate(record.CheckInTimes, today) {
return record, 0, true, nil
}
reward := s.checkInConfig.RewardCoins
if reward <= 0 {
reward = 1
}
record.SproutCoins += reward
record.LastCheckInDate = today
if strings.TrimSpace(at) == "" {
at = models.CurrentActivityTime()
}
record.LastCheckInAt = at
record.CheckInTimes = append(record.CheckInTimes, at)
if record.CreatedAt == "" {
record.CreatedAt = models.NowISO()
}
record.UpdatedAt = models.NowISO()
if err := writeJSONFile(path, record); err != nil {
return models.UserRecord{}, 0, false, err
}
return record, reward, false, nil
}
func (s *Store) DeleteUser(account string) error {
s.mu.Lock()
defer s.mu.Unlock()
path := s.userFilePath(account)
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return nil
}
return os.Remove(path)
}
func (s *Store) userFilePath(account string) string {
return filepath.Join(s.usersDir, userFileName(account))
}
func userFileName(account string) string {
encoded := base64.RawURLEncoding.EncodeToString([]byte(account))
return encoded + ".json"
}
func readJSONFile(path string, target any) error {
raw, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(raw, target)
}
func writeJSONFile(path string, value any) error {
raw, err := json.MarshalIndent(value, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, raw, 0644)
}