package de.unifreiburg.unet;

import caffe.Caffe;

/* loaded from: input_file:de/unifreiburg/unet/PoolingLayer.class */
public class PoolingLayer extends NetworkLayer {
    private final int[] _kernelShape;
    private final int[] _pad;
    private final int[] _stride;

    public PoolingLayer(Caffe.LayerParameter layerParameter, Net net, CaffeBlob[] caffeBlobArr) throws BlobException {
        super(layerParameter, net, caffeBlobArr);
        Caffe.PoolingParameter poolingParam = layerParameter.getPoolingParam();
        this._kernelShape = new int[caffeBlobArr[0].shape().length - 2];
        for (int i = 0; i < poolingParam.getKernelSizeCount(); i++) {
            this._kernelShape[i] = poolingParam.getKernelSize(i);
        }
        for (int kernelSizeCount = poolingParam.getKernelSizeCount(); kernelSizeCount < this._kernelShape.length; kernelSizeCount++) {
            this._kernelShape[kernelSizeCount] = this._kernelShape[kernelSizeCount - 1];
        }
        this._pad = new int[caffeBlobArr[0].shape().length - 2];
        if (poolingParam.getPadCount() > 0) {
            for (int i2 = 0; i2 < poolingParam.getPadCount(); i2++) {
                this._pad[i2] = poolingParam.getPad(i2);
            }
            for (int padCount = poolingParam.getPadCount(); padCount < this._pad.length; padCount++) {
                this._pad[padCount] = this._pad[padCount - 1];
            }
        } else {
            for (int i3 = 0; i3 < this._pad.length; i3++) {
                this._pad[i3] = 0;
            }
        }
        this._stride = new int[caffeBlobArr[0].shape().length - 2];
        if (poolingParam.getStrideCount() > 0) {
            for (int i4 = 0; i4 < poolingParam.getStrideCount(); i4++) {
                this._stride[i4] = poolingParam.getStride(i4);
            }
            for (int strideCount = poolingParam.getStrideCount(); strideCount < this._stride.length; strideCount++) {
                this._stride[strideCount] = this._stride[strideCount - 1];
            }
        } else {
            for (int i5 = 0; i5 < this._stride.length; i5++) {
                this._stride[i5] = 1;
            }
        }
        long[] jArr = new long[caffeBlobArr[0].shape().length];
        jArr[0] = caffeBlobArr[0].nSamples();
        jArr[1] = caffeBlobArr[0].nChannels();
        int i6 = 0;
        for (int i7 = 0; i7 < this._kernelShape.length; i7++) {
            jArr[i7 + 2] = ((int) Math.ceil(((float) ((caffeBlobArr[0].shape()[i7 + 2] + (2 * this._pad[i7])) - this._kernelShape[i7])) / this._stride[i7])) + 1;
            i6 += this._pad[i7];
        }
        if (i6 > 0) {
            for (int i8 = 0; i8 < this._kernelShape.length; i8++) {
                if ((jArr[i8 + 2] - 1) * this._stride[i8] >= caffeBlobArr[0].shape()[i8 + 2] + this._pad[i8]) {
                    int i9 = i8 + 2;
                    jArr[i9] = jArr[i9] - 1;
                }
                if ((jArr[i8 + 2] - 1) * this._stride[i8] >= caffeBlobArr[0].shape()[i8 + 2] + this._pad[i8]) {
                    throw new BlobException("Invalid pooling parameters given");
                }
            }
        }
        this._out[0] = new CaffeBlob(layerParameter.getTop(0), jArr, this, true, caffeBlobArr[0].gradientRequired());
        for (CaffeBlob caffeBlob : caffeBlobArr) {
            caffeBlob.setOnGPU(true);
        }
    }

    @Override // de.unifreiburg.unet.NetworkLayer
    public String paramString() {
        String str = "kernelShape: [ ";
        for (int i : this._kernelShape) {
            str = str + i + " ";
        }
        String str2 = (str + "]") + " pad: [ ";
        for (int i2 : this._pad) {
            str2 = str2 + i2 + " ";
        }
        String str3 = (str2 + "]") + " stride: [ ";
        for (int i3 : this._stride) {
            str3 = str3 + i3 + " ";
        }
        return str3 + "]";
    }

    @Override // de.unifreiburg.unet.NetworkLayer
    public long memoryOther() {
        return 4 * (this._out[0].count() + (4 * this._kernelShape.length) + this._kernelShape.length + 1);
    }
}
