package io.hops.hopsworks.common.serving.inference;

import io.hops.common.Pair;
import io.hops.hopsworks.common.integrations.LocalhostStereotype;
import io.hops.hopsworks.common.serving.LocalhostServingController;
import io.hops.hopsworks.exceptions.InferenceException;
import io.hops.hopsworks.persistence.entity.serving.ModelServer;
import io.hops.hopsworks.persistence.entity.serving.Serving;
import io.hops.hopsworks.restutils.RESTCodes;
import java.net.URISyntaxException;
import java.nio.charset.UnsupportedCharsetException;
import java.util.logging.Level;
import javax.ejb.ConcurrencyManagement;
import javax.ejb.ConcurrencyManagementType;
import javax.ejb.EJB;
import javax.ejb.Singleton;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import org.apache.http.client.protocol.HttpClientContext;

@ConcurrencyManagement(ConcurrencyManagementType.BEAN)
@Singleton
@LocalhostStereotype
@TransactionAttribute(TransactionAttributeType.NEVER)
/* loaded from: input_file:io/hops/hopsworks/common/serving/inference/LocalhostInferenceController.class */
public class LocalhostInferenceController implements ServingInferenceController {

    @EJB
    private InferenceHttpClient inferenceHttpClient;

    @EJB
    private LocalhostTfInferenceUtils localhostTfInferenceUtils;

    @EJB
    private LocalhostSkLearnInferenceUtils localhostSkLearnInferenceUtils;

    @EJB
    private ServingInferenceUtils servingInferenceUtils;

    @Override // io.hops.hopsworks.common.serving.inference.ServingInferenceController
    public Pair<Integer, String> infer(Serving serving, Integer num, String str, String str2) throws InferenceException {
        if (serving.getCid().equals(LocalhostServingController.CID_STOPPED)) {
            throw new InferenceException(RESTCodes.InferenceErrorCode.SERVING_NOT_RUNNING, Level.FINE);
        }
        try {
            return this.inferenceHttpClient.handleInferenceResponse(this.inferenceHttpClient.execute(this.servingInferenceUtils.buildInferenceRequest("localhost", serving.getLocalPort().intValue(), getInferencePath(serving, num, str), str2), HttpClientContext.create()));
        } catch (URISyntaxException e) {
            throw new InferenceException(RESTCodes.InferenceErrorCode.REQUEST_ERROR, Level.SEVERE, (String) null, e.getMessage(), e);
        } catch (UnsupportedCharsetException e2) {
            throw new InferenceException(RESTCodes.InferenceErrorCode.BAD_REQUEST, Level.INFO, (String) null, e2.getMessage(), e2);
        }
    }

    private String getInferencePath(Serving serving, Integer num, String str) {
        if (serving.getModelServer() == ModelServer.TENSORFLOW_SERVING) {
            return this.localhostTfInferenceUtils.getPath(serving.getName(), num, str);
        }
        if (serving.getModelServer() == ModelServer.FLASK) {
            return this.localhostSkLearnInferenceUtils.getPath(str);
        }
        throw new UnsupportedOperationException("Model server not supported as local serving");
    }
}
