jwt_utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from __future__ import annotations
  2. import jwt
  3. from jwt import PyJWTError
  4. from app.core.config import settings
  5. def _parse_base64_binary(text: str) -> bytes:
  6. """
  7. Base64 解码实现
  8. """
  9. decode_map = _init_decode_map()
  10. PADDING = 127
  11. # 长度估算逻辑
  12. buflen = _guess_length(text, decode_map, PADDING)
  13. out = bytearray(buflen)
  14. o = 0
  15. length = len(text)
  16. quadruplet = [0] * 4
  17. q = 0
  18. for i in range(length):
  19. ch = text[i]
  20. char_code = ord(ch)
  21. # 只检查 ASCII 范围内的字符
  22. if char_code < 128:
  23. v = decode_map[char_code]
  24. # 只有当 v != -1 时才加入 quadruplet
  25. if v != -1:
  26. quadruplet[q] = v
  27. q += 1
  28. # 当 q == 4 时进行解码
  29. if q == 4:
  30. # 第一个字节:quadruplet[0] << 2 | quadruplet[1] >> 4
  31. byte_val = (quadruplet[0] << 2) | (quadruplet[1] >> 4)
  32. out[o] = byte_val & 0xFF # 确保在 0-255 范围内
  33. o += 1
  34. # 第二个字节:只有 quadruplet[2] 不是填充时才计算
  35. if quadruplet[2] != PADDING:
  36. byte_val = (quadruplet[1] << 4) | (quadruplet[2] >> 2)
  37. out[o] = byte_val & 0xFF # 确保在 0-255 范围内
  38. o += 1
  39. # 第三个字节:只有 quadruplet[3] 不是填充时才计算
  40. if quadruplet[3] != PADDING:
  41. byte_val = (quadruplet[2] << 6) | quadruplet[3]
  42. out[o] = byte_val & 0xFF # 确保在 0-255 范围内
  43. o += 1
  44. q = 0
  45. # 返回正确长度的字节数组
  46. if buflen == o:
  47. return bytes(out)
  48. else:
  49. # 如果长度不匹配,创建新数组并复制
  50. nb = bytearray(o)
  51. nb[:] = out[:o]
  52. return bytes(nb)
  53. def _guess_length(text: str, decode_map: list, padding: int) -> int:
  54. """
  55. 长度估算逻辑
  56. """
  57. length = len(text)
  58. # 从末尾开始找到第一个非填充字符
  59. j = length - 1
  60. while j >= 0:
  61. char_code = ord(text[j])
  62. # 只检查 ASCII 字符
  63. if char_code < 128:
  64. code = decode_map[char_code]
  65. if code != padding:
  66. if code == -1:
  67. # 包含无效字符,使用标准估算
  68. return length // 4 * 3
  69. break
  70. j -= 1
  71. j += 1
  72. padding_count = length - j
  73. # 计算输出长度
  74. if padding_count > 2:
  75. return length // 4 * 3
  76. else:
  77. return length // 4 * 3 - padding_count
  78. def _init_decode_map() -> list:
  79. """
  80. 解码映射表
  81. """
  82. decode_map = [-1] * 128
  83. # A-Z: 0-25
  84. for i in range(65, 91): # 'A' to 'Z'
  85. decode_map[i] = i - 65
  86. # a-z: 26-51
  87. for i in range(97, 123): # 'a' to 'z'
  88. decode_map[i] = i - 97 + 26
  89. # 0-9: 52-61
  90. for i in range(48, 58): # '0' to '9'
  91. decode_map[i] = i - 48 + 52
  92. # '+' -> 62, '/' -> 63, '=' -> 127
  93. decode_map[43] = 62 # '+'
  94. decode_map[47] = 63 # '/'
  95. decode_map[61] = 127 # '=' (PADDING)
  96. return decode_map
  97. def verify_jwt_token(token: str) -> dict | None:
  98. """
  99. 验证JWT token并返回payload
  100. """
  101. try:
  102. payload = jwt.decode(
  103. token,
  104. _parse_base64_binary(settings.JWT_SECRET_KEY),
  105. algorithms=[settings.JWT_ALGORITHM]
  106. )
  107. return payload
  108. except PyJWTError:
  109. return None