from __future__ import annotations import jwt from jwt import PyJWTError from app.core.config import settings def _parse_base64_binary(text: str) -> bytes: """ Base64 解码实现 """ decode_map = _init_decode_map() PADDING = 127 # 长度估算逻辑 buflen = _guess_length(text, decode_map, PADDING) out = bytearray(buflen) o = 0 length = len(text) quadruplet = [0] * 4 q = 0 for i in range(length): ch = text[i] char_code = ord(ch) # 只检查 ASCII 范围内的字符 if char_code < 128: v = decode_map[char_code] # 只有当 v != -1 时才加入 quadruplet if v != -1: quadruplet[q] = v q += 1 # 当 q == 4 时进行解码 if q == 4: # 第一个字节:quadruplet[0] << 2 | quadruplet[1] >> 4 byte_val = (quadruplet[0] << 2) | (quadruplet[1] >> 4) out[o] = byte_val & 0xFF # 确保在 0-255 范围内 o += 1 # 第二个字节:只有 quadruplet[2] 不是填充时才计算 if quadruplet[2] != PADDING: byte_val = (quadruplet[1] << 4) | (quadruplet[2] >> 2) out[o] = byte_val & 0xFF # 确保在 0-255 范围内 o += 1 # 第三个字节:只有 quadruplet[3] 不是填充时才计算 if quadruplet[3] != PADDING: byte_val = (quadruplet[2] << 6) | quadruplet[3] out[o] = byte_val & 0xFF # 确保在 0-255 范围内 o += 1 q = 0 # 返回正确长度的字节数组 if buflen == o: return bytes(out) else: # 如果长度不匹配,创建新数组并复制 nb = bytearray(o) nb[:] = out[:o] return bytes(nb) def _guess_length(text: str, decode_map: list, padding: int) -> int: """ 长度估算逻辑 """ length = len(text) # 从末尾开始找到第一个非填充字符 j = length - 1 while j >= 0: char_code = ord(text[j]) # 只检查 ASCII 字符 if char_code < 128: code = decode_map[char_code] if code != padding: if code == -1: # 包含无效字符,使用标准估算 return length // 4 * 3 break j -= 1 j += 1 padding_count = length - j # 计算输出长度 if padding_count > 2: return length // 4 * 3 else: return length // 4 * 3 - padding_count def _init_decode_map() -> list: """ 解码映射表 """ decode_map = [-1] * 128 # A-Z: 0-25 for i in range(65, 91): # 'A' to 'Z' decode_map[i] = i - 65 # a-z: 26-51 for i in range(97, 123): # 'a' to 'z' decode_map[i] = i - 97 + 26 # 0-9: 52-61 for i in range(48, 58): # '0' to '9' decode_map[i] = i - 48 + 52 # '+' -> 62, '/' -> 63, '=' -> 127 decode_map[43] = 62 # '+' decode_map[47] = 63 # '/' decode_map[61] = 127 # '=' (PADDING) return decode_map def verify_jwt_token(token: str) -> dict | None: """ 验证JWT token并返回payload """ try: payload = jwt.decode( token, _parse_base64_binary(settings.JWT_SECRET_KEY), algorithms=[settings.JWT_ALGORITHM] ) return payload except PyJWTError: return None