| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- from __future__ import annotations
- import jwt
- from jwt import PyJWTError
- from app.core.config import settings
- def _parse_base64_binary(text: str) -> bytes:
- """
- Base64 解码实现
- """
- decode_map = _init_decode_map()
- PADDING = 127
- # 长度估算逻辑
- buflen = _guess_length(text, decode_map, PADDING)
- out = bytearray(buflen)
- o = 0
- length = len(text)
- quadruplet = [0] * 4
- q = 0
- for i in range(length):
- ch = text[i]
- char_code = ord(ch)
- # 只检查 ASCII 范围内的字符
- if char_code < 128:
- v = decode_map[char_code]
- # 只有当 v != -1 时才加入 quadruplet
- if v != -1:
- quadruplet[q] = v
- q += 1
- # 当 q == 4 时进行解码
- if q == 4:
- # 第一个字节:quadruplet[0] << 2 | quadruplet[1] >> 4
- byte_val = (quadruplet[0] << 2) | (quadruplet[1] >> 4)
- out[o] = byte_val & 0xFF # 确保在 0-255 范围内
- o += 1
- # 第二个字节:只有 quadruplet[2] 不是填充时才计算
- if quadruplet[2] != PADDING:
- byte_val = (quadruplet[1] << 4) | (quadruplet[2] >> 2)
- out[o] = byte_val & 0xFF # 确保在 0-255 范围内
- o += 1
- # 第三个字节:只有 quadruplet[3] 不是填充时才计算
- if quadruplet[3] != PADDING:
- byte_val = (quadruplet[2] << 6) | quadruplet[3]
- out[o] = byte_val & 0xFF # 确保在 0-255 范围内
- o += 1
- q = 0
- # 返回正确长度的字节数组
- if buflen == o:
- return bytes(out)
- else:
- # 如果长度不匹配,创建新数组并复制
- nb = bytearray(o)
- nb[:] = out[:o]
- return bytes(nb)
- def _guess_length(text: str, decode_map: list, padding: int) -> int:
- """
- 长度估算逻辑
- """
- length = len(text)
- # 从末尾开始找到第一个非填充字符
- j = length - 1
- while j >= 0:
- char_code = ord(text[j])
- # 只检查 ASCII 字符
- if char_code < 128:
- code = decode_map[char_code]
- if code != padding:
- if code == -1:
- # 包含无效字符,使用标准估算
- return length // 4 * 3
- break
- j -= 1
- j += 1
- padding_count = length - j
- # 计算输出长度
- if padding_count > 2:
- return length // 4 * 3
- else:
- return length // 4 * 3 - padding_count
- def _init_decode_map() -> list:
- """
- 解码映射表
- """
- decode_map = [-1] * 128
- # A-Z: 0-25
- for i in range(65, 91): # 'A' to 'Z'
- decode_map[i] = i - 65
- # a-z: 26-51
- for i in range(97, 123): # 'a' to 'z'
- decode_map[i] = i - 97 + 26
- # 0-9: 52-61
- for i in range(48, 58): # '0' to '9'
- decode_map[i] = i - 48 + 52
- # '+' -> 62, '/' -> 63, '=' -> 127
- decode_map[43] = 62 # '+'
- decode_map[47] = 63 # '/'
- decode_map[61] = 127 # '=' (PADDING)
- return decode_map
- def verify_jwt_token(token: str) -> dict | None:
- """
- 验证JWT token并返回payload
- """
- try:
- payload = jwt.decode(
- token,
- _parse_base64_binary(settings.JWT_SECRET_KEY),
- algorithms=[settings.JWT_ALGORITHM]
- )
- return payload
- except PyJWTError:
- return None
|