wallaroo.auth

Handles authentication to the Wallaroo platform.

Performs a "device code"-style OAuth login flow.

The code is organized as follows:

  • Auth objects returned by create() should be placed on each request to platform APIs. Currently, we have the following types:

    • NoAuth: Does not modify requests
    • PlatformAuth: Places Authorization: Bearer XXX headers on each outgoing request
  • Objects derived from TokenFetcher know how to obtain an AccessToken from a particular provider:

    • KeycloakTokenFetcher: Fetches a token from Keycloak using a device code login flow
    • CachedTokenFetcher: Wraps another TokenFetcher and caches the value to a JSON file to reduce the number of user logins needed.
  1"""Handles authentication to the Wallaroo platform.
  2
  3Performs a "device code"-style OAuth login flow.
  4
  5The code is organized as follows:
  6
  7* Auth objects returned by `create()` should be placed on each request to
  8  platform APIs. Currently, we have the following types:
  9  * NoAuth: Does not modify requests
 10  * PlatformAuth: Places `Authorization: Bearer XXX` headers on each outgoing
 11    request
 12
 13* Objects derived from TokenFetcher know how to obtain an AccessToken from a
 14  particular provider:
 15  * KeycloakTokenFetcher: Fetches a token from Keycloak using a device code
 16    login flow
 17  * CachedTokenFetcher: Wraps another TokenFetcher and caches the value to a
 18    JSON file to reduce the number of user logins needed.
 19"""
 20import abc
 21import datetime
 22import enum
 23import json
 24import logging as log
 25import os
 26import pathlib
 27import posixpath
 28import shutil
 29import time
 30from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
 31
 32import appdirs  # type: ignore
 33import jwt
 34import requests
 35
 36from .version import _user_agent
 37
 38AUTH_PATH = "WALLAROO_SDK_CREDENTIALS"
 39USER_VAR = "WALLAROO_USER"
 40PASSWORD_VAR = "WALLAROO_PASSWORD"
 41KEYCLOAK_CLIENT_NAME = "sdk-client"
 42
 43KEYCLOAK_REALM = "master"
 44
 45
 46################################################################################
 47# Module public API
 48################################################################################
 49
 50
 51class AuthType(enum.Enum):
 52    """Defines all the supported auth types.
 53
 54    Handles conversions from string names to enum values.
 55    """
 56
 57    NONE = "none"
 58    SSO = "sso"
 59    USER_PASSWORD = "user_password"
 60    TEST_AUTH = "test_auth"
 61    TOKEN = "token"
 62
 63
 64class TokenData(NamedTuple):
 65    token: str
 66    user_email: str
 67    user_id: str
 68
 69    def to_dict(self) -> Dict[str, str]:
 70        return self._asdict()
 71
 72
 73class _WallarooAuth(requests.auth.AuthBase):
 74    """Add a user_id function to base auth class"""
 75
 76    def user_id(self) -> Optional[str]:
 77        pass
 78
 79    def user_email(self) -> Optional[str]:
 80        pass
 81
 82    def _bearer_token_str(self) -> str:
 83        pass
 84
 85    def _access_token(self) -> "_AccessToken":
 86        pass
 87
 88
 89def create(keycloak_addr: str, auth_type: Union[AuthType, str, None]) -> _WallarooAuth:
 90    """Returns an auth object of the corresponding type.
 91
 92    :param str keycloak_addr: Address of the Keycloak instance to auth against
 93    :param AuthType or str auth_type: Type of authentication to use
 94    :return: Auth object that can be passed to all `requests` calls
 95    :rtype: AuthBase
 96    :raises NotImplementedError: if auth_type is not recognized
 97    """
 98    if isinstance(auth_type, str):
 99        auth_type = AuthType(auth_type.lower())
100    elif os.getenv(AUTH_PATH) or os.getenv(USER_VAR):
101        auth_type = AuthType.USER_PASSWORD
102    elif auth_type == None or auth_type == AuthType.NONE:
103        return _NoAuth()
104    else:
105        # TODO: Error?
106        print("Unknown auth type.")
107
108    cached_token_path = (
109        pathlib.Path(
110            appdirs.user_cache_dir(appname="wallaroo_sdk", appauthor="wallaroo")
111        )
112        / "auth"
113        / "keycloak.json"
114    )
115    fetcher: _TokenFetcher
116    if auth_type == AuthType.SSO:
117        fetcher = _CachedTokenFetcher(
118            path=cached_token_path,
119            fetcher=_KeycloakTokenFetcher(
120                address=keycloak_addr,
121                realm=KEYCLOAK_REALM,
122                client_id=KEYCLOAK_CLIENT_NAME,
123            ),
124        )
125    elif auth_type == AuthType.USER_PASSWORD:
126        (username, password) = _GetUserPasswordCreds()
127        fetcher = _CachedTokenFetcher(
128            path=cached_token_path,
129            fetcher=_PasswordTokenFetcher(
130                address=keycloak_addr,
131                realm=KEYCLOAK_REALM,
132                client_id=KEYCLOAK_CLIENT_NAME,
133                username=username,
134                password=password,
135            ),
136        )
137    elif auth_type == AuthType.TEST_AUTH:
138        return _TestAuth()
139    elif auth_type == AuthType.TOKEN:
140        env_key = "WALLAROO_SDK_CREDENTIALS"
141        if env_key not in os.environ:
142            raise Exception("Passed token auth, but no token provided.")
143        if not os.path.exists(os.environ[env_key]):
144            raise Exception("Token file provided, but file does not exist.")
145        data = _get_token_from_file(os.environ[env_key])
146        fetcher = _RawTokenFetcher(data)
147    else:
148        raise NotImplementedError(f"Unsupported auth type: {auth_type}")
149    return _PlatformAuth(fetcher=fetcher)
150
151
152def logout():
153    """Removes cached values for all third-party auth providers.
154
155    This will not invalidate auth objects already created with `create()`.
156
157    :rtype: None
158    """
159    cache_dir = (
160        pathlib.Path(
161            appdirs.user_cache_dir(appname="wallaroo_sdk", appauthor="wallaroo")
162        )
163        / "auth"
164    )
165    shutil.rmtree(cache_dir, ignore_errors=True)
166
167
168class AuthError(Exception):
169    """Base type for all errors in this module."""
170
171    def __init__(self, message: str, code: Optional[int] = None) -> None:
172        if code:
173            super().__init__(f"[HTTP {code}] {message}")
174        else:
175            super().__init__(message)
176
177
178class TokenFetchError(AuthError):
179    """Errors encountered while performing a login."""
180
181
182class TokenRefreshError(AuthError):
183    """Errors encountered while refreshing an AccessToken."""
184
185
186################################################################################
187# Module private classes
188################################################################################
189
190
191def _GetUserPasswordCreds() -> Tuple[str, str]:
192    """Returns username/password credentials discovered via the environment.
193
194    If this function is called, $WALLAROO_SDK_CREDENTIALS must point to a JSON
195    file containing the following shape:
196
197    {
198        "username": "some_keycloak_username",
199        "password": "some_password"
200    }
201
202    Returns: (username, password)
203
204    Raises: TokenFetchError if the var is not set or the file is not found.
205    """
206    if os.getenv(USER_VAR):
207        return (os.environ[USER_VAR], os.environ[PASSWORD_VAR])
208    path = os.getenv(AUTH_PATH)
209    if not path:
210        raise TokenFetchError(f"${AUTH_PATH} is not set")
211    p = pathlib.Path(path)
212    if not p.is_file():
213        raise TokenFetchError(f"{AUTH_PATH} does not point to a file")
214    with p.open("r") as f:
215        creds = json.loads(f.read())
216    return (creds["username"], creds["password"])
217
218
219def _get_token_from_file(path: str) -> TokenData:
220    p = pathlib.Path(path)
221    with p.open("r") as f:
222        creds = json.loads(f.read())
223    if "token" not in creds:
224        raise Exception("Token property field not in json.")
225    if "user_email" not in creds:
226        raise Exception("User email property field not in json.")
227    if "user_id" not in creds:
228        raise Exception("User id property field not in json.")
229    return TokenData(**creds)
230
231
232class _AccessToken(NamedTuple):
233    """Wraps a token returned by an oauth provider.
234
235    These tokens require a manual (read: annoying) flow to obtain
236    and either don't expire or are otherwise long-lived, so they should be
237    cached aggressively.
238    """
239
240    # Token payload (e.g. "gho_qtGcULeZO3HbvCRS3tl9GR0xtO9nRQ3F" for Github)
241    token: str
242    # Expiry time for `token`
243    expiry: datetime.datetime
244    # Refresh token payload
245    refresh_token: str
246    # Refresh token expiry time
247    refresh_token_expiry: datetime.datetime
248    # User Id from keycloak
249    user_id: str
250    # email is from keycloak
251    user_email: str
252
253    def ToDict(self) -> Dict[str, Any]:
254        return {
255            "access_token": self.token,
256            "expiry": self.expiry.timestamp(),
257            "refresh_token": self.refresh_token,
258            "refresh_token_expiry": self.refresh_token_expiry.timestamp(),
259            "user_id": self.user_id,
260            "user_email": self.user_email,
261        }
262
263    @classmethod
264    def FromDict(cls, d):
265        return cls(
266            token=d["access_token"],
267            expiry=datetime.datetime.fromtimestamp(d["expiry"]),
268            refresh_token=d["refresh_token"],
269            refresh_token_expiry=datetime.datetime.fromtimestamp(
270                d["refresh_token_expiry"]
271            ),
272            user_id=d["user_id"],
273            user_email=d["user_email"],
274        )
275
276
277class _TokenFetcher(abc.ABC):
278    """Defines a method by which tokens are fetched."""
279
280    @abc.abstractmethod
281    def Fetch(self) -> _AccessToken:
282        """Performs a third-party-specific manual flow to obtain a token.
283
284        Raises: TokenFetchError if the flow fails.
285        """
286
287    @abc.abstractmethod
288    def Refresh(self, access_token: _AccessToken) -> _AccessToken:
289        """Performs a token refresh to obtain a new token.
290
291        Raises: TokenRefreshError if the flow fails.
292        """
293
294    def Reset(self) -> None:
295        """Resets any internal state.
296
297        This can be called by higher levels before reattempting a token exchange
298        if it is suspected that the fetcher is in a bad state.
299        """
300
301
302class _CachedTokenFetcher(_TokenFetcher):
303    """Wraps another TokenFetcher; persists its token to a file.
304
305    If the file named by `path` already exists and contains a valid token, this
306    TokenFetcher returns that value. If not, this TokenFetcher delegates to the
307    supplied fetcher and caches that fetcher's value in said file.
308    """
309
310    def __init__(self, path: pathlib.Path, fetcher: _TokenFetcher) -> None:
311        # Path to the JSON file in which AccessToken should be stored.
312        self.path = path
313        # Fetcher to wrap.
314        self.fetcher = fetcher
315        # Fetch and cache the underlying value ASAP.
316        self.Fetch()
317
318    def Fetch(self) -> _AccessToken:
319        try:
320            with self.path.open("r") as f:
321                token = json.load(f, object_hook=_AccessToken.FromDict)
322                return self.fetcher.Refresh(token)
323        except Exception as e:
324            log.info(
325                "Couldn't load access token from '%s' (%s); re-fetching", self.path, e
326            )
327
328        access_token = self.fetcher.Fetch()
329        if self.path.exists():
330            self.path.unlink()
331        self.path.parent.mkdir(parents=True, exist_ok=True)
332        self._WriteFile(access_token)
333        return access_token
334
335    def Refresh(self, access_token: _AccessToken) -> _AccessToken:
336        """Refreshes the token using the underlying fetcher and saves the result."""
337        access_token = self.fetcher.Refresh(access_token)
338        self._WriteFile(access_token)
339        return access_token
340
341    def Reset(self) -> None:
342        """Deletes the cached file before delegating to the underlying Fetcher."""
343        if self.path.exists():
344            self.path.unlink()
345        self.fetcher.Reset()
346
347    def _WriteFile(self, access_token: _AccessToken) -> None:
348        with self.path.open("w") as f:
349            f.write(json.dumps(access_token.ToDict()))
350        self.path.chmod(0o600)
351
352
353class _RawTokenFetcher(_TokenFetcher):
354    """
355    Used for passing a raw token. This is meant for nested calls, and short lived request.
356    Refresh will not handled heree.
357    """
358
359    access_token: _AccessToken
360
361    def __init__(self, token_data: TokenData) -> None:
362        expiry = datetime.datetime.now() + datetime.timedelta(minutes=5)
363        self.access_token = _AccessToken(
364            expiry=expiry,
365            refresh_token="unknown",
366            refresh_token_expiry=expiry,
367            **token_data.to_dict(),
368        )
369
370    def Fetch(self) -> _AccessToken:
371        return self.access_token
372
373    def Refresh(self, access_token: _AccessToken) -> _AccessToken:
374        return self.access_token
375
376    def Reset(self) -> None:
377        return None
378
379    def _WriteFile(self, access_token: _AccessToken) -> None:
380        pass
381
382
383class _KeycloakTokenFetcher(_TokenFetcher):
384    def __init__(self, address: str, realm: str, client_id: str) -> None:
385        self.address = address
386        self.realm = realm
387        self.client_id = client_id
388
389    def Fetch(self) -> _AccessToken:
390        device_code_endpoint = posixpath.join(
391            self.address,
392            "auth/realms",
393            self.realm,
394            "protocol/openid-connect/auth/device",
395        )
396        headers = {"User-Agent": _user_agent}
397        res = requests.post(
398            device_code_endpoint,
399            data={
400                "client_id": self.client_id,
401            },
402            headers=headers,
403        )
404        if res.status_code != 200:
405            log.error(
406                "Keycloak device code fetch got error: %d - %s",
407                res.status_code,
408                res.text,
409            )
410            raise TokenFetchError(res.text, code=res.status_code)
411        res_data = json.loads(res.text)
412        device_code = res_data["device_code"]
413        poll_interval = res_data["interval"]
414        expire_time = datetime.datetime.now() + datetime.timedelta(
415            seconds=int(res_data["expires_in"])
416        )
417
418        external_url = os.environ.get("WALLAROO_SDK_AUTH_ENDPOINT", self.address)
419        verification_uri_complete = res_data["verification_uri_complete"].replace(
420            self.address, external_url, 1
421        )
422
423        print(
424            f"Please log into the following URL in a web browser:\n\n\t{verification_uri_complete}\n"
425        )
426
427        while datetime.datetime.now() < expire_time:
428            log.debug("Polling Keycloak for access_code...")
429            token_endpoint = posixpath.join(
430                self.address,
431                "auth/realms",
432                self.realm,
433                "protocol/openid-connect/token",
434            )
435            res = requests.post(
436                token_endpoint,
437                data={
438                    "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
439                    "device_code": device_code,
440                    "client_id": self.client_id,
441                },
442                headers=headers,
443            )
444            res_data = json.loads(res.text)
445            if res.status_code != 200:
446                if "error" in res_data and res_data["error"] in [
447                    "authorization_pending",
448                    "slow_down",
449                ]:
450                    log.debug("keycloak authorization is still pending")
451                    time.sleep(poll_interval)
452                    continue
453                else:
454                    log.error(
455                        "unknown error while polling for access_token: %d - %s",
456                        res.status_code,
457                        res.text,
458                    )
459                    raise TokenFetchError(res.text, code=res.status_code)
460            log.debug("got access_token from Keycloak")
461            print("Login successful!")
462            decoded = jwt.decode(
463                res_data["access_token"], options={"verify_signature": False}
464            )
465            user_id = decoded["sub"]
466            user_email = decoded["email"]
467            return _AccessToken(
468                token=res_data["access_token"],
469                expiry=datetime.datetime.now()
470                + datetime.timedelta(seconds=int(res_data["expires_in"])),
471                refresh_token=res_data["refresh_token"],
472                refresh_token_expiry=datetime.datetime.now()
473                + datetime.timedelta(seconds=int(res_data["refresh_expires_in"])),
474                user_id=user_id,
475                user_email=user_email,
476            )
477
478        raise TokenFetchError("Device code expired while waiting for user login")
479
480    def Refresh(self, access_token: _AccessToken) -> _AccessToken:
481        if datetime.datetime.now() < access_token.expiry - datetime.timedelta(
482            seconds=10
483        ):
484            return access_token
485
486        if datetime.datetime.now() >= access_token.refresh_token_expiry:
487            raise TokenRefreshError("refresh token has expired")
488
489        log.debug("refreshing access_token via Keycloak...")
490
491        refresh_endpoint = posixpath.join(
492            self.address,
493            "auth/realms",
494            self.realm,
495            "protocol/openid-connect/token",
496        )
497        res = requests.post(
498            refresh_endpoint,
499            data={
500                "client_id": self.client_id,
501                "grant_type": "refresh_token",
502                "refresh_token": access_token.refresh_token,
503            },
504            headers={"User-Agent": _user_agent},
505        )
506        if res.status_code != 200:
507            log.error(
508                "Keycloak token refresh got error: %d - %s", res.status_code, res.text
509            )
510            raise TokenRefreshError(res.text, code=res.status_code)
511
512        res_data = json.loads(res.text)
513        decoded = jwt.decode(
514            res_data["access_token"], options={"verify_signature": False}
515        )
516        user_id = decoded["sub"]
517        user_email = decoded["email"]
518        new_access_token = _AccessToken(
519            token=res_data["access_token"],
520            expiry=datetime.datetime.now()
521            + datetime.timedelta(seconds=res_data["expires_in"]),
522            refresh_token=res_data["refresh_token"],
523            refresh_token_expiry=datetime.datetime.now()
524            + datetime.timedelta(seconds=res_data["refresh_expires_in"]),
525            user_id=user_id,
526            user_email=user_email,
527        )
528        log.debug("keycloak token refresh successful")
529        return new_access_token
530
531
532class _PasswordTokenFetcher(_TokenFetcher):
533    def __init__(
534        self,
535        address: str,
536        realm: str,
537        client_id: str,
538        username: str,
539        password: str,
540    ):
541        self.address = address
542        self.realm = realm
543        self.client_id = client_id
544        self.username = username
545        self.password = password
546
547    def Fetch(self) -> _AccessToken:
548        token_endpoint = posixpath.join(
549            self.address,
550            "auth/realms",
551            self.realm,
552            "protocol/openid-connect/token",
553        )
554        res = requests.post(
555            token_endpoint,
556            data={
557                "client_id": self.client_id,
558                "username": self.username,
559                "password": self.password,
560                "grant_type": "password",
561            },
562            headers={"User-Agent": _user_agent},
563        )
564        if res.status_code != 200:
565            log.error(
566                "Keycloak token refresh got error: %d - %s", res.status_code, res.text
567            )
568            raise TokenFetchError(res.text, code=res.status_code)
569        res_data = json.loads(res.text)
570
571        decoded = jwt.decode(
572            res_data["access_token"], options={"verify_signature": False}
573        )
574        user_id = decoded["sub"]
575        user_email = decoded["email"] if "email" in decoded else "admin@keycloak"
576        access_token = _AccessToken(
577            token=res_data["access_token"],
578            expiry=datetime.datetime.now()
579            + datetime.timedelta(seconds=res_data["expires_in"]),
580            refresh_token=res_data["refresh_token"],
581            refresh_token_expiry=datetime.datetime.now()
582            + datetime.timedelta(seconds=res_data["refresh_expires_in"]),
583            user_id=user_id,
584            user_email=user_email,
585        )
586        return access_token
587
588    def Refresh(self, access_token: _AccessToken) -> _AccessToken:
589        if datetime.datetime.now() < access_token.expiry - datetime.timedelta(
590            seconds=10
591        ):
592            return access_token
593
594        if datetime.datetime.now() >= access_token.refresh_token_expiry:
595            raise TokenRefreshError("refresh token has expired")
596
597        log.debug("refreshing access_token via Keycloak...")
598
599        refresh_endpoint = posixpath.join(
600            self.address,
601            "auth/realms",
602            self.realm,
603            "protocol/openid-connect/token",
604        )
605        res = requests.post(
606            refresh_endpoint,
607            data={
608                "client_id": self.client_id,
609                "grant_type": "refresh_token",
610                "refresh_token": access_token.refresh_token,
611            },
612            headers={"User-Agent": _user_agent},
613        )
614        if res.status_code != 200:
615            log.error(
616                "Keycloak token refresh got error: %d - %s", res.status_code, res.text
617            )
618            raise TokenRefreshError(res.text, code=res.status_code)
619
620        res_data = json.loads(res.text)
621        new_access_token = _AccessToken(
622            token=res_data["access_token"],
623            expiry=datetime.datetime.now()
624            + datetime.timedelta(seconds=res_data["expires_in"]),
625            refresh_token=res_data["refresh_token"],
626            refresh_token_expiry=datetime.datetime.now()
627            + datetime.timedelta(seconds=res_data["refresh_expires_in"]),
628            user_id=access_token.user_id,
629            user_email=access_token.user_email,
630        )
631        log.debug("keycloak token refresh successful")
632        return new_access_token
633
634
635class _NoAuth(_WallarooAuth):
636    """No-op auth hook that does not change requests."""
637
638    def __init__(self) -> None:
639        pass
640
641    def __call__(self, req: requests.PreparedRequest) -> requests.PreparedRequest:
642        """Returns the request unmodified."""
643        return req
644
645    def user_id(self) -> None:
646        pass
647
648    def user_email(self) -> str:
649        return "default@ex.co"
650
651    def _bearer_token_str(self) -> str:
652        return "no token"
653
654    def _access_token(self) -> "_AccessToken":
655        return _AccessToken(
656            token="definitely_an_access_token",
657            expiry=datetime.datetime.now(),
658            refresh_token="none",
659            refresh_token_expiry=datetime.datetime.now(),
660            user_email="test",
661            user_id="test",
662        )
663
664
665class _PlatformAuth(_WallarooAuth):
666    """Auth object for when our platform has auth enabled.
667
668    This object should be constructed once and then passed to every `requests`
669    call as an `auth` parameter.
670
671    Takes a TokenFetcher that will be different depending on the third-party
672    provider specified.
673    """
674
675    def __init__(self, fetcher: _TokenFetcher):
676        self.fetcher = fetcher
677
678    def __call__(self, req: requests.PreparedRequest) -> requests.PreparedRequest:
679        """Attaches a Keycloak JWT to the outgoing request."""
680        req.headers["Authorization"] = self._bearer_token_str()
681        return req
682
683    def auth_header(self) -> Dict[str, str]:
684        headers = {}
685        headers["Authorization"] = self._bearer_token_str()
686        return headers
687
688    def _bearer_token_str(self) -> str:
689        """Generates a bearer string using a Keycloak JWT."""
690        token = self._access_token()
691        return "Bearer {}".format(token.token)
692
693    def _access_token(self) -> "_AccessToken":
694        try:
695            token = self.fetcher.Refresh(self.fetcher.Fetch())
696        except TokenRefreshError:
697            # Maybe the refresh failed because the refresh token expired.
698            # Re-exchange in an attempt to get a fresh JWT.
699            self.fetcher.Reset()
700            token = self.fetcher.Fetch()
701        return token
702
703    def user_id(self) -> Optional[str]:
704        return self._access_token().user_id
705
706    def user_email(self) -> Optional[str]:
707        return self._access_token().user_email
708
709
710class _TestAuth(_WallarooAuth):
711    """Auth type for unit tests"""
712
713    def __call__(self, req: requests.PreparedRequest) -> requests.PreparedRequest:
714        """Attaches a Keycloak JWT to the outgoing request."""
715        req.headers["Authorization"] = self._bearer_token_str()
716        return req
717
718    def user_id(self) -> Optional[str]:
719        return "99cace15-e0d4-4bb6-bf14-35efee181b90"
720
721    def user_email(self) -> Optional[str]:
722        return "jane@ex.co"
723
724    def _bearer_token_str(self) -> str:
725        return "definitely_a_bearer_token_str"
726
727    def _access_token(self) -> "_AccessToken":
728        return _AccessToken(
729            token="definitely_an_access_token",
730            expiry=datetime.datetime.now(),
731            refresh_token="none",
732            refresh_token_expiry=datetime.datetime.now(),
733            user_email="test",
734            user_id="test",
735        )
class AuthType(enum.Enum):
52class AuthType(enum.Enum):
53    """Defines all the supported auth types.
54
55    Handles conversions from string names to enum values.
56    """
57
58    NONE = "none"
59    SSO = "sso"
60    USER_PASSWORD = "user_password"
61    TEST_AUTH = "test_auth"
62    TOKEN = "token"

Defines all the supported auth types.

Handles conversions from string names to enum values.

NONE = <AuthType.NONE: 'none'>
SSO = <AuthType.SSO: 'sso'>
USER_PASSWORD = <AuthType.USER_PASSWORD: 'user_password'>
TEST_AUTH = <AuthType.TEST_AUTH: 'test_auth'>
TOKEN = <AuthType.TOKEN: 'token'>
Inherited Members
enum.Enum
name
value
class TokenData(typing.NamedTuple):
65class TokenData(NamedTuple):
66    token: str
67    user_email: str
68    user_id: str
69
70    def to_dict(self) -> Dict[str, str]:
71        return self._asdict()

TokenData(token, user_email, user_id)

TokenData(token: str, user_email: str, user_id: str)

Create new instance of TokenData(token, user_email, user_id)

token: str

Alias for field number 0

user_email: str

Alias for field number 1

user_id: str

Alias for field number 2

def to_dict(self) -> Dict[str, str]:
70    def to_dict(self) -> Dict[str, str]:
71        return self._asdict()
Inherited Members
builtins.tuple
index
count
def create( keycloak_addr: str, auth_type: Union[wallaroo.auth.AuthType, str, NoneType]) -> wallaroo.auth._WallarooAuth:
 90def create(keycloak_addr: str, auth_type: Union[AuthType, str, None]) -> _WallarooAuth:
 91    """Returns an auth object of the corresponding type.
 92
 93    :param str keycloak_addr: Address of the Keycloak instance to auth against
 94    :param AuthType or str auth_type: Type of authentication to use
 95    :return: Auth object that can be passed to all `requests` calls
 96    :rtype: AuthBase
 97    :raises NotImplementedError: if auth_type is not recognized
 98    """
 99    if isinstance(auth_type, str):
100        auth_type = AuthType(auth_type.lower())
101    elif os.getenv(AUTH_PATH) or os.getenv(USER_VAR):
102        auth_type = AuthType.USER_PASSWORD
103    elif auth_type == None or auth_type == AuthType.NONE:
104        return _NoAuth()
105    else:
106        # TODO: Error?
107        print("Unknown auth type.")
108
109    cached_token_path = (
110        pathlib.Path(
111            appdirs.user_cache_dir(appname="wallaroo_sdk", appauthor="wallaroo")
112        )
113        / "auth"
114        / "keycloak.json"
115    )
116    fetcher: _TokenFetcher
117    if auth_type == AuthType.SSO:
118        fetcher = _CachedTokenFetcher(
119            path=cached_token_path,
120            fetcher=_KeycloakTokenFetcher(
121                address=keycloak_addr,
122                realm=KEYCLOAK_REALM,
123                client_id=KEYCLOAK_CLIENT_NAME,
124            ),
125        )
126    elif auth_type == AuthType.USER_PASSWORD:
127        (username, password) = _GetUserPasswordCreds()
128        fetcher = _CachedTokenFetcher(
129            path=cached_token_path,
130            fetcher=_PasswordTokenFetcher(
131                address=keycloak_addr,
132                realm=KEYCLOAK_REALM,
133                client_id=KEYCLOAK_CLIENT_NAME,
134                username=username,
135                password=password,
136            ),
137        )
138    elif auth_type == AuthType.TEST_AUTH:
139        return _TestAuth()
140    elif auth_type == AuthType.TOKEN:
141        env_key = "WALLAROO_SDK_CREDENTIALS"
142        if env_key not in os.environ:
143            raise Exception("Passed token auth, but no token provided.")
144        if not os.path.exists(os.environ[env_key]):
145            raise Exception("Token file provided, but file does not exist.")
146        data = _get_token_from_file(os.environ[env_key])
147        fetcher = _RawTokenFetcher(data)
148    else:
149        raise NotImplementedError(f"Unsupported auth type: {auth_type}")
150    return _PlatformAuth(fetcher=fetcher)

Returns an auth object of the corresponding type.

Parameters
  • str keycloak_addr: Address of the Keycloak instance to auth against
  • AuthType or str auth_type: Type of authentication to use
Returns

Auth object that can be passed to all requests calls

Raises
  • NotImplementedError: if auth_type is not recognized
def logout():
153def logout():
154    """Removes cached values for all third-party auth providers.
155
156    This will not invalidate auth objects already created with `create()`.
157
158    :rtype: None
159    """
160    cache_dir = (
161        pathlib.Path(
162            appdirs.user_cache_dir(appname="wallaroo_sdk", appauthor="wallaroo")
163        )
164        / "auth"
165    )
166    shutil.rmtree(cache_dir, ignore_errors=True)

Removes cached values for all third-party auth providers.

This will not invalidate auth objects already created with create().

class AuthError(builtins.Exception):
169class AuthError(Exception):
170    """Base type for all errors in this module."""
171
172    def __init__(self, message: str, code: Optional[int] = None) -> None:
173        if code:
174            super().__init__(f"[HTTP {code}] {message}")
175        else:
176            super().__init__(message)

Base type for all errors in this module.

AuthError(message: str, code: Optional[int] = None)
172    def __init__(self, message: str, code: Optional[int] = None) -> None:
173        if code:
174            super().__init__(f"[HTTP {code}] {message}")
175        else:
176            super().__init__(message)
Inherited Members
builtins.BaseException
with_traceback
args
class TokenFetchError(AuthError):
179class TokenFetchError(AuthError):
180    """Errors encountered while performing a login."""

Errors encountered while performing a login.

Inherited Members
AuthError
AuthError
builtins.BaseException
with_traceback
args
class TokenRefreshError(AuthError):
183class TokenRefreshError(AuthError):
184    """Errors encountered while refreshing an AccessToken."""

Errors encountered while refreshing an AccessToken.

Inherited Members
AuthError
AuthError
builtins.BaseException
with_traceback
args