Coverage for src/edwh_auth_rbac/model.py: 78%

195 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-22 15:42 +0200

1import copy 

2import datetime as dt 

3import hashlib 

4import hmac 

5import typing 

6import uuid 

7from typing import Optional 

8from uuid import UUID 

9 

10import dateutil.parser 

11from pydal import DAL, Field, SQLCustomType 

12from pydal.objects import SQLALL, Query, Table 

13 

14from .helpers import IS_IN_LIST 

15 

16 

17class DEFAULT: 

18 pass 

19 

20 

21DEFAULT_STARTS = dt.datetime(2000, 1, 1) 

22DEFAULT_ENDS = dt.datetime(3000, 1, 1) 

23 

24 

25def unstr_datetime(s: dt.datetime | str) -> dt.datetime: 

26 """json helper... might values arrive as str""" 

27 return dateutil.parser.parse(s) if isinstance(s, str) else s 

28 

29 

30class Password: 

31 """ 

32 Encode a password using: Password.encode('secret') 

33 """ 

34 

35 @classmethod 

36 def hmac_hash(cls, value: str, key: str, salt: str = None) -> str: 

37 digest_alg = hashlib.sha512 

38 d = hmac.new(str(key).encode(), str(value).encode(), digest_alg) 

39 if salt: 

40 d.update(str(salt).encode()) 

41 return d.hexdigest() 

42 

43 @classmethod 

44 def validate(cls, password: str, candidate: str) -> bool: 

45 salt, hashed = candidate.split(":", 1) 

46 return cls.hmac_hash(value=password, key="secret_start", salt=salt) == hashed 

47 

48 @classmethod 

49 def encode(cls, password: str) -> str: 

50 salt = uuid.uuid4().hex 

51 return salt + ":" + cls.hmac_hash(value=password, key="secret_start", salt=salt) 

52 

53 

54def is_uuid(s) -> bool: 

55 try: 

56 UUID(s) 

57 return True 

58 except Exception: 

59 return False 

60 

61 

62IdentityKey: typing.TypeAlias = str | int | UUID 

63ObjectTypes = typing.Literal["user", "group", "item"] 

64 

65 

66def key_lookup_query(db: DAL, identity_key: IdentityKey, object_type: ObjectTypes = None) -> Query: 

67 if "@" in str(identity_key): 

68 query = db.identity.email == identity_key.lower() 

69 elif isinstance(identity_key, int): 

70 query = db.identity.id == identity_key 

71 elif is_uuid(identity_key): 

72 query = db.identity.object_id == identity_key.lower() 

73 else: 

74 query = db.identity.firstname == identity_key 

75 

76 if object_type: 

77 query &= db.identity.object_type == object_type 

78 

79 return query 

80 

81 

82def key_lookup(db: DAL, identity_key: IdentityKey, object_type: ObjectTypes = None) -> str | None: 

83 query = key_lookup_query(db, identity_key, object_type) 

84 

85 rowset = db(query).select(db.identity.object_id) 

86 

87 if not rowset: 

88 return None 

89 elif len(rowset) > 1: 

90 raise ValueError("Keep lookup for {} returned {} results.".format(identity_key, len(rowset))) 

91 

92 return rowset.first().object_id 

93 

94 

95my_datetime = SQLCustomType( 

96 type="string", native="char(35)", encoder=(lambda x: x.isoformat(" ")), decoder=(lambda x: dateutil.parser.parse(x)) 

97) 

98 

99 

100class RbacKwargs(typing.TypedDict, total=False): 

101 allowed_types: list[str] 

102 migrate: bool 

103 

104 

105class Identity(typing.Protocol): 

106 object_id: str 

107 object_type: str 

108 created: dt.datetime 

109 email: str 

110 firstname: str 

111 fullname: str 

112 encoded_password: str 

113 

114 

115def define_auth_rbac_model(db: DAL, other_args: RbacKwargs): 

116 migrate = other_args.get("migrate", False) 

117 

118 db.define_table( 

119 "identity", 

120 # std uuid from uuid libs are 36 chars long 

121 Field("object_id", "string", length=36, unique=True, notnull=True, default=str(uuid.uuid4())), 

122 Field("object_type", "string", requires=(IS_IN_LIST(other_args["allowed_types"]))), 

123 Field("created", "datetime", default=dt.datetime.now), 

124 # email needn't be unique, groups can share email addresses, and with people too 

125 Field("email", "string"), 

126 Field("firstname", "string", comment="also used as short code for groups"), 

127 Field("fullname", "string"), 

128 Field("encoded_password", "string"), 

129 migrate=migrate, 

130 ) 

131 

132 db.define_table( 

133 "membership", 

134 # beide zijn eigenlijk: reference:identity.object_id 

135 Field("subject", "string", length=36, notnull=True), 

136 Field("member_of", "string", length=36, notnull=True), 

137 # Field('starts','datetime', default=DEFAULT_STARTS), 

138 # Field('ends','datetime', default=DEFAULT_ENDS), 

139 migrate=migrate, 

140 ) 

141 

142 db.define_table( 

143 "permission", 

144 Field("privilege", "string", length=20), 

145 # reference:identity.object_id 

146 Field("identity_object_id", "string", length=36), 

147 Field("target_object_id", "string", length=36), 

148 # Field('scope'), lets bail scope for now. every one needs a rule for everything 

149 # just to make sure, no 'wildcards' and 'every dossier for org x' etc ... 

150 Field("starts", type=my_datetime, default=DEFAULT_STARTS), 

151 Field("ends", type=my_datetime, default=DEFAULT_ENDS), 

152 migrate=migrate, 

153 ) 

154 

155 db.define_table( 

156 "recursive_memberships", 

157 Field("root"), 

158 Field("object_id"), 

159 Field("object_type"), 

160 Field("level", "integer"), 

161 Field("email"), 

162 Field("firstname"), 

163 Field("fullname"), 

164 migrate=False, # view 

165 primarykey=["root", "object_id"], # composed, no primary key 

166 ) 

167 db.define_table( 

168 "recursive_members", 

169 Field("root"), 

170 Field("object_id"), 

171 Field("object_type"), 

172 Field("level", "integer"), 

173 Field("email"), 

174 Field("firstname"), 

175 Field("fullname"), 

176 migrate=False, # view 

177 primarykey=["root", "object_id"], # composed, no primary key 

178 ) 

179 

180 

181def add_identity( 

182 db: DAL, 

183 email: str, 

184 member_of: list[IdentityKey], 

185 name: str = None, 

186 firstname: str = None, 

187 fullname: str = None, 

188 password: str = None, 

189 gid: str | UUID = None, 

190 object_type: ObjectTypes = None, 

191) -> str: 

192 """paramaters name and firstname are equal.""" 

193 email = email.lower().strip() 

194 if object_type is None: 

195 raise ValueError("object_type parameter expected") 

196 object_id = gid if gid else uuid.uuid4() 

197 db.identity.validate_and_insert( 

198 object_id=object_id, 

199 object_type=object_type, 

200 email=email, 

201 firstname=name or firstname or None, 

202 fullname=fullname, 

203 encoded_password=Password.encode(password), 

204 ) 

205 db.commit() 

206 for key in member_of: 

207 group_id = key_lookup(db, key, "group") 

208 if get_group(db, group_id): 

209 # check each group if it exists. 

210 add_membership(db, identity_key=object_id, group_key=group_id) 

211 db.commit() 

212 return str(object_id) 

213 

214 

215def add_group(db: DAL, email: str, name: str, member_of: list[IdentityKey]): 

216 return add_identity(db, email, member_of, name=name, object_type="group") 

217 

218 

219def remove_identity(db: DAL, object_id: IdentityKey): 

220 removed = db(db.identity.object_id == object_id).delete() 

221 # todo: remove permissions and group memberships 

222 db.commit() 

223 return removed > 0 

224 

225 

226def get_identity(db: DAL, key: IdentityKey, object_type: ObjectTypes = None): 

227 """ 

228 :param db: dal db connection 

229 :param key: can be the email, id, or object_id 

230 :return: user record or None when not found 

231 """ 

232 query = key_lookup_query(db, key, object_type) 

233 rows = db(query).select() 

234 return rows.first() 

235 

236 

237def get_user(db: DAL, key: IdentityKey): 

238 """ 

239 :param db: dal db connection 

240 :param key: can be the email, id, or object_id 

241 :return: user record or None when not found 

242 """ 

243 return get_identity(db, key, object_type="user") 

244 

245 

246def get_group(db: DAL, key: IdentityKey): 

247 """ 

248 

249 :param db: dal db connection 

250 :param key: can be the name of the group, the id, object_id or email_address 

251 :return: user record or None when not found 

252 """ 

253 return get_identity(db, key, object_type="group") 

254 

255 

256def authenticate_user(db: DAL, password: str = None, user: Identity = None, key: IdentityKey = None): 

257 if not password: 

258 return False 

259 if not user: 

260 user = get_user(db, key) 

261 return Password.validate(password, user.encoded_password) 

262 

263 

264def add_membership(db: DAL, identity_key: IdentityKey, group_key: IdentityKey): 

265 identity_oid = key_lookup(db, identity_key) 

266 if identity_oid is None: 

267 raise ValueError("invalid identity_oid key: " + identity_key) 

268 group = get_group(db, group_key) 

269 if not group: 

270 raise ValueError("invalid group key: " + group_key) 

271 query = db.membership.subject == identity_oid 

272 query &= db.membership.member_of == group.object_id 

273 if db(query).count() == 0: 

274 db.membership.validate_and_insert( 

275 subject=identity_oid, 

276 member_of=group.object_id, 

277 ) 

278 db.commit() 

279 

280 

281def remove_membership(db: DAL, identity_key: IdentityKey, group_key: IdentityKey): 

282 identity = get_identity(db, identity_key) 

283 group = get_group(db, group_key) 

284 query = db.membership.subject == identity.object_id 

285 query &= db.membership.member_of == group.object_id 

286 deleted = db(query).delete() 

287 db.commit() 

288 return deleted 

289 

290 

291def get_memberships(db: DAL, object_id: IdentityKey, bare: bool = True): 

292 query = db.recursive_memberships.root == object_id 

293 fields = [db.recursive_memberships.object_id, db.recursive_memberships.object_type] if bare else [] 

294 return db(query).select(*fields) 

295 

296 

297def get_members(db: DAL, object_id: IdentityKey, bare: bool = True): 

298 query = db.recursive_members.root == object_id 

299 fields = [db.recursive_members.object_id, db.recursive_members.object_type] if bare else [] 

300 return db(query).select(*fields) 

301 

302 

303def add_permission( 

304 db: DAL, 

305 identity_key: IdentityKey, 

306 target_oid: IdentityKey, 

307 privilege: str, 

308 starts: dt.datetime | str = DEFAULT_STARTS, 

309 ends: dt.datetime | str = DEFAULT_ENDS, 

310): 

311 identity_oid = key_lookup(db, identity_key) 

312 starts = unstr_datetime(starts) 

313 ends = unstr_datetime(ends) 

314 if has_permission(db, identity_oid, target_oid, privilege, when=starts): 

315 # permission already granted. just skip it 

316 print( 

317 "{privilege} permission already granted to {user_or_group_key} on {target_oid} @ {starts} ".format( 

318 **locals() 

319 ) 

320 ) 

321 # print(db._lastsql) 

322 return 

323 db.permission.validate_and_insert( 

324 privilege=privilege, 

325 identity_object_id=identity_oid, 

326 target_object_id=target_oid, 

327 starts=starts, 

328 ends=ends, 

329 ) 

330 db.commit() 

331 

332 

333def remove_permission( 

334 db: DAL, identity_key: IdentityKey, target_oid: IdentityKey, privilege: str, when: dt.datetime | str = DEFAULT 

335): 

336 identity_oid = key_lookup(db, identity_key) 

337 if when is DEFAULT: 

338 when = dt.datetime.now() 

339 else: 

340 when = unstr_datetime(when) 

341 # base object is is the root to check for, user or group 

342 permission = db.permission 

343 query = permission.identity_object_id == identity_oid 

344 query &= permission.target_object_id == target_oid 

345 query &= permission.privilege == privilege 

346 query &= permission.starts <= when 

347 query &= permission.ends >= when 

348 result = db(query).delete() > 0 

349 db.commit() 

350 # print(db._lastsql) 

351 return result 

352 

353 

354def with_alias(db: DAL, source: Table, alias: str): 

355 other = copy.copy(source) 

356 other["ALL"] = SQLALL(other) 

357 other["_tablename"] = alias 

358 for fieldname in other.fields: 

359 tmp = source[fieldname].clone() 

360 tmp.bind(other) 

361 other[fieldname] = tmp 

362 if "id" in source and "id" not in other.fields: 

363 other["id"] = other[source.id.name] 

364 

365 if source_id := getattr(source, "_id", None): 

366 other._id = other[source_id.name] 

367 db[alias] = other 

368 return other 

369 

370 

371def has_permission( 

372 db: DAL, user_or_group_key: IdentityKey, target_oid: IdentityKey, privilege: str, when: dt.datetime | str = DEFAULT 

373): 

374 user_or_group_oid = key_lookup(db, user_or_group_key) 

375 # the permission system 

376 if when is DEFAULT: 

377 when = dt.datetime.now() 

378 else: 

379 when = unstr_datetime(when) 

380 # base object is is the root to check for, user or group 

381 root_oid = user_or_group_oid 

382 permission = db.permission 

383 # ugly hack to satisfy pydal aliasing keyed tables /views 

384 left = with_alias(db, db.recursive_memberships, "left") 

385 right = with_alias(db, db.recursive_memberships, "right") 

386 # left = db.recursive_memberships.with_alias('left') 

387 # right = db.recursive_memberships.with_alias('right') 

388 

389 # end of ugly hack 

390 query = left.root == root_oid 

391 query &= right.root == target_oid 

392 query &= permission.identity_object_id == left.object_id 

393 query &= permission.target_object_id == right.object_id 

394 query &= permission.privilege == privilege 

395 query &= permission.starts <= when 

396 query &= permission.ends >= when 

397 return db(query).count() > 0