Coverage for test_providers.py: 100%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import datetime 

2 

3from plain.auth import get_user_model 

4from plain.oauth.models import OAuthConnection 

5from plain.oauth.providers import OAuthProvider, OAuthToken, OAuthUser 

6 

7 

8class DummyProvider(OAuthProvider): 

9 authorization_url = "https://example.com/oauth/authorize" 

10 

11 def generate_state(self) -> str: 

12 return "dummy_state" 

13 

14 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: 

15 return OAuthToken( 

16 access_token="refreshed_dummy_access_token", 

17 refresh_token="refreshed_dummy_refresh_token", 

18 access_token_expires_at=datetime.datetime( 

19 2029, 1, 1, 0, 0, tzinfo=datetime.UTC 

20 ), 

21 refresh_token_expires_at=datetime.datetime( 

22 2029, 1, 2, 0, 0, tzinfo=datetime.UTC 

23 ), 

24 ) 

25 

26 def get_oauth_token(self, *, code, request) -> OAuthToken: 

27 return OAuthToken( 

28 access_token="dummy_access_token", 

29 refresh_token="dummy_refresh_token", 

30 access_token_expires_at=datetime.datetime( 

31 2020, 1, 1, 0, 0, tzinfo=datetime.UTC 

32 ), 

33 refresh_token_expires_at=datetime.datetime( 

34 2020, 1, 2, 0, 0, tzinfo=datetime.UTC 

35 ), 

36 ) 

37 

38 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: 

39 return OAuthUser( 

40 id="dummy_id", 

41 email="dummy@example.com", 

42 username="dummy_username", 

43 ) 

44 

45 

46def test_dummy_signup(db, client, settings): 

47 settings.OAUTH_LOGIN_PROVIDERS = { 

48 "dummy": { 

49 "class": "test_providers.DummyProvider", 

50 "kwargs": { 

51 "client_id": "dummy_client_id", 

52 "client_secret": "dummy_client_secret", 

53 "scope": "dummy_scope", 

54 }, 

55 } 

56 } 

57 

58 assert get_user_model().objects.count() == 0 

59 assert OAuthConnection.objects.count() == 0 

60 

61 # Login required for this view 

62 response = client.get("/") 

63 assert response.status_code == 302 

64 assert response.url == "/login/?next=/" 

65 

66 # User clicks the login link (form submit) 

67 response = client.post("/oauth/dummy/login/") 

68 assert response.status_code == 302 

69 assert ( 

70 response.url 

71 == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" 

72 ) 

73 

74 # Provider redirects to the callback url 

75 response = client.get("/oauth/dummy/callback/?code=test_code&state=dummy_state") 

76 assert response.status_code == 302 

77 assert response.url == "/" 

78 

79 # Now logged in 

80 response = client.get("/") 

81 assert response.status_code == 200 

82 assert b"Hello dummy_username!\n" in response.content 

83 

84 # Check the user and connection that was created 

85 user = response.user 

86 assert user.username == "dummy_username" 

87 assert user.email == "dummy@example.com" 

88 connections = user.oauth_connections.all() 

89 assert len(connections) == 1 

90 assert connections[0].provider_key == "dummy" 

91 assert connections[0].provider_user_id == "dummy_id" 

92 assert connections[0].access_token == "dummy_access_token" 

93 assert connections[0].refresh_token == "dummy_refresh_token" 

94 assert connections[0].access_token_expires_at == datetime.datetime( 

95 2020, 1, 1, 0, 0, tzinfo=datetime.UTC 

96 ) 

97 assert connections[0].refresh_token_expires_at == datetime.datetime( 

98 2020, 1, 2, 0, 0, tzinfo=datetime.UTC 

99 ) 

100 

101 assert get_user_model().objects.count() == 1 

102 assert OAuthConnection.objects.count() == 1 

103 

104 

105def test_dummy_login_connection(db, client, settings): 

106 settings.OAUTH_LOGIN_PROVIDERS = { 

107 "dummy": { 

108 "class": "test_providers.DummyProvider", 

109 "kwargs": { 

110 "client_id": "dummy_client_id", 

111 "client_secret": "dummy_client_secret", 

112 "scope": "dummy_scope", 

113 }, 

114 } 

115 } 

116 

117 assert get_user_model().objects.count() == 0 

118 assert OAuthConnection.objects.count() == 0 

119 

120 # Create a user 

121 user = get_user_model().objects.create( 

122 username="dummy_username", email="dummy@example.com" 

123 ) 

124 OAuthConnection.objects.create( 

125 user=user, 

126 provider_key="dummy", 

127 provider_user_id="dummy_id", 

128 access_token="dummy_access_token", 

129 refresh_token="dummy_refresh_token", 

130 access_token_expires_at=datetime.datetime( 

131 2020, 1, 1, 0, 0, tzinfo=datetime.UTC 

132 ), 

133 refresh_token_expires_at=datetime.datetime( 

134 2020, 1, 2, 0, 0, tzinfo=datetime.UTC 

135 ), 

136 ) 

137 

138 assert get_user_model().objects.count() == 1 

139 assert OAuthConnection.objects.count() == 1 

140 

141 # Login required for this view 

142 response = client.get("/") 

143 assert response.status_code == 302 

144 assert response.url == "/login/?next=/" 

145 

146 # User clicks the login link (form submit) 

147 response = client.post("/oauth/dummy/login/") 

148 assert response.status_code == 302 

149 assert ( 

150 response.url 

151 == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" 

152 ) 

153 

154 # Provider redirects to the callback url 

155 response = client.get("/oauth/dummy/callback/?code=test_code&state=dummy_state") 

156 assert response.status_code == 302 

157 assert response.url == "/" 

158 

159 # Now logged in 

160 response = client.get("/") 

161 assert response.status_code == 200 

162 assert b"Hello dummy_username!\n" in response.content 

163 

164 # Check the user and connection that was created 

165 user = response.user 

166 assert user.username == "dummy_username" 

167 assert user.email == "dummy@example.com" 

168 connections = user.oauth_connections.all() 

169 assert len(connections) == 1 

170 assert connections[0].provider_key == "dummy" 

171 assert connections[0].provider_user_id == "dummy_id" 

172 assert connections[0].access_token == "dummy_access_token" 

173 assert connections[0].refresh_token == "dummy_refresh_token" 

174 assert connections[0].access_token_expires_at == datetime.datetime( 

175 2020, 1, 1, 0, 0, tzinfo=datetime.UTC 

176 ) 

177 assert connections[0].refresh_token_expires_at == datetime.datetime( 

178 2020, 1, 2, 0, 0, tzinfo=datetime.UTC 

179 ) 

180 

181 assert get_user_model().objects.count() == 1 

182 assert OAuthConnection.objects.count() == 1 

183 

184 

185def test_dummy_login_without_connection(db, client, settings): 

186 settings.OAUTH_LOGIN_PROVIDERS = { 

187 "dummy": { 

188 "class": "test_providers.DummyProvider", 

189 "kwargs": { 

190 "client_id": "dummy_client_id", 

191 "client_secret": "dummy_client_secret", 

192 "scope": "dummy_scope", 

193 }, 

194 } 

195 } 

196 

197 assert get_user_model().objects.count() == 0 

198 assert OAuthConnection.objects.count() == 0 

199 

200 # Create a user 

201 get_user_model().objects.create( 

202 username="dummy_username", email="dummy@example.com" 

203 ) 

204 

205 assert get_user_model().objects.count() == 1 

206 assert OAuthConnection.objects.count() == 0 

207 

208 # Login required for this view 

209 response = client.get("/") 

210 assert response.status_code == 302 

211 assert response.url == "/login/?next=/" 

212 

213 # User clicks the login link (form submit) 

214 response = client.post("/oauth/dummy/login/") 

215 assert response.status_code == 302 

216 assert ( 

217 response.url 

218 == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" 

219 ) 

220 

221 # Provider redirects to the callback url 

222 response = client.get("/oauth/dummy/callback/?code=test_code&state=dummy_state") 

223 assert response.status_code == 400 

224 assert b"OAuth Error" in response.content 

225 

226 

227def test_dummy_connect(db, client, settings): 

228 settings.OAUTH_LOGIN_PROVIDERS = { 

229 "dummy": { 

230 "class": "test_providers.DummyProvider", 

231 "kwargs": { 

232 "client_id": "dummy_client_id", 

233 "client_secret": "dummy_client_secret", 

234 "scope": "dummy_scope", 

235 }, 

236 } 

237 } 

238 

239 assert get_user_model().objects.count() == 0 

240 assert OAuthConnection.objects.count() == 0 

241 

242 # Create a user 

243 user = get_user_model().objects.create( 

244 username="dummy_username", email="dummy@example.com" 

245 ) 

246 

247 assert get_user_model().objects.count() == 1 

248 assert OAuthConnection.objects.count() == 0 

249 

250 client.force_login(user) 

251 

252 response = client.post("/oauth/dummy/connect/") 

253 assert response.status_code == 302 

254 assert ( 

255 response.url 

256 == "https://example.com/oauth/authorize?client_id=dummy_client_id&redirect_uri=https%3A%2F%2Ftestserver%2Foauth%2Fdummy%2Fcallback%2F&response_type=code&scope=dummy_scope&state=dummy_state" 

257 ) 

258 

259 # Provider redirects to the callback url 

260 response = client.get("/oauth/dummy/callback/?code=test_code&state=dummy_state") 

261 assert response.status_code == 302 

262 assert response.url == "/" 

263 

264 # Now logged in 

265 response = client.get("/") 

266 

267 # Check the user and connection that was created 

268 user = response.user 

269 connections = user.oauth_connections.all() 

270 assert len(connections) == 1 

271 assert connections[0].provider_key == "dummy" 

272 assert connections[0].provider_user_id == "dummy_id" 

273 assert connections[0].access_token == "dummy_access_token" 

274 assert connections[0].refresh_token == "dummy_refresh_token" 

275 assert connections[0].access_token_expires_at == datetime.datetime( 

276 2020, 1, 1, 0, 0, tzinfo=datetime.UTC 

277 ) 

278 assert connections[0].refresh_token_expires_at == datetime.datetime( 

279 2020, 1, 2, 0, 0, tzinfo=datetime.UTC 

280 ) 

281 

282 assert get_user_model().objects.count() == 1 

283 assert OAuthConnection.objects.count() == 1 

284 

285 

286# def test_dummy_disconnect_to_password(db, client, settings): 

287# settings.OAUTH_LOGIN_PROVIDERS = { 

288# "dummy": { 

289# "class": "test_providers.DummyProvider", 

290# "kwargs": { 

291# "client_id": "dummy_client_id", 

292# "client_secret": "dummy_client_secret", 

293# "scope": "dummy_scope", 

294# }, 

295# } 

296# } 

297 

298# assert get_user_model().objects.count() == 0 

299# assert OAuthConnection.objects.count() == 0 

300 

301# # Create a user 

302# user = get_user_model().objects.create( 

303# username="dummy_username", email="dummy@example.com", password="dummy_password" 

304# ) 

305# OAuthConnection.objects.create( 

306# user=user, 

307# provider_key="dummy", 

308# provider_user_id="dummy_id", 

309# access_token="dummy_access_token", 

310# refresh_token="dummy_refresh_token", 

311# access_token_expires_at=datetime.datetime( 

312# 2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc 

313# ), 

314# refresh_token_expires_at=datetime.datetime( 

315# 2020, 1, 2, 0, 0, tzinfo=datetime.timezone.utc 

316# ), 

317# ) 

318 

319# assert get_user_model().objects.count() == 1 

320# assert OAuthConnection.objects.count() == 1 

321 

322# client.force_login(user) 

323 

324# # Raises a BadRequest error - can't disconnect the last connection without a password 

325# response = client.post( 

326# "/oauth/dummy/disconnect/", data={"provider_user_id": "dummy_id"} 

327# ) 

328# assert response.status_code == 302 

329# assert response.url == "/" 

330 

331# assert get_user_model().objects.count() == 1 

332# assert OAuthConnection.objects.count() == 0 

333 

334 

335# def test_dummy_disconnect_to_connection(db, client, settings): 

336# settings.OAUTH_LOGIN_PROVIDERS = { 

337# "dummy": { 

338# "class": "test_providers.DummyProvider", 

339# "kwargs": { 

340# "client_id": "dummy_client_id", 

341# "client_secret": "dummy_client_secret", 

342# "scope": "dummy_scope", 

343# }, 

344# } 

345# } 

346 

347# assert get_user_model().objects.count() == 0 

348# assert OAuthConnection.objects.count() == 0 

349 

350# # Create a user 

351# user = get_user_model().objects.create( 

352# username="dummy_username", email="dummy@example.com" 

353# ) 

354# OAuthConnection.objects.create( 

355# user=user, 

356# provider_key="dummy", 

357# provider_user_id="dummy_id", 

358# access_token="dummy_access_token", 

359# refresh_token="dummy_refresh_token", 

360# access_token_expires_at=datetime.datetime( 

361# 2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc 

362# ), 

363# refresh_token_expires_at=datetime.datetime( 

364# 2020, 1, 2, 0, 0, tzinfo=datetime.timezone.utc 

365# ), 

366# ) 

367# OAuthConnection.objects.create( 

368# user=user, 

369# provider_key="dummy", 

370# provider_user_id="dummy_id2", 

371# access_token="dummy_access_token", 

372# refresh_token="dummy_refresh_token", 

373# access_token_expires_at=datetime.datetime( 

374# 2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc 

375# ), 

376# refresh_token_expires_at=datetime.datetime( 

377# 2020, 1, 2, 0, 0, tzinfo=datetime.timezone.utc 

378# ), 

379# ) 

380 

381# assert get_user_model().objects.count() == 1 

382# assert OAuthConnection.objects.count() == 2 

383 

384# client.force_login(user) 

385 

386# # Raises a BadRequest error - can't disconnect the last connection without a password 

387# response = client.post( 

388# "/oauth/dummy/disconnect/", data={"provider_user_id": "dummy_id"} 

389# ) 

390# assert response.status_code == 302 

391# assert response.url == "/" 

392 

393# assert get_user_model().objects.count() == 1 

394# assert OAuthConnection.objects.count() == 1 

395 

396 

397# def test_dummy_disconnect_last(db, client, settings): 

398# settings.OAUTH_LOGIN_PROVIDERS = { 

399# "dummy": { 

400# "class": "test_providers.DummyProvider", 

401# "kwargs": { 

402# "client_id": "dummy_client_id", 

403# "client_secret": "dummy_client_secret", 

404# "scope": "dummy_scope", 

405# }, 

406# } 

407# } 

408 

409# assert get_user_model().objects.count() == 0 

410# assert OAuthConnection.objects.count() == 0 

411 

412# # Create a user 

413# user = get_user_model().objects.create( 

414# username="dummy_username", email="dummy@example.com" 

415# ) 

416# OAuthConnection.objects.create( 

417# user=user, 

418# provider_key="dummy", 

419# provider_user_id="dummy_id", 

420# access_token="dummy_access_token", 

421# refresh_token="dummy_refresh_token", 

422# access_token_expires_at=datetime.datetime( 

423# 2020, 1, 1, 0, 0, tzinfo=datetime.timezone.utc 

424# ), 

425# refresh_token_expires_at=datetime.datetime( 

426# 2020, 1, 2, 0, 0, tzinfo=datetime.timezone.utc 

427# ), 

428# ) 

429 

430# assert get_user_model().objects.count() == 1 

431# assert OAuthConnection.objects.count() == 1 

432 

433# client.force_login(user) 

434 

435# # Raises a BadRequest error - can't disconnect the last connection without a password 

436# response = client.post( 

437# "/oauth/dummy/disconnect/", data={"provider_user_id": "dummy_id"} 

438# ) 

439# assert response.status_code == 400 

440# assert response.templates[0].name == "oauth/error.html" 

441 

442# assert get_user_model().objects.count() == 1 

443# assert OAuthConnection.objects.count() == 1 

444 

445 

446def test_dummy_refresh(db, settings, monkeypatch): 

447 settings.OAUTH_LOGIN_PROVIDERS = { 

448 "dummy": { 

449 "class": "test_providers.DummyProvider", 

450 "kwargs": { 

451 "client_id": "dummy_client_id", 

452 "client_secret": "dummy_client_secret", 

453 "scope": "dummy_scope", 

454 }, 

455 } 

456 } 

457 

458 user = get_user_model().objects.create( 

459 username="dummy_username", email="dummy@example.com" 

460 ) 

461 connection = OAuthConnection.objects.create( 

462 user=user, 

463 provider_key="dummy", 

464 provider_user_id="dummy_id", 

465 access_token="dummy_access_token", 

466 refresh_token="dummy_refresh_token", 

467 access_token_expires_at=datetime.datetime( 

468 2020, 1, 1, 0, 0, tzinfo=datetime.UTC 

469 ), 

470 refresh_token_expires_at=datetime.datetime( 

471 2020, 1, 2, 0, 0, tzinfo=datetime.UTC 

472 ), 

473 ) 

474 

475 connection.refresh_access_token() 

476 assert connection.provider_key == "dummy" 

477 assert connection.provider_user_id == "dummy_id" 

478 assert connection.access_token == "refreshed_dummy_access_token" 

479 assert connection.refresh_token == "refreshed_dummy_refresh_token" 

480 assert connection.access_token_expires_at == datetime.datetime( 

481 2029, 1, 1, 0, 0, tzinfo=datetime.UTC 

482 ) 

483 assert connection.refresh_token_expires_at == datetime.datetime( 

484 2029, 1, 2, 0, 0, tzinfo=datetime.UTC 

485 )