初始化提交

This commit is contained in:
2025-12-14 15:40:49 +08:00
commit 410b2f068d
72 changed files with 10460 additions and 0 deletions

View File

@@ -0,0 +1,76 @@
from flask import Flask
from flask_cors import CORS
from config import config
from extensions import db, jwt, mail
import os
# 初始化扩展
# jwt = JWTManager() # Moved to extensions.py
# mail = Mail() # Moved to extensions.py
def create_app(config_name='default'):
"""应用工厂函数"""
app = Flask(__name__)
# 加载配置
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
jwt.init_app(app)
mail.init_app(app)
CORS(app, resources={
r"/api/*": {"origins": "*"},
r"/v1/*": {"origins": "*"}
})
# 注册蓝图
from routes.auth import auth_bp
from routes.user import user_bp
from routes.order import order_bp
from routes.api_service import api_bp
from routes.admin import admin_bp
from routes.apikey import apikey_bp
from routes.v1_api import v1_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(user_bp, url_prefix='/api/user')
app.register_blueprint(order_bp, url_prefix='/api/order')
app.register_blueprint(api_bp, url_prefix='/api/service')
app.register_blueprint(admin_bp, url_prefix='/api/admin')
app.register_blueprint(apikey_bp, url_prefix='/api/apikey')
app.register_blueprint(v1_bp, url_prefix='/v1')
# 创建数据库表
with app.app_context():
db.create_all()
# 创建默认管理员账户
from models import User
admin = User.query.filter_by(email='admin@nba.com').first()
if not admin:
admin = User(
email='admin@nba.com',
username='Admin',
is_admin=True,
is_active=True
)
admin.set_password('admin123')
db.session.add(admin)
db.session.commit()
print('默认管理员账户已创建: admin@nba.com / admin123')
@app.route('/')
def index():
return {'message': 'Nano Banana API Transfer Service', 'version': '1.0.0'}
@app.route('/health')
def health():
return {'status': 'healthy'}
return app
if __name__ == '__main__':
app = create_app(os.getenv('FLASK_ENV', 'development'))
app.run(host='0.0.0.0', port=5000, debug=True)

View File

@@ -0,0 +1,58 @@
import os
from datetime import timedelta
from dotenv import load_dotenv
load_dotenv()
class Config:
"""基础配置"""
# Flask配置
SECRET_KEY = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production')
# 数据库配置
SQLALCHEMY_DATABASE_URI = os.getenv('DATABASE_URI', 'sqlite:///nba_transfer.db')
SQLALCHEMY_TRACK_MODIFICATIONS = False
# JWT配置
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'jwt-secret-key-change-in-production')
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
# 邮件配置
MAIL_SERVER = os.getenv('MAIL_SERVER', 'smtp.qq.com')
MAIL_PORT = int(os.getenv('MAIL_PORT', 465))
MAIL_USE_TLS = os.getenv('MAIL_USE_TLS', 'False') == 'True'
MAIL_USE_SSL = os.getenv('MAIL_USE_SSL', 'True') == 'True'
MAIL_USERNAME = os.getenv('MAIL_USERNAME')
MAIL_PASSWORD = os.getenv('MAIL_PASSWORD')
MAIL_DEFAULT_SENDER = os.getenv('MAIL_USERNAME')
# 支付配置
WECHAT_PAY_APP_ID = os.getenv('WECHAT_PAY_APP_ID')
WECHAT_PAY_MCH_ID = os.getenv('WECHAT_PAY_MCH_ID')
WECHAT_PAY_API_KEY = os.getenv('WECHAT_PAY_API_KEY')
ALIPAY_APP_ID = os.getenv('ALIPAY_APP_ID')
ALIPAY_PRIVATE_KEY = os.getenv('ALIPAY_PRIVATE_KEY')
ALIPAY_PUBLIC_KEY = os.getenv('ALIPAY_PUBLIC_KEY')
# 注意:
# 模型 API 配置 (DeepSeek, NanoBanana) 和 价格策略
# 已迁移至 modelapiservice 模块下的 config.py 中独立管理
# 此处不再保留,避免重复定义和混乱
class DevelopmentConfig(Config):
"""开发环境配置"""
DEBUG = True
class ProductionConfig(Config):
"""生产环境配置"""
DEBUG = False
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'default': DevelopmentConfig
}

View File

@@ -0,0 +1,7 @@
from flask_sqlalchemy import SQLAlchemy
from flask_jwt_extended import JWTManager
from flask_mail import Mail
db = SQLAlchemy()
jwt = JWTManager()
mail = Mail()

View File

@@ -0,0 +1,16 @@
# DeepSeek 模型配置
import os
# 价格配置 (CNY per 1M tokens)
TOKEN_PRICE_INPUT = 400.0 / 1000000 # 4元 / 1M tokens
TOKEN_PRICE_OUTPUT = 1600.0 / 1000000 # 16元 / 1M tokens
# 最低余额要求
MIN_BALANCE = 0.01
# API 配置 (优先从环境变量获取)
def get_config():
return {
'api_url': os.getenv('DEEPSEEK_API_URL', 'https://api.deepseek.com'),
'api_key': os.getenv('DEEPSEEK_API_KEY')
}

View File

@@ -0,0 +1,54 @@
from ..base import ModelService
from .config import TOKEN_PRICE_INPUT, TOKEN_PRICE_OUTPUT, MIN_BALANCE, get_config
from typing import Dict, Any, Tuple, Generator, Union
import json
import logging
logger = logging.getLogger(__name__)
class DeepSeekService(ModelService):
def get_api_config(self) -> Tuple[str, str]:
config = get_config()
return config['api_url'], config['api_key']
def check_balance(self, balance: float) -> Tuple[bool, float, str]:
if balance < MIN_BALANCE:
return False, MIN_BALANCE, f'余额不足。当前余额: {balance}, 需要至少: {MIN_BALANCE}'
return True, 0.0, ""
def calculate_cost(self, usage: Dict[str, Any] = None, stream: bool = False) -> float:
if not usage:
return 0.0
prompt_tokens = usage.get('prompt_tokens', 0)
completion_tokens = usage.get('completion_tokens', 0)
cost = (prompt_tokens * TOKEN_PRICE_INPUT) + (completion_tokens * TOKEN_PRICE_OUTPUT)
return cost
def prepare_payload(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = {k: v for k, v in data.items() if k not in ['user']}
# 强制 DeepSeek 返回 usage
if data.get('stream', False):
if 'stream_options' not in payload:
payload['stream_options'] = {"include_usage": True}
return payload
def handle_response(self, response, stream: bool) -> Union[Dict[str, Any], Generator]:
# Response handling is mostly done in the route handler for stream/json split
# but we can provide helper methods to parse usage
pass
@staticmethod
def parse_stream_usage(chunk_text: str) -> Dict[str, Any]:
"""解析流式响应中的 usage"""
try:
for line in chunk_text.split('\n'):
if line.startswith('data: ') and line != 'data: [DONE]':
json_str = line[6:]
data_obj = json.loads(json_str)
if 'usage' in data_obj and data_obj['usage']:
return data_obj['usage']
except Exception:
pass
return None

View File

@@ -0,0 +1,12 @@
# Nano Banana 模型配置
import os
# 价格配置 (CNY per call)
IMAGE_GENERATION_PRICE = float(os.getenv('IMAGE_GENERATION_PRICE', 0.15))
# API 配置 (优先从环境变量获取)
def get_config():
return {
'api_url': os.getenv('NANO_BANANA_API_URL', 'https://api.nanobanana.com/v1'),
'api_key': os.getenv('NANO_BANANA_API_KEY')
}

View File

@@ -0,0 +1,26 @@
from ..base import ModelService
from .config import IMAGE_GENERATION_PRICE, get_config
from typing import Dict, Any, Tuple, Generator, Union
import logging
logger = logging.getLogger(__name__)
class NanoBananaService(ModelService):
def get_api_config(self) -> Tuple[str, str]:
config = get_config()
return config['api_url'], config['api_key']
def check_balance(self, balance: float) -> Tuple[bool, float, str]:
if balance < IMAGE_GENERATION_PRICE:
return False, IMAGE_GENERATION_PRICE, f'余额不足。当前余额: {balance}, 需要: {IMAGE_GENERATION_PRICE}'
return True, IMAGE_GENERATION_PRICE, ""
def calculate_cost(self, usage: Dict[str, Any] = None, stream: bool = False) -> float:
# 固定按次计费
return IMAGE_GENERATION_PRICE
def prepare_payload(self, data: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in data.items() if k not in ['user']}
def handle_response(self, response, stream: bool) -> Union[Dict[str, Any], Generator]:
pass

View File

@@ -0,0 +1,10 @@
from .DeepSeek.service import DeepSeekService
from .NanoBanana.service import NanoBananaService
def get_model_service(model_name: str):
"""根据模型名称获取对应的服务实例"""
if model_name.startswith('deepseek-'):
return DeepSeekService()
else:
# 默认使用 Nano Banana 服务 (包括文生图等)
return NanoBananaService()

View File

@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Generator, Union, Tuple
import requests
class ModelService(ABC):
"""大模型服务基类"""
@abstractmethod
def get_api_config(self) -> Tuple[str, str]:
"""获取 API URL 和 Key"""
pass
@abstractmethod
def calculate_cost(self, usage: Dict[str, Any] = None, stream: bool = False) -> float:
"""计算费用"""
pass
@abstractmethod
def check_balance(self, balance: float) -> Tuple[bool, float, str]:
"""检查余额是否充足
Returns: (is_sufficient, estimated_cost, error_message)
"""
pass
@abstractmethod
def prepare_payload(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""准备请求载荷"""
pass
@abstractmethod
def handle_response(self, response: requests.Response, stream: bool) -> Union[Dict[str, Any], Generator]:
"""处理响应"""
pass

View File

@@ -0,0 +1,233 @@
from datetime import datetime, timedelta
from flask_sqlalchemy import SQLAlchemy
from werkzeug.security import generate_password_hash, check_password_hash
import secrets
import string
from extensions import db
# db = SQLAlchemy() # Moved to extensions.py
class User(db.Model):
"""用户模型"""
__tablename__ = 'users'
id = db.Column(db.Integer, primary_key=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
username = db.Column(db.String(80))
balance = db.Column(db.Float, default=0.0) # 账户余额
is_active = db.Column(db.Boolean, default=True)
is_admin = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 关系
orders = db.relationship('Order', backref='user', lazy='dynamic', cascade='all, delete-orphan')
api_calls = db.relationship('ApiCall', backref='user', lazy='dynamic', cascade='all, delete-orphan')
transactions = db.relationship('Transaction', backref='user', lazy='dynamic', cascade='all, delete-orphan')
api_keys = db.relationship('APIKey', backref='user', lazy='dynamic', cascade='all, delete-orphan')
verification_codes = db.relationship('VerificationCode', backref='user', lazy='dynamic', cascade='all, delete-orphan')
# 验证状态
email_verified = db.Column(db.Boolean, default=False)
email_verified_at = db.Column(db.DateTime)
def set_password(self, password):
"""设置密码"""
self.password_hash = generate_password_hash(password)
def check_password(self, password):
"""验证密码"""
return check_password_hash(self.password_hash, password)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'email': self.email,
'username': self.username,
'balance': self.balance,
'is_active': self.is_active,
'is_admin': self.is_admin,
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat()
}
class Order(db.Model):
"""订单模型"""
__tablename__ = 'orders'
id = db.Column(db.Integer, primary_key=True)
order_no = db.Column(db.String(50), unique=True, nullable=False, index=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
amount = db.Column(db.Float, nullable=False) # 充值金额
payment_method = db.Column(db.String(20)) # wechat, alipay
status = db.Column(db.String(20), default='pending') # pending, paid, cancelled, failed
transaction_id = db.Column(db.String(100)) # 第三方交易ID
created_at = db.Column(db.DateTime, default=datetime.utcnow)
paid_at = db.Column(db.DateTime)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'order_no': self.order_no,
'user_id': self.user_id,
'amount': self.amount,
'payment_method': self.payment_method,
'status': self.status,
'transaction_id': self.transaction_id,
'created_at': self.created_at.isoformat(),
'paid_at': self.paid_at.isoformat() if self.paid_at else None
}
class Transaction(db.Model):
"""交易记录模型"""
__tablename__ = 'transactions'
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
type = db.Column(db.String(20), nullable=False) # recharge, consume, refund
amount = db.Column(db.Float, nullable=False)
balance_before = db.Column(db.Float, nullable=False)
balance_after = db.Column(db.Float, nullable=False)
description = db.Column(db.String(255))
order_id = db.Column(db.Integer, db.ForeignKey('orders.id'))
api_call_id = db.Column(db.Integer, db.ForeignKey('api_calls.id'))
created_at = db.Column(db.DateTime, default=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'user_id': self.user_id,
'type': self.type,
'amount': self.amount,
'balance_before': self.balance_before,
'balance_after': self.balance_after,
'description': self.description,
'order_id': self.order_id,
'api_call_id': self.api_call_id,
'created_at': self.created_at.isoformat()
}
class ApiCall(db.Model):
"""API调用记录模型"""
__tablename__ = 'api_calls'
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
api_type = db.Column(db.String(50), default='text_to_image') # text_to_image
prompt = db.Column(db.Text)
parameters = db.Column(db.Text) # JSON格式的参数
status = db.Column(db.String(20), default='pending') # pending, processing, success, failed
result_url = db.Column(db.String(500)) # 生成结果的URL
cost = db.Column(db.Float, default=0.0) # 本次调用费用
error_message = db.Column(db.Text)
request_time = db.Column(db.DateTime, default=datetime.utcnow)
response_time = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'user_id': self.user_id,
'api_type': self.api_type,
'prompt': self.prompt,
'status': self.status,
'result_url': self.result_url,
'cost': self.cost,
'error_message': self.error_message,
'request_time': self.request_time.isoformat(),
'response_time': self.response_time.isoformat() if self.response_time else None,
'created_at': self.created_at.isoformat()
}
class SystemConfig(db.Model):
"""系统配置模型"""
__tablename__ = 'system_configs'
id = db.Column(db.Integer, primary_key=True)
key = db.Column(db.String(50), unique=True, nullable=False)
value = db.Column(db.Text)
description = db.Column(db.String(255))
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
def to_dict(self):
"""转换为字典"""
return {
'id': self.id,
'key': self.key,
'value': self.value,
'description': self.description,
'updated_at': self.updated_at.isoformat()
}
class VerificationCode(db.Model):
"""邮箱验证码模型"""
__tablename__ = 'verification_codes'
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
email = db.Column(db.String(120), nullable=False)
code = db.Column(db.String(6), nullable=False) # 6位验证码
purpose = db.Column(db.String(20), default='register') # register, password_reset
used = db.Column(db.Boolean, default=False)
expired_at = db.Column(db.DateTime, nullable=False) # 过期时间
created_at = db.Column(db.DateTime, default=datetime.utcnow)
def is_valid(self):
"""检查验证码是否有效"""
return not self.used and datetime.utcnow() < self.expired_at
@staticmethod
def generate_code():
"""生成6位验证码"""
return ''.join(secrets.choice(string.digits) for _ in range(6))
def to_dict(self):
return {
'id': self.id,
'email': self.email,
'purpose': self.purpose,
'expired_at': self.expired_at.isoformat()
}
class APIKey(db.Model):
"""API密钥模型 - 用户可自己生成"""
__tablename__ = 'api_keys'
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
name = db.Column(db.String(100), nullable=False) # API密钥名称
api_key = db.Column(db.String(100), unique=True, nullable=False, index=True) # 实际密钥
# secret_key 字段已移除
is_active = db.Column(db.Boolean, default=True)
last_used_at = db.Column(db.DateTime) # 最后使用时间
created_at = db.Column(db.DateTime, default=datetime.utcnow)
@staticmethod
def generate_key():
"""生成 API Key"""
api_key = 'sk_' + ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(32))
return api_key
def to_dict(self):
"""转换为字典"""
data = {
'id': self.id,
'name': self.name,
'api_key': self.api_key,
'is_active': self.is_active,
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
'created_at': self.created_at.isoformat()
}
return data

View File

@@ -0,0 +1,9 @@
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-CORS==4.0.0
Flask-JWT-Extended==4.6.0
Flask-Mail==0.9.1
python-dotenv==1.0.0
requests==2.31.0
Werkzeug==3.0.1
email-validator==2.1.0

View File

@@ -0,0 +1,113 @@
"""管理员相关路由"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from models import User
from services.admin_service import AdminService
admin_bp = Blueprint('admin', __name__)
def admin_required():
"""管理员权限验证装饰器"""
current_user_id = get_jwt_identity()
user = User.query.get(current_user_id)
if not user or not user.is_admin:
return None
return user
@admin_bp.route('/users', methods=['GET'])
@jwt_required()
def get_users():
"""获取用户列表"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
# 分页参数
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
search = request.args.get('search', '')
result, status_code = AdminService.get_users(page, per_page, search)
return jsonify(result), status_code
@admin_bp.route('/users/<int:user_id>', methods=['GET'])
@jwt_required()
def get_user_detail(user_id):
"""获取用户详情"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
result, status_code = AdminService.get_user_detail(user_id)
return jsonify(result), status_code
@admin_bp.route('/users/<int:user_id>/toggle-status', methods=['POST'])
@jwt_required()
def toggle_user_status(user_id):
"""启用/禁用用户"""
admin = admin_required()
if not admin:
return jsonify({'error': '需要管理员权限'}), 403
result, status_code = AdminService.toggle_user_status(admin.id, user_id)
return jsonify(result), status_code
@admin_bp.route('/users/<int:user_id>/adjust-balance', methods=['POST'])
@jwt_required()
def adjust_balance(user_id):
"""调整用户余额"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
data = request.get_json()
result, status_code = AdminService.adjust_balance(user_id, data)
return jsonify(result), status_code
@admin_bp.route('/orders', methods=['GET'])
@jwt_required()
def get_all_orders():
"""获取所有订单"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
status = request.args.get('status')
result, status_code = AdminService.get_all_orders(page, per_page, status)
return jsonify(result), status_code
@admin_bp.route('/api-calls', methods=['GET'])
@jwt_required()
def get_all_api_calls():
"""获取所有API调用记录"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
status = request.args.get('status')
result, status_code = AdminService.get_all_api_calls(page, per_page, status)
return jsonify(result), status_code
@admin_bp.route('/stats/overview', methods=['GET'])
@jwt_required()
def get_overview_stats():
"""获取总览统计"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
result, status_code = AdminService.get_overview_stats()
return jsonify(result), status_code
@admin_bp.route('/stats/chart', methods=['GET'])
@jwt_required()
def get_chart_data():
"""获取图表数据最近7天"""
if not admin_required():
return jsonify({'error': '需要管理员权限'}), 403
days = request.args.get('days', 7, type=int)
result, status_code = AdminService.get_chart_data(days)
return jsonify(result), status_code

View File

@@ -0,0 +1,45 @@
"""API服务相关路由 - 支持 Nano Banana API 代理"""
from flask import Blueprint, request, jsonify, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from services.api_proxy_service import ApiProxyService
api_bp = Blueprint('api_service', __name__)
@api_bp.route('/text-to-image', methods=['POST'])
@jwt_required()
def text_to_image():
"""
通用 API 代理 (支持文生图和对话)
"""
current_user_id = get_jwt_identity()
data = request.get_json()
result, status_code = ApiProxyService.handle_api_request(current_user_id, data)
if status_code == 200 and data.get('stream', False):
return current_app.response_class(result, mimetype='text/event-stream')
return jsonify(result), status_code
@api_bp.route('/models', methods=['GET'])
@jwt_required()
def get_models():
"""
获取可用的模型列表
"""
result, status_code = ApiProxyService.get_models()
return jsonify(result), status_code
@api_bp.route('/pricing', methods=['GET'])
def get_pricing():
"""获取价格信息(公开接口)"""
result, status_code = ApiProxyService.get_pricing()
return jsonify(result), status_code
@api_bp.route('/call/<int:call_id>', methods=['GET'])
@jwt_required()
def get_api_call(call_id):
"""获取API调用详情"""
current_user_id = get_jwt_identity()
result, status_code = ApiProxyService.get_api_call(current_user_id, call_id)
return jsonify(result), status_code

View File

@@ -0,0 +1,56 @@
"""用户 API Key 管理相关路由"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from services.apikey_service import APIKeyService
apikey_bp = Blueprint('apikey', __name__)
@apikey_bp.route('/keys', methods=['GET'])
@jwt_required()
def list_api_keys():
"""获取用户的所有 API Key"""
user_id = get_jwt_identity()
result, status_code = APIKeyService.list_api_keys(user_id)
return jsonify(result), status_code
@apikey_bp.route('/keys', methods=['POST'])
@jwt_required()
def create_api_key():
"""创建新的 API Key"""
user_id = get_jwt_identity()
data = request.get_json()
result, status_code = APIKeyService.create_api_key(user_id, data)
return jsonify(result), status_code
@apikey_bp.route('/keys/<int:key_id>', methods=['GET'])
@jwt_required()
def get_api_key(key_id):
"""获取单个 API Key 详情"""
user_id = get_jwt_identity()
result, status_code = APIKeyService.get_api_key(user_id, key_id)
return jsonify(result), status_code
@apikey_bp.route('/keys/<int:key_id>', methods=['PUT'])
@jwt_required()
def update_api_key(key_id):
"""更新 API Key 名称或状态"""
user_id = get_jwt_identity()
data = request.get_json()
result, status_code = APIKeyService.update_api_key(user_id, key_id, data)
return jsonify(result), status_code
@apikey_bp.route('/keys/<int:key_id>', methods=['DELETE'])
@jwt_required()
def delete_api_key(key_id):
"""删除 API Key"""
user_id = get_jwt_identity()
result, status_code = APIKeyService.delete_api_key(user_id, key_id)
return jsonify(result), status_code
@apikey_bp.route('/keys/<int:key_id>/regenerate', methods=['POST'])
@jwt_required()
def regenerate_api_key(key_id):
"""重置/轮换 API Key"""
user_id = get_jwt_identity()
result, status_code = APIKeyService.regenerate_api_key(user_id, key_id)
return jsonify(result), status_code

View File

@@ -0,0 +1,70 @@
"""认证相关路由"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity, create_access_token
from services.auth_service import AuthService
from services.user_service import UserService
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/send-verification-code', methods=['POST'])
def send_verification_code():
"""发送验证码"""
data = request.get_json()
result, status_code = AuthService.send_verification_code(data)
return jsonify(result), status_code
@auth_bp.route('/register', methods=['POST'])
def register():
"""用户注册 - 需要验证码"""
data = request.get_json()
result, status_code = AuthService.register(data)
return jsonify(result), status_code
@auth_bp.route('/login', methods=['POST'])
def login():
"""用户登录"""
data = request.get_json()
result, status_code = AuthService.login(data)
return jsonify(result), status_code
@auth_bp.route('/me', methods=['GET'])
@jwt_required()
def get_current_user():
"""获取当前用户信息"""
current_user_id = get_jwt_identity()
result, status_code = UserService.get_profile(current_user_id)
return jsonify(result), status_code
@auth_bp.route('/change-password', methods=['POST'])
@jwt_required()
def change_password():
"""修改密码"""
current_user_id = get_jwt_identity()
data = request.get_json()
result, status_code = AuthService.change_password(current_user_id, data)
return jsonify(result), status_code
@auth_bp.route('/reset-password', methods=['POST'])
def reset_password():
"""重置密码 - 需要验证码"""
data = request.get_json()
result, status_code = AuthService.reset_password(data)
return jsonify(result), status_code
@auth_bp.route('/verify-code', methods=['POST'])
def verify_code():
"""验证验证码是否有效(不标记为已使用)"""
data = request.get_json()
result, status_code = AuthService.verify_code(data)
return jsonify(result), status_code
@auth_bp.route('/refresh', methods=['POST'])
@jwt_required(refresh=True)
def refresh():
"""刷新访问令牌"""
current_user_id = get_jwt_identity()
access_token = create_access_token(identity=current_user_id)
return jsonify({
'access_token': access_token
}), 200

View File

@@ -0,0 +1,57 @@
"""订单相关路由"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from services.order_service import OrderService
order_bp = Blueprint('order', __name__)
@order_bp.route('/create', methods=['POST'])
@jwt_required()
def create_order():
"""创建充值订单"""
current_user_id = get_jwt_identity()
data = request.get_json()
result, status_code = OrderService.create_order(current_user_id, data)
return jsonify(result), status_code
@order_bp.route('/list', methods=['GET'])
@jwt_required()
def get_orders():
"""获取订单列表"""
current_user_id = get_jwt_identity()
# 分页参数
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
status = request.args.get('status')
result, status_code = OrderService.get_orders(current_user_id, page, per_page, status)
return jsonify(result), status_code
@order_bp.route('/<int:order_id>', methods=['GET'])
@jwt_required()
def get_order(order_id):
"""获取订单详情"""
current_user_id = get_jwt_identity()
result, status_code = OrderService.get_order(current_user_id, order_id)
return jsonify(result), status_code
@order_bp.route('/callback/alipay', methods=['POST'])
def alipay_callback():
"""支付宝支付回调(预留接口)"""
data = request.get_json() or request.form.to_dict()
result, status_code = OrderService.alipay_callback(data)
return jsonify(result), status_code
@order_bp.route('/callback/wechat', methods=['POST'])
def wechat_callback():
"""微信支付回调(预留接口)"""
data = request.get_json() or request.form.to_dict()
result, status_code = OrderService.wechat_callback(data)
return jsonify(result), status_code
@order_bp.route('/notify/<order_no>', methods=['POST'])
def payment_notify(order_no):
"""模拟支付通知(仅用于测试)"""
result, status_code = OrderService.payment_notify(order_no)
return jsonify(result), status_code

View File

@@ -0,0 +1,65 @@
"""用户相关路由"""
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from services.user_service import UserService
user_bp = Blueprint('user', __name__)
@user_bp.route('/profile', methods=['GET'])
@jwt_required()
def get_profile():
"""获取用户资料"""
current_user_id = get_jwt_identity()
result, status_code = UserService.get_profile(current_user_id)
return jsonify(result), status_code
@user_bp.route('/profile', methods=['PUT'])
@jwt_required()
def update_profile():
"""更新用户资料"""
current_user_id = get_jwt_identity()
data = request.get_json()
result, status_code = UserService.update_profile(current_user_id, data)
return jsonify(result), status_code
@user_bp.route('/balance', methods=['GET'])
@jwt_required()
def get_balance():
"""获取账户余额"""
current_user_id = get_jwt_identity()
result, status_code = UserService.get_balance(current_user_id)
return jsonify(result), status_code
@user_bp.route('/transactions', methods=['GET'])
@jwt_required()
def get_transactions():
"""获取交易记录"""
current_user_id = get_jwt_identity()
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
result, status_code = UserService.get_transactions(current_user_id, page, per_page)
return jsonify(result), status_code
@user_bp.route('/api-calls', methods=['GET'])
@jwt_required()
def get_api_calls():
"""获取API调用记录"""
current_user_id = get_jwt_identity()
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
result, status_code = UserService.get_api_calls(current_user_id, page, per_page)
return jsonify(result), status_code
@user_bp.route('/stats', methods=['GET'])
@jwt_required()
def get_stats():
"""获取用户统计信息"""
current_user_id = get_jwt_identity()
result, status_code = UserService.get_stats(current_user_id)
return jsonify(result), status_code

View File

@@ -0,0 +1,36 @@
"""对外公开的 API 路由 (v1) - 支持 API Key 认证"""
from flask import Blueprint, request, jsonify, current_app
from services.apikey_service import APIKeyService
from services.v1_service import V1Service
v1_bp = Blueprint('v1_api', __name__)
@v1_bp.route('/chat/completions', methods=['POST'])
def chat_completions():
"""
OpenAI 兼容的 Chat Completions 接口
支持多模型,通过 modelapiservice 分发
"""
# 1. 认证
auth_header = request.headers.get('Authorization')
user, error = APIKeyService.authenticate_api_key(auth_header)
if error:
return jsonify({'error': {'message': error, 'type': 'auth_error', 'code': 401}}), 401
# 2. 获取请求数据
data = request.get_json()
if not data:
return jsonify({'error': {'message': '无效的 JSON 请求体', 'type': 'invalid_request_error', 'code': 400}}), 400
result, status_code = V1Service.chat_completions(user, data)
if status_code == 200 and data.get('stream', False):
return current_app.response_class(result, mimetype='text/event-stream')
return jsonify(result), status_code
@v1_bp.route('/text-to-image', methods=['POST'])
def text_to_image_alias():
"""兼容旧接口路径"""
return chat_completions()

View File

@@ -0,0 +1,244 @@
from models import db, User, Order, ApiCall, Transaction
from sqlalchemy import desc, func
from datetime import datetime, timedelta
class AdminService:
@staticmethod
def get_users(page=1, per_page=20, search=''):
"""获取用户列表"""
query = User.query
if search:
query = query.filter(
(User.email.contains(search)) | (User.username.contains(search))
)
# 分页查询
pagination = query.order_by(desc(User.created_at))\
.paginate(page=page, per_page=per_page, error_out=False)
return {
'users': [user.to_dict() for user in pagination.items],
'total': pagination.total,
'page': page,
'per_page': per_page,
'pages': pagination.pages
}, 200
@staticmethod
def get_user_detail(user_id):
"""获取用户详情"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
# 统计信息
total_recharge = db.session.query(func.sum(Transaction.amount))\
.filter_by(user_id=user_id, type='recharge').scalar() or 0
total_consume = db.session.query(func.sum(Transaction.amount))\
.filter_by(user_id=user_id, type='consume').scalar() or 0
total_api_calls = ApiCall.query.filter_by(user_id=user_id).count()
return {
'user': user.to_dict(),
'stats': {
'total_recharge': total_recharge,
'total_consume': abs(total_consume),
'total_api_calls': total_api_calls
}
}, 200
@staticmethod
def toggle_user_status(admin_user_id, user_id):
"""启用/禁用用户"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
# 检查是否尝试禁用其他管理员 (需要传入当前管理员ID)
if user.is_admin and user.id != admin_user_id:
return {'error': '不能禁用其他管理员'}, 403
user.is_active = not user.is_active
try:
db.session.commit()
return {
'message': f'用户已{"启用" if user.is_active else "禁用"}',
'user': user.to_dict()
}, 200
except Exception as e:
db.session.rollback()
return {'error': '操作失败'}, 500
@staticmethod
def adjust_balance(user_id, data):
"""调整用户余额"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
amount = data.get('amount')
description = data.get('description', '管理员调整余额')
if amount is None or amount == 0:
return {'error': '金额不能为0'}, 400
try:
balance_before = user.balance
user.balance += amount
balance_after = user.balance
# 创建交易记录
transaction = Transaction(
user_id=user_id,
type='recharge' if amount > 0 else 'consume',
amount=amount,
balance_before=balance_before,
balance_after=balance_after,
description=description
)
db.session.add(transaction)
db.session.commit()
return {
'message': '余额调整成功',
'user': user.to_dict(),
'transaction': transaction.to_dict()
}, 200
except Exception as e:
db.session.rollback()
return {'error': '余额调整失败'}, 500
@staticmethod
def get_all_orders(page=1, per_page=20, status=None):
"""获取所有订单"""
query = Order.query
if status:
query = query.filter_by(status=status)
# 分页查询
pagination = query.order_by(desc(Order.created_at))\
.paginate(page=page, per_page=per_page, error_out=False)
return {
'orders': [order.to_dict() for order in pagination.items],
'total': pagination.total,
'page': page,
'per_page': per_page,
'pages': pagination.pages
}, 200
@staticmethod
def get_all_api_calls(page=1, per_page=20, status=None):
"""获取所有API调用记录"""
query = ApiCall.query
if status:
query = query.filter_by(status=status)
# 分页查询
pagination = query.order_by(desc(ApiCall.created_at))\
.paginate(page=page, per_page=per_page, error_out=False)
return {
'api_calls': [call.to_dict() for call in pagination.items],
'total': pagination.total,
'page': page,
'per_page': per_page,
'pages': pagination.pages
}, 200
@staticmethod
def get_overview_stats():
"""获取总览统计"""
# 用户统计
total_users = User.query.count()
active_users = User.query.filter_by(is_active=True).count()
# 订单统计
total_orders = Order.query.count()
paid_orders = Order.query.filter_by(status='paid').count()
total_revenue = db.session.query(func.sum(Order.amount))\
.filter_by(status='paid').scalar() or 0
# API调用统计
total_api_calls = ApiCall.query.count()
success_calls = ApiCall.query.filter_by(status='success').count()
failed_calls = ApiCall.query.filter_by(status='failed').count()
# 今日统计
today = datetime.utcnow().date()
today_start = datetime.combine(today, datetime.min.time())
today_users = User.query.filter(User.created_at >= today_start).count()
today_orders = Order.query.filter(Order.created_at >= today_start).count()
today_revenue = db.session.query(func.sum(Order.amount))\
.filter(Order.created_at >= today_start, Order.status == 'paid').scalar() or 0
today_api_calls = ApiCall.query.filter(ApiCall.created_at >= today_start).count()
return {
'total': {
'users': total_users,
'active_users': active_users,
'orders': total_orders,
'paid_orders': paid_orders,
'revenue': total_revenue,
'api_calls': total_api_calls,
'success_calls': success_calls,
'failed_calls': failed_calls
},
'today': {
'users': today_users,
'orders': today_orders,
'revenue': today_revenue,
'api_calls': today_api_calls
}
}, 200
@staticmethod
def get_chart_data(days=7):
"""获取图表数据"""
# 计算日期范围
end_date = datetime.utcnow().date()
start_date = end_date - timedelta(days=days-1)
chart_data = []
for i in range(days):
date = start_date + timedelta(days=i)
date_start = datetime.combine(date, datetime.min.time())
date_end = datetime.combine(date, datetime.max.time())
# 统计当天数据
users = User.query.filter(
User.created_at >= date_start,
User.created_at <= date_end
).count()
revenue = db.session.query(func.sum(Order.amount))\
.filter(
Order.created_at >= date_start,
Order.created_at <= date_end,
Order.status == 'paid'
).scalar() or 0
api_calls = ApiCall.query.filter(
ApiCall.created_at >= date_start,
ApiCall.created_at <= date_end
).count()
chart_data.append({
'date': date.isoformat(),
'users': users,
'revenue': float(revenue),
'api_calls': api_calls
})
return {
'chart_data': chart_data
}, 200

View File

@@ -0,0 +1,258 @@
from flask import current_app
from models import db, User, ApiCall, Transaction
from datetime import datetime
import requests
import json
import logging
from modelapiservice import get_model_service
logger = logging.getLogger(__name__)
class ApiProxyService:
@staticmethod
def deduct_balance(user_id, api_call_id, cost, model):
"""统一扣费逻辑"""
try:
user = User.query.get(user_id)
api_call = ApiCall.query.get(api_call_id)
if not user or not api_call:
return
# 刷新用户数据
db.session.refresh(user)
balance_before = user.balance
user.balance -= cost
balance_after = user.balance
api_call.status = 'success'
api_call.cost = cost
db.session.add(api_call)
transaction = Transaction(
user_id=user.id,
type='consume',
amount=-cost,
balance_before=balance_before,
balance_after=balance_after,
description=f'API调用 - {model}',
api_call_id=api_call.id
)
db.session.add(transaction)
db.session.commit()
logger.info(f"扣费成功: 用户 {user.id}, 消费 {cost}, 余额 {balance_before} -> {balance_after}")
except Exception as e:
logger.error(f'扣费失败: {e}')
db.session.rollback()
@staticmethod
def handle_api_request(user_id, data):
"""
通用 API 代理 (支持文生图和对话)
"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
if not user.is_active:
return {'error': '账户已被禁用'}, 403
model = data.get('model')
messages = data.get('messages', [])
stream = data.get('stream', False)
# 验证必要字段
if not model:
return {'error': 'model字段不能为空'}, 400
if not messages or len(messages) == 0:
return {'error': 'messages不能为空'}, 400
# 获取模型服务并检查余额
try:
service = get_model_service(model)
except Exception as e:
return {'error': f'不支持的模型: {model}'}, 400
is_sufficient, estimated_cost, error_msg = service.check_balance(user.balance)
if not is_sufficient:
return {
'error': '余额不足',
'message': error_msg,
'required': estimated_cost,
'balance': user.balance
}, 402
prompt = messages[0].get('content', '') if messages else ''
if not prompt:
prompt = "Empty prompt"
# 创建API调用记录
api_call = ApiCall(
user_id=user_id,
api_type='chat_completion',
prompt=prompt[:500], # 截断
parameters=json.dumps({
'model': model,
'stream': stream
}),
status='processing',
cost=estimated_cost,
request_time=datetime.utcnow()
)
try:
db.session.add(api_call)
db.session.flush()
# 准备请求
api_url, api_key = service.get_api_config()
if not api_url or not api_key:
raise ValueError(f'模型 {model} API 配置未完成')
headers = {
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json'
}
payload = service.prepare_payload(data)
target_url = f'{api_url}/chat/completions'
logger.info(f'API 转发: {target_url}, User: {user.id}, Model: {model}')
response = requests.post(
target_url,
headers=headers,
json=payload,
stream=stream,
timeout=300
)
if response.status_code != 200:
error_msg = f'第三方 API 返回错误: {response.status_code}'
try:
error_detail = response.json()
error_msg += f' - {error_detail}'
except:
error_msg += f' - {response.text[:200]}'
api_call.status = 'failed'
api_call.error_message = error_msg
db.session.commit()
return {'error': 'API 调用失败', 'details': error_msg}, 502
# 处理响应
if stream:
# 流式响应处理
def generate():
final_usage = None
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
if hasattr(service, 'parse_stream_usage'):
try:
text_chunk = chunk.decode('utf-8', errors='ignore')
usage = service.parse_stream_usage(text_chunk)
if usage:
final_usage = usage
except:
pass
yield chunk
# 计算最终费用
actual_cost = service.calculate_cost(final_usage, stream=True)
if actual_cost == 0 and estimated_cost > 0:
actual_cost = estimated_cost
with current_app.app_context():
ApiProxyService.deduct_balance(user.id, api_call.id, actual_cost, model)
except Exception as e:
logger.error(f'Stream error: {e}')
return generate(), 200 # Special return for stream
else:
result = response.json()
api_call.status = 'success'
api_call.response_time = datetime.utcnow()
# 计算费用
usage = result.get('usage')
final_cost = service.calculate_cost(usage, stream=False)
if final_cost == 0 and estimated_cost > 0:
final_cost = estimated_cost
# 简化响应格式
simplified_result = {
'success': True,
'api_call_id': api_call.id,
'cost': final_cost,
'model': model,
'content': ''
}
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0].get('message', {}).get('content', '')
simplified_result['content'] = content
api_call.result_url = content[:500]
ApiProxyService.deduct_balance(user.id, api_call.id, final_cost, model)
return simplified_result, 200
except Exception as e:
logger.error(f'API 调用异常: {str(e)}', exc_info=True)
if api_call.id:
api_call.status = 'failed'
api_call.error_message = str(e)
db.session.commit()
return {'error': '服务异常', 'message': str(e)}, 500
@staticmethod
def get_models():
"""获取可用的模型列表"""
# 暂时返回硬编码的模型列表,后续可以从各 Service 聚合
return {
'object': 'list',
'data': [
{
'id': 'deepseek-chat',
'object': 'model',
'owned_by': 'deepseek',
'description': 'DeepSeek Chat V3'
},
{
'id': 'deepseek-reasoner',
'object': 'model',
'owned_by': 'deepseek',
'description': 'DeepSeek Reasoner (R1)'
},
# ... 其他模型 ...
]
}, 200
@staticmethod
def get_pricing():
"""获取价格信息"""
pricing = {
'text_to_image': {
'price': current_app.config.get('IMAGE_GENERATION_PRICE', 0),
'currency': 'CNY',
'unit': '每张图片'
}
}
return {
'pricing': pricing
}, 200
@staticmethod
def get_api_call(user_id, call_id):
"""获取API调用详情"""
api_call = ApiCall.query.filter_by(id=call_id, user_id=user_id).first()
if not api_call:
return {'error': 'API调用记录不存在'}, 404
return api_call.to_dict(), 200

View File

@@ -0,0 +1,165 @@
from models import db, User, APIKey
from datetime import datetime
class APIKeyService:
@staticmethod
def list_api_keys(user_id):
"""获取用户的所有 API Key"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
keys = APIKey.query.filter_by(user_id=user_id).all()
return {
'total': len(keys),
'keys': [key.to_dict() for key in keys]
}, 200
@staticmethod
def create_api_key(user_id, data):
"""创建新的 API Key"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
name = data.get('name', '').strip()
if not name:
return {'error': 'API Key 名称不能为空'}, 400
if len(name) > 100:
return {'error': 'API Key 名称长度不能超过100个字符'}, 400
# 生成 API Key
api_key = APIKey.generate_key()
# 创建数据库记录
new_key = APIKey(
user_id=user_id,
name=name,
api_key=api_key
)
try:
db.session.add(new_key)
db.session.commit()
return {
'message': 'API Key 创建成功',
'key': new_key.to_dict()
}, 201
except Exception as e:
db.session.rollback()
return {'error': '创建失败,请稍后重试'}, 500
@staticmethod
def get_api_key(user_id, key_id):
"""获取单个 API Key 详情"""
key = APIKey.query.filter_by(id=key_id, user_id=user_id).first()
if not key:
return {'error': 'API Key 不存在'}, 404
return key.to_dict(), 200
@staticmethod
def update_api_key(user_id, key_id, data):
"""更新 API Key 名称或状态"""
key = APIKey.query.filter_by(id=key_id, user_id=user_id).first()
if not key:
return {'error': 'API Key 不存在'}, 404
if 'name' in data:
name = data.get('name', '').strip()
if not name or len(name) > 100:
return {'error': 'API Key 名称无效'}, 400
key.name = name
if 'is_active' in data:
key.is_active = bool(data.get('is_active'))
try:
db.session.commit()
return {
'message': 'API Key 更新成功',
'key': key.to_dict()
}, 200
except Exception as e:
db.session.rollback()
return {'error': '更新失败,请稍后重试'}, 500
@staticmethod
def delete_api_key(user_id, key_id):
"""删除 API Key"""
key = APIKey.query.filter_by(id=key_id, user_id=user_id).first()
if not key:
return {'error': 'API Key 不存在'}, 404
try:
db.session.delete(key)
db.session.commit()
return {'message': 'API Key 已删除'}, 200
except Exception as e:
db.session.rollback()
return {'error': '删除失败,请稍后重试'}, 500
@staticmethod
def regenerate_api_key(user_id, key_id):
"""重置/轮换 API Key"""
key = APIKey.query.filter_by(id=key_id, user_id=user_id).first()
if not key:
return {'error': 'API Key 不存在'}, 404
# 生成新的 API Key
new_api_key = APIKey.generate_key()
key.api_key = new_api_key
try:
db.session.commit()
return {
'message': 'API Key 已重置',
'key': key.to_dict()
}, 200
except Exception as e:
db.session.rollback()
return {'error': '重置失败,请稍后重试'}, 500
@staticmethod
def authenticate_api_key(auth_header):
"""验证 API Key 并返回用户"""
if not auth_header:
return None, "缺少 Authorization 头"
parts = auth_header.split()
if parts[0].lower() != "bearer":
return None, "Authorization 头格式错误"
if len(parts) == 1:
return None, "无效的 Token"
api_key_str = parts[1]
# 查找 API Key
api_key = APIKey.query.filter_by(api_key=api_key_str).first()
if not api_key:
return None, "无效的 API Key"
if not api_key.is_active:
return None, "API Key 已被禁用"
# 更新最后使用时间
api_key.last_used_at = datetime.utcnow()
db.session.commit()
user = User.query.get(api_key.user_id)
if not user or not user.is_active:
return None, "账户不存在或已被禁用"
return user, None

View File

@@ -0,0 +1,290 @@
from flask import current_app, jsonify
from flask_mail import Message
from flask_jwt_extended import create_access_token, create_refresh_token
from models import db, User, VerificationCode
from email_validator import validate_email, EmailNotValidError
from datetime import datetime, timedelta
import re
class AuthService:
@staticmethod
def send_verification_email(email, code, purpose='register'):
"""发送验证码邮件"""
try:
subject = '验证码' if purpose == 'register' else '密码重置验证码'
body = f"""
您好!
感谢您使用 Nano Banana API 转售平台。
您的验证码是:{code}
此验证码有效期为10分钟请勿泄露给他人。
如果不是您本人操作,请忽略此邮件。
---
Nano Banana API 平台
"""
msg = Message(
subject=f'[Nano Banana] {subject}',
recipients=[email],
body=body
)
# 这里使用 Flask-Mail 发送
from extensions import mail
mail.send(msg)
return True
except Exception as e:
print(f'发送邮件失败: {str(e)}')
return False
@staticmethod
def send_verification_code(data):
"""发送验证码逻辑"""
email = data.get('email', '').strip().lower()
purpose = data.get('purpose', 'register')
if not email:
return {'error': '邮箱不能为空'}, 400
# 验证邮箱格式
try:
validate_email(email, check_deliverability=False)
except EmailNotValidError:
return {'error': '邮箱格式不正确'}, 400
# 如果是注册,检查邮箱是否已存在
if purpose == 'register':
user = User.query.filter_by(email=email).first()
if user:
# 如果用户已存在且已验证邮箱,则提示已注册
if user.email_verified:
return {'error': '该邮箱已被注册'}, 400
# 如果用户存在但未验证,可以继续发送验证码(重发)
else:
# 创建临时用户记录用于验证码关联
user = User(email=email, username=email.split('@')[0])
user.set_password('temp_pending_registration')
db.session.add(user)
db.session.commit()
else:
# 密码重置:检查用户是否存在
user = User.query.filter_by(email=email).first()
if not user:
return {'error': '用户不存在'}, 404
# 生成验证码
code = VerificationCode.generate_code()
# 删除旧的未使用验证码
VerificationCode.query.filter_by(
email=email,
purpose=purpose,
used=False
).delete()
# 创建新的验证码
verification = VerificationCode(
user_id=user.id,
email=email,
code=code,
purpose=purpose,
expired_at=datetime.utcnow() + timedelta(minutes=10)
)
db.session.add(verification)
db.session.commit()
# 发送邮件
send_success = AuthService.send_verification_email(email, code, purpose)
if send_success:
return {
'message': '验证码已发送到您的邮箱',
'email': email,
'purpose': purpose
}, 200
else:
return {'error': '邮件发送失败,请检查邮箱配置'}, 500
@staticmethod
def register(data):
"""用户注册逻辑"""
if not data or not data.get('email') or not data.get('password') or not data.get('code'):
return {'error': '邮箱、密码和验证码不能为空'}, 400
email = data.get('email').strip().lower()
password = data.get('password')
code = data.get('code')
username = data.get('username', '').strip()
try:
validate_email(email, check_deliverability=False)
except EmailNotValidError:
return {'error': '邮箱格式不正确'}, 400
if len(password) < 6:
return {'error': '密码长度至少为6位'}, 400
verification = VerificationCode.query.filter_by(
email=email,
code=code,
purpose='register',
used=False
).first()
if not verification:
return {'error': '验证码不存在或已过期'}, 400
if not verification.is_valid():
return {'error': '验证码已过期'}, 400
existing_user = User.query.filter_by(email=email).filter(
User.id != verification.user_id
).first()
if existing_user:
return {'error': '该邮箱已被注册'}, 400
user = User.query.get(verification.user_id)
user.username = username or email.split('@')[0]
user.set_password(password)
user.email_verified = True
user.email_verified_at = datetime.utcnow()
user.is_active = True
verification.used = True
try:
db.session.commit()
access_token = create_access_token(identity=user.id)
refresh_token = create_refresh_token(identity=user.id)
return {
'message': '注册成功',
'access_token': access_token,
'refresh_token': refresh_token,
'user': user.to_dict()
}, 201
except Exception as e:
db.session.rollback()
print(f'注册错误: {str(e)}')
return {'error': '注册失败,请稍后重试'}, 500
@staticmethod
def login(data):
"""用户登录逻辑"""
if not data or not data.get('email') or not data.get('password'):
return {'error': '邮箱和密码不能为空'}, 400
email = data.get('email').strip().lower()
password = data.get('password')
user = User.query.filter_by(email=email).first()
if not user or not user.check_password(password):
return {'error': '邮箱或密码错误'}, 401
if not user.is_active:
return {'error': '账户已被禁用'}, 403
# 如果用户未验证邮箱 (理论上注册流程保证了已验证,但为了兼容旧数据)
if not user.email_verified:
return {'error': '邮箱未验证,请先完成注册验证'}, 403
access_token = create_access_token(identity=user.id)
refresh_token = create_refresh_token(identity=user.id)
return {
'message': '登录成功',
'access_token': access_token,
'refresh_token': refresh_token,
'user': user.to_dict()
}, 200
@staticmethod
def change_password(user_id, data):
"""修改密码"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
old_password = data.get('old_password')
new_password = data.get('new_password')
if not old_password or not new_password:
return {'error': '参数不能为空'}, 400
if not user.check_password(old_password):
return {'error': '原密码错误'}, 401
if len(new_password) < 6:
return {'error': '新密码长度至少为6位'}, 400
user.set_password(new_password)
db.session.commit()
return {'message': '密码修改成功'}, 200
@staticmethod
def reset_password(data):
"""重置密码"""
email = data.get('email', '').strip().lower()
code = data.get('code')
new_password = data.get('new_password')
if not email or not code or not new_password:
return {'error': '参数不能为空'}, 400
# 验证验证码
verification = VerificationCode.query.filter_by(
email=email,
code=code,
purpose='password_reset',
used=False
).first()
if not verification or not verification.is_valid():
return {'error': '验证码不存在或已过期'}, 400
user = User.query.get(verification.user_id)
if not user:
return {'error': '用户不存在'}, 404
if len(new_password) < 6:
return {'error': '密码长度至少为6位'}, 400
# 更新密码
user.set_password(new_password)
verification.used = True
db.session.commit()
return {'message': '密码重置成功'}, 200
@staticmethod
def verify_code(data):
"""验证验证码是否有效"""
email = data.get('email', '').strip().lower()
code = data.get('code')
purpose = data.get('purpose', 'register')
if not email or not code:
return {'error': '参数不能为空'}, 400
verification = VerificationCode.query.filter_by(
email=email,
code=code,
purpose=purpose,
used=False
).first()
if not verification:
return {'error': '验证码错误'}, 400
if not verification.is_valid():
return {'error': '验证码已过期'}, 400
return {'message': '验证码有效'}, 200

View File

@@ -0,0 +1,150 @@
from flask import current_app
from models import db, User, Order, Transaction
from datetime import datetime
import uuid
from sqlalchemy import desc
class OrderService:
@staticmethod
def _generate_order_no():
"""生成订单号"""
return f"ORD{datetime.utcnow().strftime('%Y%m%d%H%M%S')}{uuid.uuid4().hex[:8].upper()}"
@staticmethod
def create_order(user_id, data):
"""创建充值订单"""
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
amount = data.get('amount')
payment_method = data.get('payment_method', 'alipay') # alipay, wechat
# 验证金额
if not amount or amount <= 0:
return {'error': '充值金额必须大于0'}, 400
# 验证支付方式
if payment_method not in ['alipay', 'wechat']:
return {'error': '不支持的支付方式'}, 400
# 创建订单
order = Order(
order_no=OrderService._generate_order_no(),
user_id=user_id,
amount=amount,
payment_method=payment_method,
status='pending'
)
try:
db.session.add(order)
db.session.commit()
# TODO: 调用支付接口
# 这里返回支付参数,前端跳转到支付页面
payment_params = {
'order_no': order.order_no,
'amount': order.amount,
'payment_method': payment_method,
# 实际项目中应返回支付宝或微信的支付参数
'qr_code': f'https://payment.example.com/qr/{order.order_no}',
'payment_url': f'https://payment.example.com/pay/{order.order_no}'
}
return {
'message': '订单创建成功',
'order': order.to_dict(),
'payment': payment_params
}, 201
except Exception as e:
db.session.rollback()
return {'error': '订单创建失败'}, 500
@staticmethod
def get_orders(user_id, page=1, per_page=20, status=None):
"""获取订单列表"""
query = Order.query.filter_by(user_id=user_id)
if status:
query = query.filter_by(status=status)
# 分页查询
pagination = query.order_by(desc(Order.created_at))\
.paginate(page=page, per_page=per_page, error_out=False)
return {
'orders': [order.to_dict() for order in pagination.items],
'total': pagination.total,
'page': page,
'per_page': per_page,
'pages': pagination.pages
}, 200
@staticmethod
def get_order(user_id, order_id):
"""获取订单详情"""
order = Order.query.filter_by(id=order_id, user_id=user_id).first()
if not order:
return {'error': '订单不存在'}, 404
return order.to_dict(), 200
@staticmethod
def alipay_callback(data):
"""支付宝支付回调(预留接口)"""
# TODO: 实现支付宝回调逻辑
return {'message': '处理成功'}, 200
@staticmethod
def wechat_callback(data):
"""微信支付回调(预留接口)"""
# TODO: 实现微信支付回调逻辑
return {'message': '处理成功'}, 200
@staticmethod
def payment_notify(order_no):
"""模拟支付通知(仅用于测试)"""
order = Order.query.filter_by(order_no=order_no).first()
if not order:
return {'error': '订单不存在'}, 404
if order.status != 'pending':
return {'error': '订单状态不正确'}, 400
try:
# 更新订单状态
order.status = 'paid'
order.paid_at = datetime.utcnow()
order.transaction_id = f"TXN{uuid.uuid4().hex[:16].upper()}"
# 增加用户余额
user = User.query.get(order.user_id)
balance_before = user.balance
user.balance += order.amount
balance_after = user.balance
# 创建交易记录
transaction = Transaction(
user_id=user.id,
type='recharge',
amount=order.amount,
balance_before=balance_before,
balance_after=balance_after,
description=f'充值 - 订单号: {order.order_no}',
order_id=order.id
)
db.session.add(transaction)
db.session.commit()
return {
'message': '支付成功',
'order': order.to_dict()
}, 200
except Exception as e:
db.session.rollback()
return {'error': '支付处理失败'}, 500

View File

@@ -0,0 +1,89 @@
from flask import request
from models import db, User, Transaction, ApiCall
from sqlalchemy import desc
class UserService:
@staticmethod
def get_profile(user_id):
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
return user.to_dict(), 200
@staticmethod
def update_profile(user_id, data):
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
if 'username' in data:
user.username = data['username'].strip()
try:
db.session.commit()
return {
'message': '资料更新成功',
'user': user.to_dict()
}, 200
except Exception as e:
db.session.rollback()
return {'error': '资料更新失败'}, 500
@staticmethod
def get_balance(user_id):
user = User.query.get(user_id)
if not user:
return {'error': '用户不存在'}, 404
return {
'balance': user.balance,
'user_id': user.id
}, 200
@staticmethod
def get_transactions(user_id, page, per_page):
pagination = Transaction.query.filter_by(user_id=user_id)\
.order_by(desc(Transaction.created_at))\
.paginate(page=page, per_page=per_page, error_out=False)
return {
'transactions': [t.to_dict() for t in pagination.items],
'total': pagination.total,
'page': page,
'per_page': per_page,
'pages': pagination.pages
}, 200
@staticmethod
def get_api_calls(user_id, page, per_page):
pagination = ApiCall.query.filter_by(user_id=user_id)\
.order_by(desc(ApiCall.created_at))\
.paginate(page=page, per_page=per_page, error_out=False)
return {
'api_calls': [call.to_dict() for call in pagination.items],
'total': pagination.total,
'page': page,
'per_page': per_page,
'pages': pagination.pages
}, 200
@staticmethod
def get_stats(user_id):
total_calls = ApiCall.query.filter_by(user_id=user_id).count()
success_calls = ApiCall.query.filter_by(user_id=user_id, status='success').count()
failed_calls = ApiCall.query.filter_by(user_id=user_id, status='failed').count()
total_cost = db.session.query(db.func.sum(Transaction.amount))\
.filter_by(user_id=user_id, type='consume').scalar() or 0
total_recharge = db.session.query(db.func.sum(Transaction.amount))\
.filter_by(user_id=user_id, type='recharge').scalar() or 0
return {
'total_calls': total_calls,
'success_calls': success_calls,
'failed_calls': failed_calls,
'total_cost': abs(total_cost),
'total_recharge': total_recharge
}, 200

View File

@@ -0,0 +1,174 @@
from flask import current_app
from models import db, User, ApiCall
from datetime import datetime
import requests
import json
import logging
from modelapiservice import get_model_service
from services.api_proxy_service import ApiProxyService
logger = logging.getLogger(__name__)
class V1Service:
@staticmethod
def chat_completions(user, data):
"""
OpenAI 兼容的 Chat Completions 接口逻辑
"""
model = data.get('model')
messages = data.get('messages', [])
stream = data.get('stream', False)
if not model:
return {'error': {'message': '缺少 model 参数', 'type': 'invalid_request_error', 'code': 400}}, 400
if not messages:
return {'error': {'message': '缺少 messages 参数', 'type': 'invalid_request_error', 'code': 400}}, 400
# 获取模型服务并检查余额
try:
service = get_model_service(model)
except Exception as e:
return {'error': {'message': f'不支持的模型: {model}', 'type': 'invalid_request_error', 'code': 400}}, 400
is_sufficient, estimated_cost, error_msg = service.check_balance(user.balance)
if not is_sufficient:
return {
'error': {
'message': error_msg,
'type': 'insufficient_quota',
'code': 402
}
}, 402
prompt = messages[-1].get('content', '') if messages else 'Empty prompt'
if len(prompt) > 500:
prompt = prompt[:500] + '...'
api_call = ApiCall(
user_id=user.id,
api_type='chat_completion',
prompt=prompt,
parameters=json.dumps({
'model': model,
'stream': stream
}),
status='processing',
cost=estimated_cost, # 暂时记录预估/固定费用
request_time=datetime.utcnow()
)
try:
db.session.add(api_call)
db.session.flush()
# 准备请求
api_url, api_key = service.get_api_config()
if not api_url or not api_key:
raise ValueError(f'模型 {model} API 配置未完成')
headers = {
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json'
}
payload = service.prepare_payload(data)
target_url = f'{api_url}/chat/completions'
logger.info(f'API 转发: {target_url}, User: {user.id}, Model: {model}')
response = requests.post(
target_url,
headers=headers,
json=payload,
stream=stream,
timeout=300
)
if response.status_code != 200:
error_msg = f'Upstream Error: {response.status_code}'
try:
error_detail = response.json()
error_msg += f' - {error_detail}'
except:
error_msg += f' - {response.text[:200]}'
api_call.status = 'failed'
api_call.error_message = error_msg
db.session.commit()
return {
'error': {
'message': error_msg,
'type': 'upstream_error',
'code': 502
}
}, 502
# 处理响应
if stream:
# 流式响应处理
def generate():
final_usage = None
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
if hasattr(service, 'parse_stream_usage'):
try:
text_chunk = chunk.decode('utf-8', errors='ignore')
usage = service.parse_stream_usage(text_chunk)
if usage:
final_usage = usage
except:
pass
yield chunk
# 计算最终费用
actual_cost = service.calculate_cost(final_usage, stream=True)
if actual_cost == 0 and estimated_cost > 0:
actual_cost = estimated_cost
# 扣费
with current_app.app_context():
ApiProxyService.deduct_balance(user.id, api_call.id, actual_cost, model)
except Exception as e:
logger.error(f'Stream error: {e}')
return generate(), 200
else:
# 普通响应
result = response.json()
api_call.status = 'success'
api_call.response_time = datetime.utcnow()
# 计算费用
usage = result.get('usage')
final_cost = service.calculate_cost(usage, stream=False)
if final_cost == 0 and estimated_cost > 0:
final_cost = estimated_cost
# 简化响应格式
simplified_result = {
'model': model,
'content': '',
'cost': final_cost
}
if 'choices' in result and len(result['choices']) > 0:
content = result['choices'][0].get('message', {}).get('content', '')
simplified_result['content'] = content
# 同时更新 api_call 记录
api_call.result_url = content[:500]
ApiProxyService.deduct_balance(user.id, api_call.id, final_cost, model)
return simplified_result, 200
except Exception as e:
logger.error(f'API Error: {e}', exc_info=True)
if api_call.id:
api_call.status = 'failed'
api_call.error_message = str(e)
db.session.commit()
return {'error': {'message': 'Internal Server Error', 'type': 'server_error', 'code': 500}}, 500

View File

@@ -0,0 +1,177 @@
# Nano Banana API 中转平台 - 后端
基于 Flask 的 Nano Banana API 中转购买平台后端服务。
## 功能特性
- 📧 邮箱注册登录系统
- 💰 用户余额管理
- 💳 支付接口(支付宝/微信支付)
- 🎨 Nano Banana 文生图 API 中转
- 📊 用户统计和订单管理
- 🔐 JWT 身份认证
- 👨‍💼 管理员后台
## 技术栈
- Flask 3.0
- SQLite 数据库
- Flask-SQLAlchemy ORM
- Flask-JWT-Extended 认证
- Flask-CORS 跨域支持
- Flask-Mail 邮件服务
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 配置环境变量
复制 `.env.example``.env` 并修改配置:
```bash
copy .env.example .env
```
必需配置项:
- `SECRET_KEY`: Flask 密钥
- `JWT_SECRET_KEY`: JWT 密钥
- `NANO_BANANA_API_KEY`: Nano Banana API 密钥
- `MAIL_USERNAME`: 邮箱用户名
- `MAIL_PASSWORD`: 邮箱密码
### 3. 运行应用
```bash
python app.py
```
服务将在 http://localhost:5000 启动
### 4. 默认管理员账号
- 邮箱: admin@nba.com
- 密码: admin123
**首次登录后请立即修改密码!**
## API 文档
### 认证相关 `/api/auth`
- `POST /register` - 用户注册
- `POST /login` - 用户登录
- `POST /refresh` - 刷新令牌
- `GET /me` - 获取当前用户信息
- `POST /change-password` - 修改密码
### 用户相关 `/api/user`
- `GET /profile` - 获取用户资料
- `PUT /profile` - 更新用户资料
- `GET /balance` - 获取账户余额
- `GET /transactions` - 获取交易记录
- `GET /api-calls` - 获取API调用记录
- `GET /stats` - 获取统计信息
### 订单相关 `/api/order`
- `POST /create` - 创建充值订单
- `GET /list` - 获取订单列表
- `GET /<order_id>` - 获取订单详情
- `POST /notify/<order_no>` - 支付通知(测试用)
### API服务 `/api/service`
- `POST /text-to-image` - 文生图API
- `GET /models` - 获取可用模型
- `GET /pricing` - 获取价格信息
- `GET /call/<call_id>` - 获取API调用详情
### 管理员 `/api/admin`
- `GET /users` - 获取用户列表
- `GET /users/<user_id>` - 获取用户详情
- `POST /users/<user_id>/toggle-status` - 启用/禁用用户
- `POST /users/<user_id>/adjust-balance` - 调整用户余额
- `GET /orders` - 获取所有订单
- `GET /api-calls` - 获取所有API调用
- `GET /stats/overview` - 获取总览统计
- `GET /stats/chart` - 获取图表数据
## 数据库结构
### 用户表 (users)
- 邮箱、密码、用户名
- 余额、激活状态、管理员标识
- 创建时间、更新时间
### 订单表 (orders)
- 订单号、用户ID、金额
- 支付方式、订单状态
- 第三方交易ID、支付时间
### 交易记录表 (transactions)
- 用户ID、交易类型充值/消费/退款)
- 金额、前后余额
- 关联订单、关联API调用
### API调用表 (api_calls)
- 用户ID、API类型
- 提示词、参数、状态
- 结果URL、费用、错误信息
## 开发说明
### 添加新的路由
1.`routes/` 目录下创建新的蓝图文件
2.`app.py` 中注册蓝图
### 数据库迁移
```bash
# 进入 Python shell
python
>>> from app import create_app
>>> from models import db
>>> app = create_app()
>>> with app.app_context():
... db.create_all()
```
## 部署建议
### 使用 Gunicorn (生产环境)
```bash
pip install gunicorn
gunicorn -w 4 -b 0.0.0.0:5000 app:app
```
### 使用 Docker
```dockerfile
FROM python:3.9
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:5000", "app:app"]
```
## 注意事项
1. 生产环境请修改所有默认密钥
2. 配置实际的邮件服务器
3. 接入真实的支付接口
4. 配置 HTTPS
5. 定期备份数据库
## 许可证
MIT License

View File

@@ -0,0 +1,28 @@
@echo off
chcp 65001 >nul
title Nano Banana API - 启动后端服务
cd /d "%~dp0NBATransfer-backend"
if not exist ".env" (
echo ❌ 错误:未找到 .env 配置文件
echo 请先运行 一键安装.bat 或手动复制 .env.example 为 .env
pause
exit /b 1
)
echo.
echo ============================================
echo 🍌 Nano Banana API - 后端服务
echo ============================================
echo.
echo 启动后端服务...
echo.
echo 📍 访问地址http://localhost:5000
echo 📍 API 文档:查看 README.md
echo.
echo 按 Ctrl+C 停止服务
echo.
python app.py
pause