Django-rest-framework频率认证与权限

引入

1.什么是认证、权限和限制

  • 认证是对用户身份进行验证, 然后权限、限制组件决定是否拒绝这个请求
  • 简单来说 : 认证确定了你是谁
  • 权限确定你能不能访问某个接口
  • 限制确定你访问的某个接口的频率

一.认证组件

0.作用

  • 校验用户, 三种类型 : 游客、合法用户、非法用户
  • 游客:代表校验通过,直接进入下一步校验(权限校验)
  • 非法用户:代表校验失败,抛出异常,返回403权限异常结果

1.源码分析

  • 分析步骤 APIView---->dispatch---->self.initial(request, *args, **kwargs)
  • initial方法里面认证、权限、频率, 我们先看认证 self.perform_authentication(request)

img

  • perform_authentication 方法里面就一句话 request.user, 调用Request类的 user 方法

img

  • 再调用 _authenticate 方法来开始验证

img

def _authenticate(self):
    # 遍历拿到一个认证类对象
    # self.authenticators是一个个我们配置的认证类产生的的认证类对象组成的列表[obj1,obj2..]
    # 如果我们没有进行配置,就从默认配置文件中读取配置类
    for authenticator in self.authenticators:
        try:
            # 认证类对象调用authenticate方法传入(认证类对象,request请求对象)
            # 返回登入的用户与认证的信息组成的 tuple
            # try 捕获异常,捕获到了则代表认证失败
            user_auth_tuple = authenticator.authenticate(self)
        except exceptions.APIException:
            self._not_authenticated()
            raise

        # 如果元组有内容
        if user_auth_tuple is not None:
            self._authenticator = authenticator
            # 将登入用户与登入认证分别解压赋值到self.user,self.auth中
            self.user, self.auth = user_auth_tuple
            return
    # 如果元组为空没有登入的用户和认证,则表示该用户是匿名用户
    self._not_authenticated()
  • 认证类局部配置
# 在某个视图类中添加该认证配置
# self.authenticators中得到的一个个认证类对象组成的列表就是从该配置中的类实例化出来的

authentication_classes = ['认证类1','认证类2','认证类3']
  • 认证类全局配置
# 在settings.py配置文件中进行配置
REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': ["mydrf.auth.UserAuth", ],
}    

2.认证组件的用法 (自定义Token认证)

必须重写 authenticate 方法

  • 创建表模型 : models.py
from django.db import models

# 存放用户名和密码
class User(models.Model):
    username = models.CharField(max_length=32)
    password = models.CharField(max_length=32)
    level = models.IntegerField(choices=((1, "超级用户"), (2, "普通用户"), (3, "游客")))

# 存放用户名外键和token
class UserToken(models.Model):
    user = models.OneToOneField(to="User", on_delete=models.CASCADE)
    token = models.CharField(max_length=64)

class Book(models.Model):
    title = models.CharField(max_length=32)
    price = models.IntegerField()
  • 创建自定义的 Response 类 : myresponse.py
from rest_framework.response import Response

# 自定义 Response 对象
class UserResponse(Response):
    def __init__(self, code=100, msg=None, data=None, status=None,
                 template_name=None, headers=None,
                 exception=False, content_type=None, **kwargs):
        dic = {'status': code, 'msg': msg}
        if data:
            dic['data'] = data
        if kwargs:
            dic.update(kwargs)

        super().__init__(data=dic, status=status,
                         template_name=template_name, headers=headers,
                         exception=exception, content_type=content_type)
  • 创建序列化类 : serializer.py
from rest_framework import serializers
from drf_test import models

class BookModelSerializer(serializers.ModelSerializer):
    class Meta:
        model = models.Book
        fields = "__all__"

class UserModelSerializer(serializers.ModelSerializer):
    class Meta:
        model = models.User
        fields = "__all__"
  • 创建认证类 : auth.py
from drf_test.models import UserToken
from rest_framework.exceptions import APIException
from rest_framework.authentication import BaseAuthentication

# 继承 BaseAuthentication 类
class UserAuth(BaseAuthentication):
    # 必须重写该方法
    def authenticate(self, request):
        token = request.META.get('HTTP_TOKEN')
        print(token)
        if token:
            token_obj = UserToken.objects.filter(token=token).first()
            if token_obj:
                # 必须返回两个参数
                return token_obj.user,token
            else:
                raise APIException('token 無效')
        raise APIException('token 不存在')
  • views.py
from drf_test import models
from drf_test.myresponse import UserResponse
from rest_framework.viewsets import ViewSetMixin
from drf_test.serializer import UserModelSerializer
from drf_test.serializer import BookModelSerializer
from rest_framework.decorators import action
from rest_framework.generics import ListAPIView, RetrieveAPIView, CreateAPIView
from drf_test.auth import UserAuth
import uuid

# 用户登入
class LoginView(ViewSetMixin, CreateAPIView):
    queryset = models.User.objects.all()
    serializer_class = UserModelSerializer
    # 用户认证(局部配置)
    authentication_classes = [UserAuth]

    @action(methods=['POst'], detail=False)
    def login(self, request):
        username = request.data.get('username')
        password = request.data.get('password')
        user_obj = models.User.objects.filter(username=username, password=password).first()
        if user_obj:
            token = uuid.uuid4()
            models.UserToken.objects.update_or_create(defaults={'token': token}, user=user_obj)
            return UserResponse(handers={'token': token},msg='成功')
        else:
            return UserResponse(code=201, msg='用户名或密码错误!')

# 书籍查询
class BookView(ViewSetMixin, ListAPIView, CreateAPIView,RetrieveAPIView):
    queryset = models.Book.objects.all()
    serializer_class = BookModelSerializer

    @action(detail=True)
    def get(self, request, pk):
        # 查出前几条(不是第几条)
        book_qs = self.get_queryset()[:int(pk)]
        serializer = self.get_serializer(instance=book_qs, many=True)
        return UserResponse(book_info=serializer.data)
  • 全局配置认证
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": ["drf_test.auth.UserAuth", ],
}

3.自定义token认证(高级版)

user_auth.py ,方便copy

"""
常用的认证以及DRF的认证
"""
from dimension import models
import jwt
from rest_framework import exceptions
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
from rest_framework.exceptions import AuthenticationFailed
from rest_framework_jwt.authentication import jwt_decode_handler
from django.conf import settings
from django.contrib.auth import get_user_model
from six import text_type
from django.utils.translation import ugettext as _
from rest_framework import exceptions
from rest_framework_jwt.utils import jwt_decode_handler

User = get_user_model()

class OpAuthJwtAuthentication(object):
    """
    统一JWT认证
    """

    def authenticate(self, request):
        token = self.get_header_authorization(request) or self.get_cookie_authorization(request)
        if not token:
            raise AuthenticationFailed('当前用户未登录,请登录!')
        try:
            payload = jwt_decode_handler(token)
        except jwt.ExpiredSignature:
            msg = _('会话过期,请重新登录!')
            raise exceptions.AuthenticationFailed(msg)
        except jwt.DecodeError:
            msg = _('签名错误,请重试!')
            raise exceptions.AuthenticationFailed(msg)
        except jwt.InvalidTokenError:
            raise exceptions.AuthenticationFailed()
        except User.DoesNotExist:
            raise exceptions.AuthenticationFailed()

        user_id = payload.get('user_id', None)
        if not user_id:
            return None
        user_id_field = settings.USER_FIELD or 'user_id'
        user = User.objects.filter(**{user_id_field: user_id}).first()
        if not user or not user.is_active:
            return None
        return user, token

    def authenticate_header(self, request):
        pass

    @classmethod
    def get_header_authorization(cls, request):
        """
        获取header里的认证信息, 通常用于跨域携带请求
        :param request:
        :return:
        """
        auth = request.META.get('HTTP_AUTHORIZATION', b'')
        if isinstance(auth, text_type):
            auth = auth.encode(settings.JWT_AUTH.get('HTTP_HEADER_ENCODING', 'iso-8859-1'))
        if not auth:
            return ''
        auth = str(auth, encoding='utf-8').split()
        if len(auth) != 2 or auth[0].upper() != settings.JWT_AUTH.get('JWT_AUTH_HEADER_PREFIX', 'JWT').upper():
            return ''
        return auth[1]

    @classmethod
    def get_cookie_authorization(cls, request):
        """
        获取cookie里JWT认证信息
        :param request:
        :return:
        """
        auth = request.COOKIES.get(settings.JWT_AUTH.get('JWT_AUTH_COOKIE', 'AUTH_JWT'), '')
        auth = auth.split()
        if len(auth) != 2 or auth[0].upper() != settings.JWT_AUTH.get('JWT_AUTH_HEADER_PREFIX', 'JWT'):
            return ''
        return auth[1]

class JsonAuthentication(JSONWebTokenAuthentication):
    def authenticate(self, request):
        jwt_value = self.get_jwt_value(request)
        if not jwt_value:
            raise AuthenticationFailed('当前用户未登录,请登录!')
        # 验证签名,验证是否过期
        try:
            # 得到荷载
            payload = jwt_decode_handler(jwt_value)
            # 效率更高一写,不需要查数据库了
            user = models.User(user_id=payload['user_id'])
        except jwt.ExpiredSignature:
            msg = '会话过期,请重新登录!'
            raise exceptions.AuthenticationFailed(msg)
        except jwt.DecodeError:
            msg = '签名错误,请重试!'
            raise exceptions.AuthenticationFailed(msg)
        except jwt.InvalidTokenError:
            raise exceptions.AuthenticationFailed("签名验证错误,非法用户")
        return (user, jwt_value)

二.权限组件

0.作用

  • 权限控制可以限制用户对于视图的访问和对于具体数据对象的访问
  • 认证通过, 可以进行下一步验证 (频率认证)
  • 认证失败, 抛出权限异常结果

1.源码分析

  • 权限组件入口 : APIView 中 dispatch 的 self.check_permissions(request)
def check_permissions(self, request):
    # 遍历权限对象列表得到一个个权限对象(权限器),进行权限认证
    for permission in self.get_permissions():
        # 权限类一定有一个has_permission权限方法,用来做权限认证的
        # 参数:权限对象self、请求对象request、视图类对象
        # 返回值:有权限返回True,无权限返回False
        if not permission.has_permission(request, self):
            self.permission_denied(
                request,
                message=getattr(permission, 'message', None),
                code=getattr(permission, 'code', None)
            )
# 执行 has_permission 方法,会执行你全局或局部配置的自定义权限类,通过返回值来确定是否抛出异常
# True 认证通过, False 认证不通过, 再通过"self.permission_denied"抛出异常
# 一旦抛出异常, dispatch中try之下的diamante将不再运行

3.权限组件用法

  • 创建权限类, 需要继承 BasePermission 父类, 必须重写 has_permissionhas_object_permission 方法
class UserPermission(BasePermission):
    # 自定义的提示信息(原本是英文)
    message = '没有权限访问!'
    # 重写该方法
    def has_permission(self, request, view):
        # 权限在认证之后, 所以可以取到user
        if request.user.level == 1:
            return True
        else:
            self.message = f'你是{request.user.get_level_display()}用户,没有权限!'
            return False
  • 局部配置使用
from drf_test.auth import UserPernission

# 在需要权限认证的视图类中书写
permission_classes = [UserPermission]
  • 全局配置使用
REST_FRAMEWORK = {
    "DEFAULT_PERMISSION_CLASSES": ["drf_test.auth.UserPermission", ],
}

三.频率组件

0.作用

  • 限制视图接口被访问的频率次数
  • 限制条件 : IP、ID、唯一键
  • 频率周期 : 时(h)、分(m)、秒(s)
  • 频率次数 : [num] / s
  • 没有达到限制频率可正常访问接口
  • 达到了频率限制次数, 在限制时间内不能进行访问, 超过时间后可以正常访问

1.源码分析

  • 频率组件入口 : APIView 中 dispatch 的 self.check_throttles(request)

原文内容

def check_throttles(self, request):
    throttle_durations = []
    # 1)遍历配置的频率认证类,初始化得到一个个频率认证类对象(会调用频率认证类的 __init__() 方法)
    # 2)频率认证类对象调用 allow_request 方法,判断是否限次(没有限次可访问,限次不可访问)
    # 3)频率认证类对象在限次后,调用 wait 方法,获取还需等待多长时间可以进行下一次访问
    # 注:频率认证类都是继承 SimpleRateThrottle 类
    for throttle in self.get_throttles():
        if not throttle.allow_request(request, self):
            # 只要频率限制了,allow_request 返回False了,才会调用wait
            throttle_durations.append(throttle.wait())

            if throttle_durations:
                # Filter out `None` values which may happen in case of config / rate
                # changes, see #1438
                durations = [
                    duration for duration in throttle_durations
                    if duration is not None
                ]

                duration = max(durations, default=None)
                self.throttled(request, duration)

class SimpleRateThrottle(BaseThrottle):
  def __init__(self):
     if not getattr(self, 'rate', None):
        # 得到settings配置的 次数/时间 赋值给rate
        self.rate = self.get_rate()
    # 将切分后的 '次数/时间' 解压赋值 num_requests=次数,duration=时间
    self.num_requests, self.duration = self.parse_rate(self.rate)
  # 自定义频率限制时 需要我们实现的方法
  def get_cache_key(self, request, view):
      raise NotImplementedError('.get_cache_key() must be overridden')

2.频率组件的使用

必须重写 allow_request, 由于我们继承了SimpleRateThrottle类, 对基类进行了进一步的封装, 所以只需要重写get_cache_key,返回什么就以什么作为限制条件(ip,用户id)

  • 先创建一个频率类
class MyThrottles(SimpleRateThrottle):
    scope = 'ip'

    def get_cache_key(self, request, view):
        # 返回什么就以什么作为限制条件
        return self.get_ident(request)  # IP作为限制(也可以下面写法)
        # return request.Meta.get('REMOTE_ADDR')
        # return request.user.id  # 以用户id作为限制
  • view.py
class BookView(ViewSetMixin, ListAPIView, CreateAPIView,RetrieveAPIView):
    queryset = models.Book.objects.all()
    serializer_class = BookModelSerializer
    authentication_classes = [UserAuth]
    permission_classes = [UserPermission]
    throttle_classes = [MyThrottles]

    @action(detail=True)
    def get(self, request, pk):
        # 查出前几条(不是第几条)
        book_qs = self.get_queryset()[:int(pk)]
        serializer = self.get_serializer(instance=book_qs, many=True)
        return UserResponse(book_info=serializer.data)
  • 设置限制次数
# setting.py 文件
REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_RATES': {
        'ip': '5/m',  #一分钟访问5次
    },
}
  • 局部使用
# 在某个视图类中书写
throttle_classes = [MyThrottles]
  • 全局使用
# 在settings.py
REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES": ["app01.auth.MyThrottle", ],
    'DEFAULT_THROTTLE_RATES': {
        'ip': '5/m',  #一分钟访问5次
    },
}

3.自定义频率类

  • 实现步骤 :
  1. 取出访问者ip
  2. 判断当前ip不在访问字典里,添加进去,并且直接返回True, 表示第一次访问,在字典里,继续往下走
  3. 循环判断当前 ip 的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据 pop 掉,这样列表中只有 60s 以内的访问时间
  4. 判断,当列表长度小于3(代表访问的次数小于3),说明一分钟以内访问不足三次,再把当前时间插入到列表第一个位置,返回True,顺利通过
  5. 当大于等于3,说明一分钟内访问超过三次,返回 False 验证失败
  • 实现代码
class MyThrottles(BaseThrottle):
    VISIT_RECORD = {}  # 记录访问者的字典
    def __init__(self):
        self.history=None  # 用来存时间戳次数的列表
    def allow_request(self,request, view):
        #(1)取出访问者ip
        ip=request.META.get('REMOTE_ADDR')
        import time
        ctime=time.time()
        # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
        if ip not in self.VISIT_RECORD:
            self.VISIT_RECORD[ip]=[ctime,]  # 将当前时间放在列表的第一个位置(第一次访问)
            return True
        self.history=self.VISIT_RECORD.get(ip,[])  # 拿出某个ip的时间戳列表
        # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
        while self.history and ctime-self.history[-1]>60:
            self.history.pop()
        # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
        # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
        if len(self.history)<3:
            self.history.insert(0,ctime)
            return True
        else:
            return False
    def wait(self):
        import time
        ctime=time.time()
        return 60-(ctime-self.history[-1])  # 展示还剩多少秒
        # ctime-self.history[-1]  (表示最早一次登入到现在过去了多少时间)

SimpleRateThrottle 内部源码实现方式也是如此, 只不过通过配置, 它的可扩展性更高

四.接口补充

接口 : 一种规范, 协议, 用来保证交互

  • 上面我们使用认证、权限、频率都必须要重写父类中的方法, 不然直接 raise 异常
  • 在规范子类的行为上, Python中有abc模块来强制规定子类必须要重写父类的某一个方法, 不然就报错
  • 但其实Python并不推崇这种做法(接口), Python推崇鸭子类型
  • 鸭子类型 : 即你不需要强制性的规定子类写某个方法, 子类中只要有相同的方法, 那么就可以说他们是一类

例 : 制作14根螺丝, 有4个4厘米长的, 那么我们就可以说这4个4厘米长的螺丝属于"4厘米螺丝"

  • 而认证、权限、频率是通过抛出异常的方式来进行限制的, 只要不重写, 就直接抛出异常 :

img

版权声明:
作者:淘小欣
链接:https://blog.taoxiaoxin.club/179.html
来源:淘小欣的博客
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
海报
Django-rest-framework频率认证与权限
引入 1.什么是认证、权限和限制 认证是对用户身份进行验证, 然后权限、限制组件决定是否拒绝这个请求 简单来说 : 认证确定了你是谁 权限确定你能不能访问某个……
<<上一篇
下一篇>>