package com.google.mediapipe.tasks.vision.interactivesegmenter;

import android.content.Context;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.proto.CalculatorOptionsProto;
import com.google.mediapipe.proto.CalculatorProto;
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto;
import com.google.mediapipe.tasks.vision.interactivesegmenter.AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions;
import com.google.mediapipe.util.proto.ColorProto;
import com.google.mediapipe.util.proto.RenderDataProto;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;

/* loaded from: classes2.dex */
public final class InteractiveSegmenter extends BaseVisionTaskApi {
    private static final String IMAGE_IN_STREAM_NAME = "image_in";
    private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
    private static final String ROI_IN_STREAM_NAME = "roi_in";
    private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
    private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = "mediapipe.tasks.TensorsToSegmentationCalculator";
    private boolean hasResultListener;
    private List<String> labels;
    private static final String TAG = InteractiveSegmenter.class.getSimpleName();
    private static final List<String> INPUT_STREAMS = Collections.unmodifiableList(Arrays.asList("IMAGE:image_in", "ROI:roi_in", "NORM_RECT:norm_rect_in"));

    /* loaded from: classes2.dex */
    public static abstract class InteractiveSegmenterOptions extends TaskOptions {

        /* loaded from: classes2.dex */
        public static abstract class Builder {
            abstract InteractiveSegmenterOptions autoBuild();

            public final InteractiveSegmenterOptions build() {
                return autoBuild();
            }

            public abstract Builder setBaseOptions(BaseOptions baseOptions);

            public abstract Builder setErrorListener(ErrorListener errorListener);

            public abstract Builder setOutputCategoryMask(boolean z);

            public abstract Builder setOutputConfidenceMasks(boolean z);

            public abstract Builder setResultListener(OutputHandler.ResultListener<ImageSegmenterResult, MPImage> resultListener);
        }

        public static Builder builder() {
            return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder().setOutputConfidenceMasks(true).setOutputCategoryMask(false);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract BaseOptions baseOptions();

        @Override // com.google.mediapipe.tasks.core.TaskOptions
        public CalculatorOptionsProto.CalculatorOptions convertToCalculatorOptionsProto() {
            ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder baseOptions = ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder().setBaseOptions((BaseOptionsProto.BaseOptions) BaseOptionsProto.BaseOptions.newBuilder().setUseStreamMode(false).mergeFrom((BaseOptionsProto.BaseOptions.Builder) convertBaseOptionsToProto(baseOptions())).build());
            baseOptions.setSegmenterOptions(SegmenterOptionsProto.SegmenterOptions.newBuilder());
            return (CalculatorOptionsProto.CalculatorOptions) CalculatorOptionsProto.CalculatorOptions.newBuilder().setExtension(ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext, (ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions) baseOptions.build()).build();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract Optional<ErrorListener> errorListener();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract boolean outputCategoryMask();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract boolean outputConfidenceMasks();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract Optional<OutputHandler.ResultListener<ImageSegmenterResult, MPImage>> resultListener();
    }

    /* loaded from: classes2.dex */
    public static class RegionOfInterest {
        private NormalizedKeypoint keypoint;
        private List<NormalizedKeypoint> scribble;

        private RegionOfInterest() {
        }

        public static RegionOfInterest create(NormalizedKeypoint normalizedKeypoint) {
            RegionOfInterest regionOfInterest = new RegionOfInterest();
            regionOfInterest.keypoint = normalizedKeypoint;
            return regionOfInterest;
        }

        public static RegionOfInterest create(List<NormalizedKeypoint> list) {
            RegionOfInterest regionOfInterest = new RegionOfInterest();
            regionOfInterest.scribble = list;
            return regionOfInterest;
        }
    }

    static {
        System.loadLibrary("mediapipe_tasks_vision_jni");
        ProtoUtil.registerTypeName(RenderDataProto.RenderData.class, "mediapipe.RenderData");
    }

    private InteractiveSegmenter(TaskRunner taskRunner, boolean z) {
        super(taskRunner, RunningMode.IMAGE, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
        this.hasResultListener = false;
        this.labels = new ArrayList();
        this.hasResultListener = z;
        populateLabels();
    }

    private static RenderDataProto.RenderData convertToRenderData(RegionOfInterest regionOfInterest) {
        RenderDataProto.RenderData.Builder newBuilder = RenderDataProto.RenderData.newBuilder();
        if (regionOfInterest.keypoint != null) {
            return (RenderDataProto.RenderData) newBuilder.addRenderAnnotations(RenderDataProto.RenderAnnotation.newBuilder().setColor(ColorProto.Color.newBuilder().setR(255)).setPoint(RenderDataProto.RenderAnnotation.Point.newBuilder().setX(regionOfInterest.keypoint.x()).setY(regionOfInterest.keypoint.y()))).build();
        }
        if (regionOfInterest.scribble == null) {
            throw new IllegalArgumentException("RegionOfInterest does not include a valid user interaction");
        }
        RenderDataProto.RenderAnnotation.Scribble.Builder newBuilder2 = RenderDataProto.RenderAnnotation.Scribble.newBuilder();
        for (NormalizedKeypoint normalizedKeypoint : regionOfInterest.scribble) {
            newBuilder2.addPoint(RenderDataProto.RenderAnnotation.Point.newBuilder().setX(normalizedKeypoint.x()).setY(normalizedKeypoint.y()));
        }
        return (RenderDataProto.RenderData) newBuilder.addRenderAnnotations(RenderDataProto.RenderAnnotation.newBuilder().setColor(ColorProto.Color.newBuilder().setR(255)).setScribble(newBuilder2)).build();
    }

    public static InteractiveSegmenter createFromOptions(Context context, final InteractiveSegmenterOptions interactiveSegmenterOptions) {
        if (!interactiveSegmenterOptions.outputConfidenceMasks() && !interactiveSegmenterOptions.outputCategoryMask()) {
            throw new IllegalArgumentException("At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
        }
        ArrayList arrayList = new ArrayList();
        if (interactiveSegmenterOptions.outputConfidenceMasks()) {
            arrayList.add("CONFIDENCE_MASKS:confidence_masks");
        }
        final int size = arrayList.size() - 1;
        if (interactiveSegmenterOptions.outputCategoryMask()) {
            arrayList.add("CATEGORY_MASK:category_mask");
        }
        final int size2 = arrayList.size() - 1;
        arrayList.add("QUALITY_SCORES:quality_scores");
        final int size3 = arrayList.size() - 1;
        arrayList.add("IMAGE:image_out");
        final int size4 = arrayList.size() - 1;
        final OutputHandler outputHandler = new OutputHandler();
        outputHandler.setOutputPacketConverter(new OutputHandler.OutputPacketConverter<ImageSegmenterResult, MPImage>() { // from class: com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.google.mediapipe.tasks.core.OutputHandler.OutputPacketConverter
            public MPImage convertToTaskInput(List<Packet> list) {
                return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(list.get(size4))).build();
            }

            @Override // com.google.mediapipe.tasks.core.OutputHandler.OutputPacketConverter
            public /* bridge */ /* synthetic */ MPImage convertToTaskInput(List list) {
                return convertToTaskInput((List<Packet>) list);
            }

            @Override // com.google.mediapipe.tasks.core.OutputHandler.OutputPacketConverter
            public /* bridge */ /* synthetic */ ImageSegmenterResult convertToTaskResult(List list) {
                return convertToTaskResult2((List<Packet>) list);
            }

            @Override // com.google.mediapipe.tasks.core.OutputHandler.OutputPacketConverter
            /* renamed from: convertToTaskResult, reason: avoid collision after fix types in other method */
            public ImageSegmenterResult convertToTaskResult2(List<Packet> list) throws MediaPipeException {
                ByteBuffer imageDataDirectly;
                if (list.get(size4).isEmpty()) {
                    return ImageSegmenterResult.create(Optional.empty(), Optional.empty(), new ArrayList(), list.get(size4).getTimestamp());
                }
                boolean z = !interactiveSegmenterOptions.resultListener().isPresent();
                Optional empty = Optional.empty();
                if (interactiveSegmenterOptions.outputConfidenceMasks()) {
                    empty = Optional.of(new ArrayList());
                    int imageWidthFromImageList = PacketGetter.getImageWidthFromImageList(list.get(size));
                    int imageHeightFromImageList = PacketGetter.getImageHeightFromImageList(list.get(size));
                    int imageListSize = PacketGetter.getImageListSize(list.get(size));
                    ByteBuffer[] byteBufferArr = new ByteBuffer[imageListSize];
                    if (z) {
                        for (int i = 0; i < imageListSize; i++) {
                            byteBufferArr[i] = ByteBuffer.allocateDirect(imageWidthFromImageList * imageHeightFromImageList * 4);
                        }
                    }
                    if (!PacketGetter.getImageList(list.get(size), byteBufferArr, z)) {
                        throw new MediaPipeException(MediaPipeException.StatusCode.INTERNAL.ordinal(), "There is an error getting confidence masks.");
                    }
                    for (int i2 = 0; i2 < imageListSize; i2++) {
                        ((List) empty.get()).add(new ByteBufferImageBuilder(byteBufferArr[i2], imageWidthFromImageList, imageHeightFromImageList, 10).build());
                    }
                }
                Optional empty2 = Optional.empty();
                if (interactiveSegmenterOptions.outputCategoryMask()) {
                    int imageWidth = PacketGetter.getImageWidth(list.get(size2));
                    int imageHeight = PacketGetter.getImageHeight(list.get(size2));
                    if (z) {
                        imageDataDirectly = ByteBuffer.allocateDirect(imageWidth * imageHeight);
                        if (!PacketGetter.getImageData(list.get(size2), imageDataDirectly)) {
                            throw new MediaPipeException(MediaPipeException.StatusCode.INTERNAL.ordinal(), "There is an error getting category mask.");
                        }
                    } else {
                        imageDataDirectly = PacketGetter.getImageDataDirectly(list.get(size2));
                    }
                    empty2 = Optional.of(new ByteBufferImageBuilder(imageDataDirectly, imageWidth, imageHeight, 8).build());
                }
                float[] float32Vector = PacketGetter.getFloat32Vector(list.get(size3));
                ArrayList arrayList2 = new ArrayList(float32Vector.length);
                for (float f : float32Vector) {
                    arrayList2.add(Float.valueOf(f));
                }
                return ImageSegmenterResult.create(empty, empty2, arrayList2, BaseVisionTaskApi.generateResultTimestampMs(RunningMode.IMAGE, list.get(size4)));
            }
        });
        interactiveSegmenterOptions.resultListener().ifPresent(new Consumer() { // from class: com.google.mediapipe.tasks.vision.interactivesegmenter.-$$Lambda$InteractiveSegmenter$6Amtv4dDgujpaxQXKD2mnkhbPvo
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                OutputHandler.this.setResultListener((OutputHandler.ResultListener) obj);
            }
        });
        interactiveSegmenterOptions.errorListener().ifPresent(new Consumer() { // from class: com.google.mediapipe.tasks.vision.interactivesegmenter.-$$Lambda$InteractiveSegmenter$tnuyGhYJOxuTeOmJXKp5KUbehkA
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                OutputHandler.this.setErrorListener((ErrorListener) obj);
            }
        });
        return new InteractiveSegmenter(TaskRunner.create(context, TaskInfo.builder().setTaskName(InteractiveSegmenter.class.getSimpleName()).setTaskRunningModeName(RunningMode.IMAGE.name()).setTaskGraphName(TASK_GRAPH_NAME).setInputStreams(INPUT_STREAMS).setOutputStreams(arrayList).setTaskOptions(interactiveSegmenterOptions).setEnableFlowLimiting(false).build(), outputHandler), interactiveSegmenterOptions.resultListener().isPresent());
    }

    private void populateLabels() {
        boolean z = false;
        for (CalculatorProto.CalculatorGraphConfig.Node node : this.runner.getCalculatorGraphConfig().getNodeList()) {
            if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
                if (z) {
                    throw new MediaPipeException(MediaPipeException.StatusCode.INTERNAL.ordinal(), "The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
                }
                TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions tensorsToSegmentationCalculatorOptions = (TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions) node.getOptions().getExtension(TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions.ext);
                for (int i = 0; i < tensorsToSegmentationCalculatorOptions.getLabelItemsMap().size(); i++) {
                    Long valueOf = Long.valueOf(i);
                    if (!tensorsToSegmentationCalculatorOptions.getLabelItemsMap().containsKey(valueOf)) {
                        throw new MediaPipeException(MediaPipeException.StatusCode.INTERNAL.ordinal(), "The lablemap have no expected key: " + valueOf);
                    }
                    this.labels.add(tensorsToSegmentationCalculatorOptions.getLabelItemsMap().get(valueOf).getName());
                }
                z = true;
            }
        }
    }

    private ImageSegmenterResult processImageWithRoi(MPImage mPImage, RegionOfInterest regionOfInterest, ImageProcessingOptions imageProcessingOptions) {
        if (this.runningMode != RunningMode.IMAGE) {
            throw new MediaPipeException(MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the image mode. Current running mode:" + this.runningMode.name());
        }
        HashMap hashMap = new HashMap();
        hashMap.put(IMAGE_IN_STREAM_NAME, this.runner.getPacketCreator().createImage(mPImage));
        hashMap.put(ROI_IN_STREAM_NAME, this.runner.getPacketCreator().createProto(convertToRenderData(regionOfInterest)));
        hashMap.put(NORM_RECT_IN_STREAM_NAME, this.runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions, mPImage)));
        return (ImageSegmenterResult) this.runner.process(hashMap);
    }

    private static void validateImageProcessingOptions(ImageProcessingOptions imageProcessingOptions) {
        if (imageProcessingOptions.regionOfInterest().isPresent()) {
            throw new IllegalArgumentException("InteractiveSegmenter doesn't support region-of-interest.");
        }
    }

    List<String> getLabels() {
        return this.labels;
    }

    public ImageSegmenterResult segment(MPImage mPImage, RegionOfInterest regionOfInterest) {
        return segment(mPImage, regionOfInterest, ImageProcessingOptions.builder().build());
    }

    public ImageSegmenterResult segment(MPImage mPImage, RegionOfInterest regionOfInterest, ImageProcessingOptions imageProcessingOptions) {
        if (this.hasResultListener) {
            throw new MediaPipeException(MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "ResultListener is provided in the InteractiveSegmenterOptions, but this method will return an ImageSegmentationResult.");
        }
        validateImageProcessingOptions(imageProcessingOptions);
        return processImageWithRoi(mPImage, regionOfInterest, imageProcessingOptions);
    }

    public void segmentWithResultListener(MPImage mPImage, RegionOfInterest regionOfInterest) {
        segmentWithResultListener(mPImage, regionOfInterest, ImageProcessingOptions.builder().build());
    }

    public void segmentWithResultListener(MPImage mPImage, RegionOfInterest regionOfInterest, ImageProcessingOptions imageProcessingOptions) {
        if (!this.hasResultListener) {
            throw new MediaPipeException(MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "ResultListener is not set in the InteractiveSegmenterOptions, but this method expects a ResultListener to process ImageSegmentationResult.");
        }
        validateImageProcessingOptions(imageProcessingOptions);
        processImageWithRoi(mPImage, regionOfInterest, imageProcessingOptions);
    }
}
