215 lines
5.7 KiB
Go
215 lines
5.7 KiB
Go
package storage
|
||
|
||
import (
|
||
"crypto/rand"
|
||
"errors"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
"sproutgate-backend/internal/models"
|
||
)
|
||
|
||
// InviteEntry 管理员发放的注册邀请码。
|
||
type InviteEntry struct {
|
||
Code string `json:"code"`
|
||
Note string `json:"note,omitempty"`
|
||
MaxUses int `json:"maxUses"` // 0 表示不限次数
|
||
Uses int `json:"uses"`
|
||
ExpiresAt string `json:"expiresAt,omitempty"` // RFC3339,空表示不过期
|
||
CreatedAt string `json:"createdAt"`
|
||
}
|
||
|
||
// RegistrationConfig 注册策略与邀请码列表(data/config/registration.json)。
|
||
type RegistrationConfig struct {
|
||
RequireInviteCode bool `json:"requireInviteCode"`
|
||
Invites []InviteEntry `json:"invites"`
|
||
}
|
||
|
||
func normalizeInviteCode(raw string) string {
|
||
return strings.ToUpper(strings.TrimSpace(raw))
|
||
}
|
||
|
||
func (s *Store) loadOrCreateRegistrationConfig() error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
if _, err := os.Stat(s.registrationPath); errors.Is(err, os.ErrNotExist) {
|
||
cfg := RegistrationConfig{RequireInviteCode: false, Invites: []InviteEntry{}}
|
||
if err := writeJSONFile(s.registrationPath, cfg); err != nil {
|
||
return err
|
||
}
|
||
s.registrationConfig = cfg
|
||
return nil
|
||
}
|
||
var cfg RegistrationConfig
|
||
if err := readJSONFile(s.registrationPath, &cfg); err != nil {
|
||
return err
|
||
}
|
||
if cfg.Invites == nil {
|
||
cfg.Invites = []InviteEntry{}
|
||
}
|
||
s.registrationConfig = cfg
|
||
return nil
|
||
}
|
||
|
||
func (s *Store) persistRegistrationConfigLocked() error {
|
||
return writeJSONFile(s.registrationPath, s.registrationConfig)
|
||
}
|
||
|
||
// RegistrationRequireInvite 是否强制要求邀请码才能发起注册(发邮件验证码)。
|
||
func (s *Store) RegistrationRequireInvite() bool {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
return s.registrationConfig.RequireInviteCode
|
||
}
|
||
|
||
// GetRegistrationConfig 返回配置副本(管理端)。
|
||
func (s *Store) GetRegistrationConfig() RegistrationConfig {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
out := s.registrationConfig
|
||
out.Invites = append([]InviteEntry(nil), s.registrationConfig.Invites...)
|
||
return out
|
||
}
|
||
|
||
// SetRegistrationRequireInvite 更新是否强制邀请码。
|
||
func (s *Store) SetRegistrationRequireInvite(require bool) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.registrationConfig.RequireInviteCode = require
|
||
return s.persistRegistrationConfigLocked()
|
||
}
|
||
|
||
func inviteEntryValid(e *InviteEntry) error {
|
||
if strings.TrimSpace(e.ExpiresAt) != "" {
|
||
t, err := time.Parse(time.RFC3339, e.ExpiresAt)
|
||
if err == nil && time.Now().After(t) {
|
||
return errors.New("invite code expired")
|
||
}
|
||
}
|
||
if e.MaxUses > 0 && e.Uses >= e.MaxUses {
|
||
return errors.New("invite code has been fully used")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ValidateInviteForRegister 校验邀请码是否可用(发验证码前,不扣次)。
|
||
func (s *Store) ValidateInviteForRegister(code string) error {
|
||
n := normalizeInviteCode(code)
|
||
if n == "" {
|
||
return errors.New("invite code is required")
|
||
}
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
for i := range s.registrationConfig.Invites {
|
||
e := &s.registrationConfig.Invites[i]
|
||
if strings.EqualFold(e.Code, n) {
|
||
return inviteEntryValid(e)
|
||
}
|
||
}
|
||
return errors.New("invalid invite code")
|
||
}
|
||
|
||
// RedeemInvite 邮箱验证通过创建用户后扣减邀请码使用次数。
|
||
func (s *Store) RedeemInvite(code string) error {
|
||
n := normalizeInviteCode(code)
|
||
if n == "" {
|
||
return nil
|
||
}
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
for i := range s.registrationConfig.Invites {
|
||
e := &s.registrationConfig.Invites[i]
|
||
if strings.EqualFold(e.Code, n) {
|
||
if err := inviteEntryValid(e); err != nil {
|
||
return err
|
||
}
|
||
e.Uses++
|
||
return s.persistRegistrationConfigLocked()
|
||
}
|
||
}
|
||
return errors.New("invalid invite code")
|
||
}
|
||
|
||
const inviteCodeAlphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||
|
||
func randomInviteToken(n int) (string, error) {
|
||
b := make([]byte, n)
|
||
if _, err := rand.Read(b); err != nil {
|
||
return "", err
|
||
}
|
||
var sb strings.Builder
|
||
sb.Grow(n)
|
||
for i := 0; i < n; i++ {
|
||
sb.WriteByte(inviteCodeAlphabet[int(b[i])%len(inviteCodeAlphabet)])
|
||
}
|
||
return sb.String(), nil
|
||
}
|
||
|
||
// AddInviteEntry 生成新邀请码并写入配置。
|
||
func (s *Store) AddInviteEntry(note string, maxUses int, expiresAt string) (InviteEntry, error) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
var code string
|
||
for attempt := 0; attempt < 24; attempt++ {
|
||
c, err := randomInviteToken(8)
|
||
if err != nil {
|
||
return InviteEntry{}, err
|
||
}
|
||
dup := false
|
||
for _, ex := range s.registrationConfig.Invites {
|
||
if strings.EqualFold(ex.Code, c) {
|
||
dup = true
|
||
break
|
||
}
|
||
}
|
||
if !dup {
|
||
code = c
|
||
break
|
||
}
|
||
}
|
||
if code == "" {
|
||
return InviteEntry{}, errors.New("failed to generate unique invite code")
|
||
}
|
||
expiresAt = strings.TrimSpace(expiresAt)
|
||
if expiresAt != "" {
|
||
if _, err := time.Parse(time.RFC3339, expiresAt); err != nil {
|
||
return InviteEntry{}, errors.New("invalid expiresAt (use RFC3339)")
|
||
}
|
||
}
|
||
if maxUses < 0 {
|
||
maxUses = 0
|
||
}
|
||
entry := InviteEntry{
|
||
Code: code,
|
||
Note: strings.TrimSpace(note),
|
||
MaxUses: maxUses,
|
||
Uses: 0,
|
||
ExpiresAt: expiresAt,
|
||
CreatedAt: models.NowISO(),
|
||
}
|
||
s.registrationConfig.Invites = append(s.registrationConfig.Invites, entry)
|
||
if err := s.persistRegistrationConfigLocked(); err != nil {
|
||
s.registrationConfig.Invites = s.registrationConfig.Invites[:len(s.registrationConfig.Invites)-1]
|
||
return InviteEntry{}, err
|
||
}
|
||
return entry, nil
|
||
}
|
||
|
||
// DeleteInviteEntry 按码删除(大小写不敏感)。
|
||
func (s *Store) DeleteInviteEntry(code string) error {
|
||
n := normalizeInviteCode(code)
|
||
if n == "" {
|
||
return errors.New("code is required")
|
||
}
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
for i, e := range s.registrationConfig.Invites {
|
||
if strings.EqualFold(e.Code, n) {
|
||
s.registrationConfig.Invites = append(s.registrationConfig.Invites[:i], s.registrationConfig.Invites[i+1:]...)
|
||
return s.persistRegistrationConfigLocked()
|
||
}
|
||
}
|
||
return errors.New("invite not found")
|
||
}
|