app/controllers/Auth.scala (439 lines of code) (raw):

package controllers import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity, HttpMethods, HttpRequest, MediaRange, MediaTypes, ResponseEntity, StatusCodes} import akka.stream.scaladsl.{Keep, Sink} import akka.util.ByteString import auth.{BearerTokenAuth, LoginResultOK} import com.nimbusds.jwt.JWTClaimsSet import org.slf4j.LoggerFactory import play.api.Configuration import play.api.mvc.{AbstractController, ControllerComponents, Cookie, DiscardingCookie, Request, ResponseHeader, Result, Session} import scala.util.matching.Regex import java.net.{URL, URLEncoder} import java.nio.charset.StandardCharsets import javax.inject.{Inject, Singleton} import scala.concurrent.{ExecutionContext, Future} import io.circe.generic.auto._ import io.circe.syntax._ import models.{OAuthTokenEntry, OAuthTokenEntryDAO, UserProfile, UserProfileDAO} import play.api.libs.circe.Circe import play.api.mvc.Cookie.SameSite import responses.GenericErrorResponse import auth.ClaimsSetExtensions._ import helpers.{HttpClientFactory, UserAvatarHelper} import java.nio.ByteBuffer import java.time.format.DateTimeFormatter import java.time.{Duration, Instant, ZoneId, ZonedDateTime} import java.util.{Base64} import scala.util.Try @Singleton class Auth @Inject() (config:Configuration, bearerTokenAuth: BearerTokenAuth, userProfileDAO:UserProfileDAO, cc:ControllerComponents, httpFactory:HttpClientFactory, userAvatarHelper: UserAvatarHelper, oAuthTokenEntryDAO: OAuthTokenEntryDAO) (implicit actorSystem: ActorSystem) extends AbstractController(cc) with Circe { private implicit val ec:ExecutionContext = cc.executionContext private val logger = LoggerFactory.getLogger(getClass) import Auth._ /** * allow overriding of the Http() object for testing * @return */ protected def http = httpFactory.build //sometimes in development it's easier to run without https, this is indicated by the `enforceSecure` parameter in the config private def redirectProto = if(config.getOptional[Boolean]("oAuth.enforceSecure").getOrElse(true)) "https://" else "http://" def redirectUri[T](request:Request[T]) = redirectProto + request.host + "/oauthCallback" /** * builds a URL to the oauth IdP and redirects the user there * @return */ def login(state:Option[String],code_challenge:Option[String]) = Action { request=> var args = Map(""->"") if (config.get[String]("oAuth.type") != "Azure") { args = Map( "response_type"->"code", "client_id"->config.get[String]("oAuth.clientId"), "resource"->config.get[String]("oAuth.resource"), "redirect_uri"->redirectUri(request), "state"->state.getOrElse("/") ) } else { args = Map( "response_type"->"code", "client_id"->config.get[String]("oAuth.clientId"), "redirect_uri"->redirectUri(request), "state"->state.getOrElse("/"), "scope"->config.get[String]("oAuth.scope"), "code_challenge"->code_challenge.getOrElse("nothing") ) } logger.debug(s"OAuth arguments before encoding: $args") val queryArgs = assembleFromMap(args) logger.debug(s"OAuth arguments after decoding: $queryArgs") val finalUrl = config.get[String]("oAuth.oAuthUri") + "?" + queryArgs TemporaryRedirect(finalUrl).withSession(request.session + ("code_verifier" -> code_challenge.getOrElse("nothing"))) } /** * internal method to take a Map of parameters and turn them into a urlencoded string * @param content a string->string map * @return a url-encoded string with all of the parameters from `content` */ private def assembleFromMap(content:Map[String,String]) = content .map(kv=>s"${kv._1}=${URLEncoder.encode(kv._2, "UTF-8")}") .mkString("&") /** * tries to extract and save a profile picture from either `thumbnailPhoto` or `jpegPhoto` claim fields * @param response * @return */ private def profilePicFromJWT(response: Either[String, JWTClaimsSet]) = { import cats.implicits._ response match { case Left(err)=> logger.debug(s"Can't get user avatar because login attempt was not successful: ${err}") Future(Left(err)) case Right(claims)=> Future.fromTry(Try { Seq("thumbnailPhoto", "jpegPhoto") .map(key => Option(claims.getStringClaim(key))) .collectFirst { case Some(content) => content } .map(Base64.getDecoder.decode) .map(content=>{ logger.debug(s"Got ${content.length} bytes of picture data for ${claims.getUserID}") content }) .map(ByteBuffer.wrap) .map(buffer=>userAvatarHelper.writeAvatarData(claims.getUserID, buffer)) .sequence }).flatten .map(Right.apply) .recover({ case err:Throwable=> logger.error(s"Could not get user avatar from claims: ${err.getMessage}", err) Left(err.getMessage) }) } } /** * internal method, part of the step two exchange. * Given a decoded JWT, try to look up the user's profile in the database. * If there is no profile existing at the moment then create a base one. * @param response result from `validateContent` * @return a Future with a Left if an error occurred (with descriptive string) and a Right if we got the UserProfile */ private def userProfileFromJWT(response: Either[String, JWTClaimsSet]) = { response match { case Left(err)=>Future(Left(err)) case Right(oAuthResponse)=> userProfileDAO.userProfileForEmail(oAuthResponse.getUserID).flatMap({ case None=> logger.info(s"No user profile existing for ${oAuthResponse.getUserID}, creating one") var newUserProfile = UserProfile( oAuthResponse.getUserID, oAuthResponse.getIsMMAdminFromRole, Option(oAuthResponse.getStringClaim("given_name")), Option(oAuthResponse.getStringClaim("family_name")), Seq(), allCollectionsVisible=true, None, Option(oAuthResponse.getStringClaim("location")), None, None, None, None ) if (config.get[String]("oAuth.type") != "Azure") { newUserProfile = UserProfile( oAuthResponse.getUserID, oAuthResponse.getIsMMAdmin, Option(oAuthResponse.getStringClaim("first_name")), Option(oAuthResponse.getStringClaim("family_name")), Seq(), allCollectionsVisible=true, None, Option(oAuthResponse.getStringClaim("location")), None, None, None, None ) } userProfileDAO .put(newUserProfile) .map(Right.apply) .recover({ case err:Throwable=> logger.error(s"Could not save user profile for ${newUserProfile.userEmail}: ${err.getMessage}", err) Left(err.getMessage) }) case Some(Left(dynamoErr))=>Future(Left(dynamoErr.toString)) case Some(Right(userProfile))=>Future(Right(userProfile)) }) } } /** * internal method, part of the step two exchange. * Given the response from the server, validate and decode the JWT present * @param response result from `stageTwo` * @return a Future with either a Left with descriptive error string or Right with the decoded claims set */ private def validateContent(response: Either[String, OAuthResponse]) = Future( response .flatMap(oAuthResponse=>{ if (config.get[String]("oAuth.type") != "Azure") { bearerTokenAuth .validateToken(LoginResultOK(oAuthResponse.access_token.get)) match { case Left(err)=>Left(err.toString) case Right(response)=>Right(response.content) } } else { bearerTokenAuth .validateToken(LoginResultOK(oAuthResponse.id_token.get)) match { case Left(err)=>Left(err.toString) case Right(response)=>Right(response.content) } } }) ) /** * internal method, part of the step two exchange. * * Given the response from the server and the response from the user profile, either of which could be errors, * formulate a response for the client * @param maybeOAuthResponse result from `stageTwo` * @param maybeUserProfile result from `userProfileFromJWT` * @param header ResponseHeader entitiy that is sent to the client on success * @param entity HttpEntitiy indicating the body of the response that is sent to the client on success * @return a Future containing a Play response */ private def finalCallbackResponse(maybeOAuthResponse:Either[String, OAuthResponse], maybeOAuthClaims:Either[String, JWTClaimsSet], maybeUserProfile:Either[String, UserProfile], header:ResponseHeader, entity: play.api.http.HttpEntity) = Future( maybeOAuthResponse match { case Left(err)=> logger.error(s"Could not perform oauth exchange: $err") InternalServerError(GenericErrorResponse("error",err).asJson) case Right(oAuthResponse)=> logger.debug(s"oauth exchange successful, got $oAuthResponse") val baseSessionValues = Map[String,String]() val claimsSessionValues = maybeOAuthClaims match { case Left(err)=> logger.warn(s"Could not get claims: $err") baseSessionValues case Right(claims)=> baseSessionValues ++ Map( "username"->claims.getUserID, "claimExpiry"->ZonedDateTime .ofInstant(claims.getExpirationTime.toInstant, ZoneId.systemDefault()) .format(DateTimeFormatter.ISO_DATE_TIME) ) } val sessionValues = maybeUserProfile match { case Left(err)=> logger.warn(s"Could not load user profile: $err") claimsSessionValues case Right(profile)=> claimsSessionValues ++ Map("userProfile"->profile.asJson.noSpaces) } Result( header, entity ).withSession(Session(sessionValues)) } ) private def storeRefreshToken(response:Either[String, OAuthResponse], maybeValidatedContent:Either[String, JWTClaimsSet]) = (response, maybeValidatedContent) match { case (Left(err), _)=>Future(Left(err)) case (_, Left(err))=>Future(Left(err)) case (Right(response), Right(validatedContent))=> response.refresh_token match { case None=>Future( Right( () )) case Some(refreshToken)=> oAuthTokenEntryDAO .saveToken(validatedContent.getUserID, ZonedDateTime.now(), refreshToken) .map(_=>Right( () )) } } private def isErrorPresent(response: Either[String, JWTClaimsSet]) = { response match { case Left(err)=> val stringPattern: Regex = "(?<=\\().*(?=\\))".r Future("?error=%s".format((stringPattern findFirstIn err).getOrElse("Unknown error.."))) case Right(claims)=> Future(s"") } } def oauthCallback(state:Option[String], code:Option[String], error:Option[String]) = Action.async { request=> (code, error) match { case (Some(actualCode), _)=> for { maybeOauthResponse <- stageTwo(actualCode, redirectUri(request), request) maybeValidatedContent <- validateContent(maybeOauthResponse) _ <- profilePicFromJWT(maybeValidatedContent) maybeUserProfile <- userProfileFromJWT(maybeValidatedContent) _ <- storeRefreshToken(maybeOauthResponse, maybeValidatedContent) maybeError <- isErrorPresent(maybeValidatedContent) result <- finalCallbackResponse(maybeOauthResponse, maybeValidatedContent, maybeUserProfile, ResponseHeader(StatusCodes.TemporaryRedirect.intValue, headers=Map("Location"->"%s%s".format(state.getOrElse("/"), maybeError))), play.api.http.HttpEntity.NoEntity ) } yield result case (_, Some(error))=> Future(InternalServerError(s"Auth provider could not log you in: $error. Try refreshing the page.")) case (None,None)=> Future(InternalServerError("Invalid response from auth provider. Try refreshing the page.")) } } /** * internal method to read in the content of the ResponseEntity and parse it as JSON * @param body the ResponseEntity * @tparam T data type to unmarshal the response into. A Left is returned if this unmarshalling fails. * @return a Future containing either a parser/decoder error or an OAuthResponse model */ def consumeBody[T:io.circe.Decoder](body:ResponseEntity):Future[Either[io.circe.Error, T]] = { body.dataBytes .map(_.decodeString(StandardCharsets.UTF_8)) .toMat(Sink.reduce[String](_ + _))(Keep.right) .run() .map(content=>{ logger.debug(s"raw auth content is $content") content }) .map(io.circe.parser.parse) .map(_.flatMap(_.as[T])) } protected def stageTwo(code:String, redirectUri:String,request: Request[Any]) = { var postdata = Map(""->"") if (config.get[String]("oAuth.type") != "Azure") { postdata = Map( "grant_type"->"authorization_code", "client_id"->config.get[String]("oAuth.clientId"), "redirect_uri"->redirectUri, "code"->code ) } else { postdata = Map( "grant_type"->"authorization_code", "client_id"->config.get[String]("oAuth.clientId"), "redirect_uri"->redirectUri, "code"->code, "code_verifier"->request.session.get("code_verifier").getOrElse("none") ) } val contentBody = HttpEntity(ContentType(MediaTypes.`application/x-www-form-urlencoded`) ,assembleFromMap(postdata)) val headers = List( Accept(MediaRange(MediaTypes.`application/json`)), Origin(config.get[String]("oAuth.origin")) ) logger.debug(s"oauth step2 exchange server url is ${config.get[String]("oAuth.tokenUrl")} and unformatted request content is $postdata") val rq = HttpRequest(HttpMethods.POST, uri=config.get[String]("oAuth.tokenUrl"), headers=headers, entity=contentBody) ( for { response <- http.singleRequest(rq) bodyContent <- consumeBody[OAuthResponse](response.entity) } yield (response, bodyContent) ).map({ case (response, Right(oAuthResponse))=> if(response.status==StatusCodes.OK) { Right(oAuthResponse) } else { Left(s"Server responded with an error ${response.status} ${oAuthResponse.toString}") } case (_, Left(decodingError))=> Left(s"Could not decode response from oauth server: $decodingError") }) } def logout() = Action { TemporaryRedirect("/") .withSession(Session.emptyCookie) .discardingCookies( DiscardingCookie( config.get[String]("oAuth.authCookieName"), config.get[String]("oAuth.refreshCookieName"), ) ) } private def safeGetCookie[A](request:Request[A], configPathToName:String):Option[Cookie] = Try { request.cookies.get(config.get[String](configPathToName)) }.toOption.flatten protected def requestRefresh(refreshToken:String) = { val params = Map( "grant_type"->"refresh_token", "refresh_token"->refreshToken ) val encodedParams = assembleFromMap(params) val contentBody = HttpEntity(ContentTypes.`application/x-www-form-urlencoded`, encodedParams) val headers = scala.collection.immutable.Seq( Accept(MediaRange(MediaTypes.`application/json`)), Origin(HttpOrigin(config.get[String]("oAuth.origin"))) ) val req = HttpRequest(HttpMethods.POST, config.get[String]("oAuth.tokenUrl"), headers, contentBody) (for { response <- http.singleRequest(req) responseBody <- consumeBody[OAuthResponse](response.entity) } yield (response, responseBody) ).map({ case (response, Right(oAuthResponse))=> response.status match { case StatusCodes.OK=> Right(oAuthResponse) case StatusCodes.BadGateway | StatusCodes.ServiceUnavailable=> Left("Authorization server is not available at the moment, hopefully refresh will work next time") case _=> Left(s"Server returned ${response.status}") } case (response, Left(err))=> logger.error(s"Could not parse response from server: $err") response.status match { case StatusCodes.BadGateway | StatusCodes.ServiceUnavailable=> Left("Authorization server is not available at the moment, hopefully refresh will work next time") case StatusCodes.BadRequest=> Left("Internal error, server rejected our request") case StatusCodes.InternalServerError=> Left("Authorization server failed trying to process our request, contact Infrastructure") case _=> Left(s"Server returned ${response.status}") } }) } private def expiryFromSession(request:Request[Any]) = request .session .get("claimExpiry") .flatMap(expiryString=>Try { ZonedDateTime.parse(expiryString, DateTimeFormatter.ISO_DATE_TIME) }.toOption) private def saveUpdatedRefreshToken(request:Request[Any], maybeOAuthResponse:Either[String, OAuthResponse]):Future[Either[String, OAuthTokenEntry]] = (request.session.get("username"), maybeOAuthResponse) match { case (Some(username), Right(oAuthResponse))=> oAuthResponse.refresh_token match { case Some(refreshToken) => oAuthTokenEntryDAO .saveToken(username, ZonedDateTime.now(), refreshToken) .map(Right.apply) .recover({ case err:Throwable=> logger.error(s"Could not save refresh token to dynamo: ${err.getMessage}", err) Left(err.getMessage) }) case None => Future(Left("no refresh token was present")) } case _=> Future(Left("either there was no username or no valid refresh token from the server")) } def refreshIfRequired = Action.async { request => import cats.implicits._ expiryFromSession(request) match { case Some(expiry) => if (Auth.claimIsExpired(expiry)) { request.session .get("username") .map(oAuthTokenEntryDAO.lookupToken) .sequence.map(_.flatten) .flatMap({ case Some(refreshToken) => for { maybeOauthResponse <- requestRefresh(refreshToken.value) maybeValidatedContent <- validateContent(maybeOauthResponse) maybeUserProfile <- userProfileFromJWT(maybeValidatedContent) _ <- oAuthTokenEntryDAO.removeUsedToken(refreshToken) _ <- saveUpdatedRefreshToken(request, maybeOauthResponse) result <- finalCallbackResponse(maybeOauthResponse, maybeValidatedContent, maybeUserProfile, ResponseHeader(200), play.api.http.HttpEntity.Strict( ByteString(GenericErrorResponse("ok", "token refreshed").asJson.noSpaces), Some("application/json") ) ) } yield result case None => logger.error("Could not find a refresh token") Future(BadRequest(GenericErrorResponse("error", "either no refresh token or server misconfigured").asJson)) }).recover({ case err: Throwable => logger.error(s"Could not refresh token for ${request.session.get("username")}: ${err.getMessage}", err) val baseResponse = InternalServerError(GenericErrorResponse("error", err.getMessage).asJson) if(Auth.claimIsExpired(expiry, trueIfNear=false)) { //if we are fully expired then blank out the session baseResponse.withSession(Session.emptyCookie) } else { baseResponse } }) } else { logger.info(s"${request.session.get("username")}: No token refresh required") Future(Ok(GenericErrorResponse("not_needed", "no refresh required").asJson)) } case None => logger.error(s"either no login or no expiry was set in the session") Future(BadRequest(GenericErrorResponse("session_problem", "either no expiry time or no login token in session").asJson)) } } } object Auth { private val logger = LoggerFactory.getLogger(getClass) case class OAuthResponse(access_token:Option[String], refresh_token:Option[String], id_token:Option[String], error:Option[String], error_description:Option[String]) /** * returns a boolean indicating if the given claims set either has expired or is about to * @param expiryTime ZonedDateTime indicating the token expiry * @return true if the claims set is expired or shortly will be */ def claimIsExpired(expiryTime:ZonedDateTime, trueIfNear:Boolean=true) = { val expiryWindow = Duration.ofMinutes(2) //attempt a refresh if the token is valid for less than this val expiresIn = Duration.between(Instant.now(), expiryTime) logger.debug(s"refresh check - access token expiry at ${expiryTime} which expires in $expiresIn") expiresIn.isNegative||expiresIn.isZero||(trueIfNear && expiryWindow.compareTo(expiresIn)>=0) //compareTo - if window>expiresIn result =1, if == result=0 if < result=-1 } }