1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
import base64, enum, hashlib, hmac, os, typing
# IANA Hash Function Textual Names
# https://tools.ietf.org/html/rfc5802#section-4
# https://www.iana.org/assignments/hash-function-text-names/
# MD2 has been removed as it's unacceptably weak
ALGORITHMS = [
"MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"]
SCRAM_ERRORS = [
"invalid-encoding",
"extensions-not-supported", # unrecognized 'm' value
"invalid-proof",
"channel-bindings-dont-match",
"server-does-support-channel-binding",
"channel-binding-not-supported",
"unsupported-channel-binding-type",
"unknown-user",
"invalid-username-encoding", # invalid utf8 or bad SASLprep
"no-resources"
]
def _scram_nonce() -> bytes:
return base64.b64encode(os.urandom(32))
def _scram_escape(s: bytes) -> bytes:
return s.replace(b"=", b"=3D").replace(b",", b"=2C")
def _scram_unescape(s: bytes) -> bytes:
return s.replace(b"=3D", b"=").replace(b"=2C", b",")
def _scram_xor(s1: bytes, s2: bytes) -> bytes:
return bytes(a ^ b for a, b in zip(s1, s2))
class SCRAMState(enum.Enum):
Uninitialised = 0
ClientFirst = 1
ClientFinal = 2
Success = 3
Failed = 4
VerifyFailed = 5
class SCRAMError(Exception):
pass
class SCRAM(object):
def __init__(self, algo: str, username: str, password: str):
if not algo in ALGORITHMS:
raise ValueError("Unknown SCRAM algorithm '%s'" % algo)
self._algo = algo.replace("-", "") # SHA-1 -> SHA1
self._username = username.encode("utf8")
self._password = password.encode("utf8")
self.state = SCRAMState.Uninitialised
self.error = ""
self.raw_error = ""
self._client_first = b""
self._salted_password = b""
self._auth_message = b""
def _get_pieces(self, data: bytes) -> typing.Dict[bytes, bytes]:
pieces = (piece.split(b"=", 1) for piece in data.split(b","))
return dict((piece[0], piece[1]) for piece in pieces)
def _hmac(self, key: bytes, msg: bytes) -> bytes:
return hmac.new(key, msg, self._algo).digest()
def _hash(self, msg: bytes) -> bytes:
return hashlib.new(self._algo, msg).digest()
def _constant_time_compare(self, b1: bytes, b2: bytes):
return hmac.compare_digest(b1, b2)
def client_first(self) -> bytes:
self.state = SCRAMState.ClientFirst
self._client_first = b"n=%s,r=%s" % (
_scram_escape(self._username), _scram_nonce())
# n,,n=<username>,r=<nonce>
return b"n,,%s" % self._client_first
def server_first(self, data: bytes) -> bytes:
self.state = SCRAMState.ClientFinal
pieces = self._get_pieces(data)
nonce = pieces[b"r"] # server combines your nonce with it's own
salt = base64.b64decode(pieces[b"s"]) # salt is b64encoded
iterations = int(pieces[b"i"])
salted_password = hashlib.pbkdf2_hmac(self._algo, self._password,
salt, iterations, dklen=None)
self._salted_password = salted_password
client_key = self._hmac(salted_password, b"Client Key")
stored_key = self._hash(client_key)
channel = base64.b64encode(b"n,,")
auth_noproof = b"c=%s,r=%s" % (channel, nonce)
auth_message = b"%s,%s,%s" % (self._client_first, data, auth_noproof)
self._auth_message = auth_message
client_signature = self._hmac(stored_key, auth_message)
client_proof_xor = _scram_xor(client_key, client_signature)
client_proof = base64.b64encode(client_proof_xor)
# c=<b64encode("n,,")>,r=<nonce>,p=<proof>
return b"%s,p=%s" % (auth_noproof, client_proof)
def server_final(self, data: bytes) -> bool:
pieces = self._get_pieces(data)
if b"e" in pieces:
error = pieces[b"e"].decode("utf8")
self.raw_error = error
if error in SCRAM_ERRORS:
self.error = error
else:
self.error = "other-error"
self.state = SCRAMState.Failed
return False
verifier = base64.b64decode(pieces[b"v"])
server_key = self._hmac(self._salted_password, b"Server Key")
server_signature = self._hmac(server_key, self._auth_message)
if server_signature == verifier:
self.state = SCRAMState.Success
return True
else:
self.state = SCRAMState.VerifyFailed
return False
|