package storage import ( "crypto/rand" "errors" "os" "strings" "time" "sproutgate-backend/internal/models" ) // InviteEntry 管理员发放的注册邀请码。 type InviteEntry struct { Code string `json:"code"` Note string `json:"note,omitempty"` MaxUses int `json:"maxUses"` // 0 表示不限次数 Uses int `json:"uses"` ExpiresAt string `json:"expiresAt,omitempty"` // RFC3339,空表示不过期 CreatedAt string `json:"createdAt"` } // RegistrationConfig 注册策略与邀请码列表(data/config/registration.json)。 type RegistrationConfig struct { RequireInviteCode bool `json:"requireInviteCode"` Invites []InviteEntry `json:"invites"` } func normalizeInviteCode(raw string) string { return strings.ToUpper(strings.TrimSpace(raw)) } func (s *Store) loadOrCreateRegistrationConfig() error { s.mu.Lock() defer s.mu.Unlock() if _, err := os.Stat(s.registrationPath); errors.Is(err, os.ErrNotExist) { cfg := RegistrationConfig{RequireInviteCode: false, Invites: []InviteEntry{}} if err := writeJSONFile(s.registrationPath, cfg); err != nil { return err } s.registrationConfig = cfg return nil } var cfg RegistrationConfig if err := readJSONFile(s.registrationPath, &cfg); err != nil { return err } if cfg.Invites == nil { cfg.Invites = []InviteEntry{} } s.registrationConfig = cfg return nil } func (s *Store) persistRegistrationConfigLocked() error { return writeJSONFile(s.registrationPath, s.registrationConfig) } // RegistrationRequireInvite 是否强制要求邀请码才能发起注册(发邮件验证码)。 func (s *Store) RegistrationRequireInvite() bool { s.mu.Lock() defer s.mu.Unlock() return s.registrationConfig.RequireInviteCode } // GetRegistrationConfig 返回配置副本(管理端)。 func (s *Store) GetRegistrationConfig() RegistrationConfig { s.mu.Lock() defer s.mu.Unlock() out := s.registrationConfig out.Invites = append([]InviteEntry(nil), s.registrationConfig.Invites...) return out } // SetRegistrationRequireInvite 更新是否强制邀请码。 func (s *Store) SetRegistrationRequireInvite(require bool) error { s.mu.Lock() defer s.mu.Unlock() s.registrationConfig.RequireInviteCode = require return s.persistRegistrationConfigLocked() } func inviteEntryValid(e *InviteEntry) error { if strings.TrimSpace(e.ExpiresAt) != "" { t, err := time.Parse(time.RFC3339, e.ExpiresAt) if err == nil && time.Now().After(t) { return errors.New("invite code expired") } } if e.MaxUses > 0 && e.Uses >= e.MaxUses { return errors.New("invite code has been fully used") } return nil } // ValidateInviteForRegister 校验邀请码是否可用(发验证码前,不扣次)。 func (s *Store) ValidateInviteForRegister(code string) error { n := normalizeInviteCode(code) if n == "" { return errors.New("invite code is required") } s.mu.Lock() defer s.mu.Unlock() for i := range s.registrationConfig.Invites { e := &s.registrationConfig.Invites[i] if strings.EqualFold(e.Code, n) { return inviteEntryValid(e) } } return errors.New("invalid invite code") } // RedeemInvite 邮箱验证通过创建用户后扣减邀请码使用次数。 func (s *Store) RedeemInvite(code string) error { n := normalizeInviteCode(code) if n == "" { return nil } s.mu.Lock() defer s.mu.Unlock() for i := range s.registrationConfig.Invites { e := &s.registrationConfig.Invites[i] if strings.EqualFold(e.Code, n) { if err := inviteEntryValid(e); err != nil { return err } e.Uses++ return s.persistRegistrationConfigLocked() } } return errors.New("invalid invite code") } const inviteCodeAlphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" func randomInviteToken(n int) (string, error) { b := make([]byte, n) if _, err := rand.Read(b); err != nil { return "", err } var sb strings.Builder sb.Grow(n) for i := 0; i < n; i++ { sb.WriteByte(inviteCodeAlphabet[int(b[i])%len(inviteCodeAlphabet)]) } return sb.String(), nil } // AddInviteEntry 生成新邀请码并写入配置。 func (s *Store) AddInviteEntry(note string, maxUses int, expiresAt string) (InviteEntry, error) { s.mu.Lock() defer s.mu.Unlock() var code string for attempt := 0; attempt < 24; attempt++ { c, err := randomInviteToken(8) if err != nil { return InviteEntry{}, err } dup := false for _, ex := range s.registrationConfig.Invites { if strings.EqualFold(ex.Code, c) { dup = true break } } if !dup { code = c break } } if code == "" { return InviteEntry{}, errors.New("failed to generate unique invite code") } expiresAt = strings.TrimSpace(expiresAt) if expiresAt != "" { if _, err := time.Parse(time.RFC3339, expiresAt); err != nil { return InviteEntry{}, errors.New("invalid expiresAt (use RFC3339)") } } if maxUses < 0 { maxUses = 0 } entry := InviteEntry{ Code: code, Note: strings.TrimSpace(note), MaxUses: maxUses, Uses: 0, ExpiresAt: expiresAt, CreatedAt: models.NowISO(), } s.registrationConfig.Invites = append(s.registrationConfig.Invites, entry) if err := s.persistRegistrationConfigLocked(); err != nil { s.registrationConfig.Invites = s.registrationConfig.Invites[:len(s.registrationConfig.Invites)-1] return InviteEntry{}, err } return entry, nil } // DeleteInviteEntry 按码删除(大小写不敏感)。 func (s *Store) DeleteInviteEntry(code string) error { n := normalizeInviteCode(code) if n == "" { return errors.New("code is required") } s.mu.Lock() defer s.mu.Unlock() for i, e := range s.registrationConfig.Invites { if strings.EqualFold(e.Code, n) { s.registrationConfig.Invites = append(s.registrationConfig.Invites[:i], s.registrationConfig.Invites[i+1:]...) return s.persistRegistrationConfigLocked() } } return errors.New("invite not found") }