/*
 * Decompiled with CFR 0.152.
 */
package io.hops.hopsworks.jwt;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTCreator;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import io.hops.hopsworks.jwt.AlgorithmFactory;
import io.hops.hopsworks.jwt.JsonWebToken;
import io.hops.hopsworks.jwt.SignatureAlgorithm;
import io.hops.hopsworks.jwt.dao.InvalidJwtFacade;
import io.hops.hopsworks.jwt.dao.JwtSigningKeyFacade;
import io.hops.hopsworks.jwt.exception.DuplicateSigningKeyException;
import io.hops.hopsworks.jwt.exception.InvalidationException;
import io.hops.hopsworks.jwt.exception.JWTException;
import io.hops.hopsworks.jwt.exception.NotRenewableException;
import io.hops.hopsworks.jwt.exception.SigningKeyNotFoundException;
import io.hops.hopsworks.jwt.exception.VerificationException;
import io.hops.hopsworks.persistence.entity.jwt.InvalidJwt;
import io.hops.hopsworks.persistence.entity.jwt.JwtSigningKey;
import java.security.NoSuchAlgorithmException;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Calendar;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.ejb.AccessLocalException;
import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import org.apache.commons.lang3.tuple.Pair;

@Stateless
@TransactionAttribute(value=TransactionAttributeType.NOT_SUPPORTED)
public class JWTController {
    private static final Logger LOGGER = Logger.getLogger(JWTController.class.getName());
    @EJB
    private InvalidJwtFacade invalidJwtFacade;
    @EJB
    private AlgorithmFactory algorithmFactory;
    @EJB
    private JwtSigningKeyFacade jwtSigningKeyFacade;
    private static final String SERVICE_ONE_TIME_SIGNING_KEYNAME = "%s_%s__%d";

    public String createToken(JsonWebToken jwt, Map<String, Object> claims) throws SigningKeyNotFoundException {
        return this.createToken(jwt.getKeyId(), jwt.getIssuer(), jwt.getAudience().toArray(new String[0]), jwt.getExpiresAt(), jwt.getNotBefore(), jwt.getSubject(), claims, jwt.getAlgorithm());
    }

    public String createToken(String keyId, String issuer, String[] audience, Date expiresAt, Date notBefore, String subject, Map<String, Object> claims, SignatureAlgorithm algorithm) throws SigningKeyNotFoundException {
        JWTCreator.Builder jwtBuilder = JWT.create().withKeyId(keyId).withIssuer(issuer).withAudience(audience).withIssuedAt(new Date()).withExpiresAt(expiresAt).withNotBefore(notBefore).withJWTId(this.generateJti()).withSubject(subject);
        Integer expLeeway = (Integer)claims.getOrDefault("expLeeway", -1);
        claims.put("expLeeway", this.getExpLeewayOrDefault(expLeeway));
        jwtBuilder = this.addClaims(jwtBuilder, claims);
        return jwtBuilder.sign(this.algorithmFactory.getAlgorithm(algorithm, keyId));
    }

    private JWTCreator.Builder addClaims(JWTCreator.Builder jwtCreator, Map<String, Object> claims) {
        for (Map.Entry<String, Object> entry : claims.entrySet()) {
            Object value = entry.getValue();
            if (value.getClass().isArray()) {
                Class<?> clazz = value.getClass().getComponentType();
                if (String.class.equals(clazz)) {
                    jwtCreator = jwtCreator.withArrayClaim(entry.getKey(), (String[])value);
                    continue;
                }
                if (Integer.class.equals(clazz)) {
                    jwtCreator = jwtCreator.withArrayClaim(entry.getKey(), (Integer[])value);
                    continue;
                }
                if (!Long.class.equals(clazz)) continue;
                jwtCreator = jwtCreator.withArrayClaim(entry.getKey(), (Long[])value);
                continue;
            }
            if (Boolean.class.isInstance(value)) {
                jwtCreator = jwtCreator.withClaim(entry.getKey(), (Boolean)value);
                continue;
            }
            if (Integer.class.isInstance(value)) {
                jwtCreator = jwtCreator.withClaim(entry.getKey(), (Integer)value);
                continue;
            }
            if (Long.class.isInstance(value)) {
                jwtCreator = jwtCreator.withClaim(entry.getKey(), (Long)value);
                continue;
            }
            if (Double.class.isInstance(value)) {
                jwtCreator = jwtCreator.withClaim(entry.getKey(), (Double)value);
                continue;
            }
            if (String.class.isInstance(value)) {
                jwtCreator = jwtCreator.withClaim(entry.getKey(), (String)value);
                continue;
            }
            if (!Date.class.isInstance(value)) continue;
            jwtCreator = jwtCreator.withClaim(entry.getKey(), (Date)value);
        }
        return jwtCreator;
    }

    public String createToken(String keyName, boolean createNewKey, String issuer, String[] audience, Date expiresAt, Date notBefore, String subject, Map<String, Object> claims, SignatureAlgorithm algorithm) throws NoSuchAlgorithmException, SigningKeyNotFoundException, DuplicateSigningKeyException {
        JwtSigningKey signingKey = createNewKey ? this.createNewSigningKey(keyName, algorithm) : this.getOrCreateSigningKey(keyName, algorithm);
        return this.createToken(signingKey.getId().toString(), issuer, audience, expiresAt, notBefore, subject, claims, algorithm);
    }

    public void invalidate(String token) throws InvalidationException {
        DecodedJWT jwt;
        if (token == null || token.isEmpty()) {
            return;
        }
        try {
            jwt = this.verifyToken(token, null);
        }
        catch (Exception ex) {
            return;
        }
        int expLeeway = this.getExpLeewayClaim(jwt);
        this.invalidateJWT(jwt.getId(), jwt.getExpiresAt(), expLeeway);
    }

    public int getExpLeewayClaim(DecodedJWT jwt) {
        Claim expLeewayClaim = jwt.getClaim("expLeeway");
        return expLeewayClaim == null ? 60 : this.getExpLeewayOrDefault(expLeewayClaim.asInt());
    }

    public int getExpLeewayOrDefault(int expLeeway) {
        return expLeeway < 1 ? 60 : expLeeway;
    }

    public boolean getRenewableClaim(DecodedJWT jwt) {
        Claim renewableClaim = jwt.getClaim("renewable");
        return renewableClaim != null ? renewableClaim.asBoolean() : false;
    }

    public String[] getRolesClaim(DecodedJWT jwt) {
        Claim rolesClaim = jwt.getClaim("roles");
        return rolesClaim == null ? new String[]{} : (String[])rolesClaim.asArray(String.class);
    }

    public DecodedJWT decodeToken(String token) {
        if (token == null || token.isEmpty()) {
            return null;
        }
        return JWT.decode((String)token);
    }

    public DecodedJWT verifyToken(String token, String issuer) throws SigningKeyNotFoundException, VerificationException {
        int expLeeway;
        DecodedJWT jwt = JWT.decode((String)token);
        jwt = this.verifyToken(token, issuer = issuer == null || issuer.isEmpty() ? jwt.getIssuer() : issuer, expLeeway = this.getExpLeewayClaim(jwt), this.algorithmFactory.getAlgorithm(jwt));
        if (this.isTokenInvalidated(jwt)) {
            throw new VerificationException("Invalidated token.");
        }
        return jwt;
    }

    public DecodedJWT verifyOneTimeToken(String token, String issuer) throws SigningKeyNotFoundException, VerificationException, InvalidationException {
        DecodedJWT jwt = this.verifyToken(token, issuer);
        this.invalidateJWT(jwt.getId(), jwt.getExpiresAt(), this.getExpLeewayClaim(jwt));
        return jwt;
    }

    public DecodedJWT verifyToken(String token, String issuer, Set<String> audiences, Set<String> roles) throws SigningKeyNotFoundException, VerificationException {
        JsonWebToken jwt;
        DecodedJWT djwt = this.verifyToken(token, issuer = issuer == null || issuer.isEmpty() ? jwt.getIssuer() : issuer, (jwt = new JsonWebToken(JWT.decode((String)token))).getExpLeeway(), this.algorithmFactory.getAlgorithm(jwt));
        if (this.isTokenInvalidated(djwt)) {
            throw new VerificationException("Invalidated token.");
        }
        HashSet<String> rolesSet = new HashSet<String>(jwt.getRole());
        if (roles != null && !roles.isEmpty() && !this.intersect(roles, rolesSet)) {
            throw new AccessLocalException("Client not authorized for this invocation.");
        }
        HashSet<String> audiencesSet = new HashSet<String>(jwt.getAudience());
        if (audiences != null && !audiences.isEmpty() && !this.intersect(audiences, audiencesSet)) {
            throw new AccessLocalException("Token not issued for this recipient.");
        }
        return djwt;
    }

    private DecodedJWT verifyToken(String token, String issuer, int expLeeway, Algorithm algorithm) throws VerificationException {
        DecodedJWT jwt = null;
        try {
            JWTVerifier verifier = JWT.require((Algorithm)algorithm).withIssuer(new String[]{issuer}).acceptExpiresAt((long)expLeeway).build();
            jwt = verifier.verify(token);
        }
        catch (Exception e) {
            throw new VerificationException(e.getMessage());
        }
        return jwt;
    }

    private boolean intersect(Collection list1, Collection list2) {
        if (list1 == null || list1.isEmpty() || list2 == null || list2.isEmpty()) {
            return false;
        }
        HashSet set1 = new HashSet(list1);
        HashSet set2 = new HashSet(list2);
        set1.retainAll(set2);
        return !set1.isEmpty();
    }

    public boolean isTokenInvalidated(DecodedJWT jwt) {
        return this.isTokenInvalidated(jwt.getId());
    }

    private boolean isTokenInvalidated(String id) {
        InvalidJwt invalidJwt = this.invalidJwtFacade.find(id);
        return invalidJwt != null;
    }

    public String autoRenewToken(String token) throws SigningKeyNotFoundException, NotRenewableException, InvalidationException {
        DecodedJWT jwt = this.verifyTokenForRenewal(token);
        boolean isRenewable = this.getRenewableClaim(jwt);
        if (!isRenewable) {
            throw new NotRenewableException("Token not renewable.");
        }
        Date currentTime = new Date();
        if (currentTime.before(jwt.getExpiresAt())) {
            throw new NotRenewableException("Token not expired.");
        }
        long lifetimeMs = jwt.getExpiresAt().getTime() - jwt.getIssuedAt().getTime();
        JsonWebToken _jwt = new JsonWebToken(jwt);
        _jwt.setExpiresAt(new Date(System.currentTimeMillis() + lifetimeMs));
        _jwt.setNotBefore(new Date());
        HashMap<String, Object> claims = new HashMap<String, Object>(3);
        this.addDefaultClaimsIfMissing(claims, _jwt.isRenewable(), this.getExpLeewayOrDefault(_jwt.getExpLeeway()), _jwt.getRole().toArray(new String[1]));
        String renewedToken = this.createToken(_jwt, claims);
        this.invalidateJWT(jwt.getId(), jwt.getExpiresAt(), _jwt.getExpLeeway());
        return renewedToken;
    }

    public String renewToken(String token, Date newExp, Date notBefore, boolean invalidate, Map<String, Object> claims) throws SigningKeyNotFoundException, NotRenewableException, InvalidationException {
        return this.renewToken(token, newExp, notBefore, invalidate, claims, false);
    }

    public String renewToken(String token, Date newExp, Date notBefore, boolean invalidate, Map<String, Object> claims, boolean force) throws SigningKeyNotFoundException, NotRenewableException, InvalidationException {
        Date currentTime;
        DecodedJWT jwt = this.verifyTokenForRenewal(token);
        if (!force && (currentTime = new Date()).before(jwt.getExpiresAt())) {
            throw new NotRenewableException("Token not expired.");
        }
        JsonWebToken _jwt = new JsonWebToken(jwt);
        _jwt.setExpiresAt(newExp);
        _jwt.setNotBefore(notBefore);
        claims = this.addDefaultClaimsIfMissing(claims, _jwt.isRenewable(), this.getExpLeewayOrDefault(_jwt.getExpLeeway()), _jwt.getRole().toArray(new String[1]));
        String renewedToken = this.createToken(_jwt, claims);
        if (invalidate) {
            this.invalidateJWT(jwt.getId(), jwt.getExpiresAt(), _jwt.getExpLeeway());
        }
        return renewedToken;
    }

    public Pair<String, String[]> renewServiceToken(String oneTimeRenewalToken, String serviceToken, Date newExpiration, Date newNotBefore, Long serviceJWTLifetimeMS, String username, List<String> userRoles, List<String> audience, String remoteHostname, String issuer, String defaultJWTSigningKeyName, boolean force) throws JWTException, NoSuchAlgorithmException {
        HashMap<String, Object> claims = new HashMap<String, Object>(4);
        claims.put("renewable", false);
        claims.put("expLeeway", 3600);
        claims.put("roles", userRoles.toArray(new String[1]));
        String renewalKeyName = this.getServiceOneTimeJWTSigningKeyname(username, remoteHostname);
        LocalDateTime masterExpiration = newExpiration.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
        LocalDateTime notBefore = this.computeNotBefore4ServiceRenewalTokens(masterExpiration);
        LocalDateTime expiresAt = notBefore.plus(serviceJWTLifetimeMS, ChronoUnit.MILLIS);
        JsonWebToken jwtSpecs = new JsonWebToken();
        jwtSpecs.setSubject(username);
        jwtSpecs.setIssuer(issuer);
        jwtSpecs.setAudience(audience);
        jwtSpecs.setKeyId(renewalKeyName);
        jwtSpecs.setNotBefore(this.localDateTime2Date(notBefore));
        jwtSpecs.setExpiresAt(this.localDateTime2Date(expiresAt));
        try {
            String[] renewalTokens = this.generateOneTimeTokens4ServiceJWTRenewal(jwtSpecs, claims, defaultJWTSigningKeyName);
            String signingKeyId = this.getSignKeyID(renewalTokens[0]);
            DecodedJWT serviceJWT = this.decodeToken(serviceToken);
            claims.clear();
            claims.put("renewable", false);
            claims.put("renewal_key_id", signingKeyId);
            claims.put("expLeeway", this.getExpLeewayClaim(serviceJWT));
            String renewedServiceToken = this.renewToken(serviceToken, newExpiration, newNotBefore, false, claims, force);
            this.invalidate(oneTimeRenewalToken);
            return Pair.of((Object)renewedServiceToken, (Object)renewalTokens);
        }
        catch (JWTException | NoSuchAlgorithmException ex) {
            if (renewalKeyName != null) {
                this.deleteSigningKey(renewalKeyName);
            }
            throw ex;
        }
    }

    public void invalidateServiceToken(String serviceToken2invalidate, String defaultJWTSigningKeyName) {
        JwtSigningKey signingKey;
        DecodedJWT serviceJWT2invalidate = this.decodeToken(serviceToken2invalidate);
        try {
            this.invalidate(serviceToken2invalidate);
        }
        catch (InvalidationException ex) {
            LOGGER.log(Level.WARNING, "Could not invalidate service JWT with ID " + serviceJWT2invalidate.getId() + ". Continuing with deleting signing key");
        }
        Claim signingKeyID = serviceJWT2invalidate.getClaim("renewal_key_id");
        if (!(signingKeyID == null || signingKeyID.isNull() || (signingKey = this.findSigningKeyById(Integer.parseInt(signingKeyID.asString()))) == null || defaultJWTSigningKeyName == null || defaultJWTSigningKeyName.equals(signingKey.getName()) || "oneTimeKey".equals(signingKey.getName()))) {
            this.deleteSigningKey(signingKey.getName());
        }
    }

    public String getSignKeyID(String token) {
        DecodedJWT jwt = this.decodeToken(token);
        return jwt.getKeyId();
    }

    public String[] generateOneTimeTokens4ServiceJWTRenewal(JsonWebToken jwtSpecs, Map<String, Object> claims, String defaultJWTSigningKeyName) throws NoSuchAlgorithmException, SigningKeyNotFoundException {
        String[] renewalTokens = new String[5];
        SignatureAlgorithm algorithm = SignatureAlgorithm.valueOf("HS256");
        String[] audienceArray = jwtSpecs.getAudience().toArray(new String[1]);
        try {
            renewalTokens[0] = this.createToken(jwtSpecs.getKeyId(), true, jwtSpecs.getIssuer(), audienceArray, jwtSpecs.getExpiresAt(), jwtSpecs.getNotBefore(), jwtSpecs.getSubject(), claims, algorithm);
        }
        catch (DuplicateSigningKeyException ex) {
            LOGGER.log(Level.FINE, "Signing key already exist for service JWT key " + jwtSpecs.getKeyId() + ". Removing old one");
            if (defaultJWTSigningKeyName != null && !defaultJWTSigningKeyName.equals(jwtSpecs.getKeyId()) && !"oneTimeKey".equals(jwtSpecs.getKeyId())) {
                this.deleteSigningKey(jwtSpecs.getKeyId());
            }
            try {
                renewalTokens[0] = this.createToken(jwtSpecs.getKeyId(), true, jwtSpecs.getIssuer(), audienceArray, jwtSpecs.getExpiresAt(), jwtSpecs.getNotBefore(), jwtSpecs.getSubject(), claims, algorithm);
            }
            catch (DuplicateSigningKeyException duplicateSigningKeyException) {
                // empty catch block
            }
        }
        for (int i = 1; i < renewalTokens.length; ++i) {
            try {
                renewalTokens[i] = this.createToken(jwtSpecs.getKeyId(), false, jwtSpecs.getIssuer(), audienceArray, jwtSpecs.getExpiresAt(), jwtSpecs.getNotBefore(), jwtSpecs.getSubject(), claims, algorithm);
                continue;
            }
            catch (DuplicateSigningKeyException duplicateSigningKeyException) {
                // empty catch block
            }
        }
        return renewalTokens;
    }

    private Date localDateTime2Date(LocalDateTime localDateTime) {
        return Date.from(localDateTime.atZone(ZoneId.systemDefault()).toInstant());
    }

    public LocalDateTime computeNotBefore4ServiceRenewalTokens(LocalDateTime masterExpiration) {
        LocalDateTime notBefore = null;
        notBefore = masterExpiration.minus(3L, ChronoUnit.MINUTES).isBefore(LocalDateTime.now()) ? masterExpiration.minus(3L, ChronoUnit.MILLIS) : masterExpiration.minus(3L, ChronoUnit.MINUTES);
        return notBefore;
    }

    public String getServiceOneTimeJWTSigningKeyname(String username, String remoteHost) {
        long now = System.currentTimeMillis();
        return String.format(SERVICE_ONE_TIME_SIGNING_KEYNAME, username, remoteHost, now);
    }

    public Map<String, Object> addDefaultClaimsIfMissing(Map<String, Object> userClaims, boolean isRenewable, int leeway, String[] roles) {
        if (userClaims == null) {
            userClaims = new HashMap<String, Object>(3);
            userClaims.put("renewable", isRenewable);
            userClaims.put("expLeeway", leeway);
            userClaims.put("roles", roles);
        } else {
            userClaims.putIfAbsent("renewable", isRenewable);
            userClaims.putIfAbsent("expLeeway", leeway);
            userClaims.putIfAbsent("roles", roles);
        }
        return userClaims;
    }

    private DecodedJWT verifyTokenForRenewal(String token) throws SigningKeyNotFoundException, NotRenewableException {
        DecodedJWT jwt;
        try {
            jwt = this.verifyToken(token, null);
        }
        catch (VerificationException ex) {
            throw new NotRenewableException(ex.getMessage());
        }
        return jwt;
    }

    private void invalidateJWT(String id, Date exp, int leeway) throws InvalidationException {
        try {
            InvalidJwt invalidJwt = new InvalidJwt(id, exp, leeway);
            this.invalidJwtFacade.persist(invalidJwt);
        }
        catch (Exception e) {
            throw new InvalidationException("Could not persist token.", e.getCause());
        }
    }

    public boolean passedRenewal(DecodedJWT jwt) {
        int expLeeway = this.getExpLeewayClaim(jwt);
        return this.passedRenewal(jwt.getExpiresAt(), expLeeway);
    }

    public boolean passedRenewal(Date exp, int expLeeway) {
        Date expireOn = new Date(exp.getTime() + (long)(expLeeway * 1000));
        return expireOn.before(new Date());
    }

    public String generateJti() {
        UUID uuid = UUID.randomUUID();
        String randomUUIDString = uuid.toString();
        InvalidJwt invalidJwt = this.invalidJwtFacade.find(randomUUIDString);
        while (invalidJwt != null) {
            uuid = UUID.randomUUID();
            randomUUIDString = uuid.toString();
            invalidJwt = this.invalidJwtFacade.find(randomUUIDString);
        }
        return randomUUIDString;
    }

    public JwtSigningKey getOrCreateSigningKey(String keyName, SignatureAlgorithm alg) throws NoSuchAlgorithmException {
        return this.jwtSigningKeyFacade.getOrCreateSigningKey(keyName, alg);
    }

    public JwtSigningKey createNewSigningKey(String keyName, SignatureAlgorithm alg) throws NoSuchAlgorithmException, DuplicateSigningKeyException {
        return this.jwtSigningKeyFacade.createNewSigningKey(keyName, alg);
    }

    public void deleteSigningKey(String keyName) {
        this.jwtSigningKeyFacade.remove(keyName);
    }

    public JwtSigningKey findSigningKeyById(Integer id) {
        return this.jwtSigningKeyFacade.find(id);
    }

    public int cleanupInvalidTokens() {
        List<InvalidJwt> expiredTokens = this.invalidJwtFacade.findExpired();
        int count = 0;
        for (InvalidJwt expiredToken : expiredTokens) {
            if (!this.passedRenewal(expiredToken.getExpirationTime(), expiredToken.getRenewableForSec())) continue;
            this.invalidJwtFacade.remove(expiredToken);
            ++count;
        }
        return count;
    }

    public boolean markOldSigningKeys() {
        JwtSigningKey jwtSigningKey = this.jwtSigningKeyFacade.findByName("oneTimeKey");
        Calendar cal = Calendar.getInstance();
        cal.add(5, -1);
        if (jwtSigningKey != null && jwtSigningKey.getCreatedOn().before(cal.getTime())) {
            this.removeMarkedKeys();
            this.jwtSigningKeyFacade.renameSigningKey(jwtSigningKey, "oneTimeKey_old");
            try {
                this.jwtSigningKeyFacade.getOrCreateSigningKey("oneTimeKey", SignatureAlgorithm.HS256);
            }
            catch (NoSuchAlgorithmException ex) {
                LOGGER.log(Level.SEVERE, null, ex);
            }
            return true;
        }
        return false;
    }

    public void removeMarkedKeys() {
        JwtSigningKey jwtSigningKey = this.jwtSigningKeyFacade.findByName("oneTimeKey_old");
        if (jwtSigningKey != null) {
            this.jwtSigningKeyFacade.remove(jwtSigningKey);
        }
    }

    public String getSigningKeyForELK(SignatureAlgorithm alg) throws NoSuchAlgorithmException {
        return this.getOrCreateSigningKey("elk_jwt_signing_key", alg).getSecret();
    }

    public String createTokenForELK(String subjectName, String issuer, Map<String, Object> claims, Date expiresAt, SignatureAlgorithm alg) throws DuplicateSigningKeyException, NoSuchAlgorithmException, SigningKeyNotFoundException {
        return this.createToken("elk_jwt_signing_key", false, issuer, null, expiresAt, null, subjectName.toLowerCase(), claims, alg);
    }
}

