package de.unifreiburg.unet;

import caffe.Caffe;
import java.util.Arrays;
import java.util.Iterator;
import java.util.UUID;
import java.util.Vector;

/* loaded from: input_file:de/unifreiburg/unet/Net.class */
public class Net {
    private Vector<NetworkLayer> _layers = new Vector<>();
    private Vector<CaffeBlob> _blobs = new Vector<>();
    private Vector<CaffeBlob> _outputBlobs = new Vector<>();
    private final Caffe.Phase _phase;

    public Net(Caffe.Phase phase) {
        this._phase = phase;
    }

    public static Net createFromProto(Caffe.NetParameter netParameter, String[] strArr, long[][] jArr, Caffe.Phase phase) throws NotImplementedException, BlobException {
        Net net = new Net(phase);
        Caffe.InputParameter.Builder newBuilder = Caffe.InputParameter.newBuilder();
        for (long[] jArr2 : jArr) {
            Caffe.BlobShape.Builder newBuilder2 = Caffe.BlobShape.newBuilder();
            for (long j : jArr2) {
                newBuilder2.addDim(j);
            }
            newBuilder.addShape(newBuilder2);
        }
        if (strArr != null) {
            Caffe.LayerParameter.Builder newBuilder3 = Caffe.LayerParameter.newBuilder();
            newBuilder3.setType("Input");
            newBuilder3.setName(UUID.randomUUID().toString());
            newBuilder3.setInputParam(newBuilder);
            for (String str : strArr) {
                newBuilder3.addTop(str);
            }
            net.addLayer(new DataLayer(newBuilder3.m1064build(), net));
        }
        boolean z = false;
        while (!z) {
            z = true;
            for (Caffe.LayerParameter layerParameter : netParameter.getLayerList()) {
                if (net.findLayer(layerParameter.getName()) == null) {
                    boolean z2 = true;
                    Iterator<Caffe.NetStateRule> it = layerParameter.getExcludeList().iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        Caffe.NetStateRule next = it.next();
                        if (next.hasPhase() && next.getPhase().equals(phase)) {
                            z2 = false;
                            break;
                        }
                    }
                    if (z2) {
                        Iterator<Caffe.NetStateRule> it2 = layerParameter.getIncludeList().iterator();
                        while (true) {
                            if (!it2.hasNext()) {
                                break;
                            }
                            Caffe.NetStateRule next2 = it2.next();
                            if (next2.hasPhase() && !next2.getPhase().equals(phase)) {
                                z2 = false;
                            }
                            if (next2.hasPhase() && next2.getPhase().equals(phase)) {
                                z2 = true;
                                break;
                            }
                        }
                        if (z2) {
                            if (layerParameter.getType().equals("HDF5Data")) {
                                if (layerParameter.getTopCount() > jArr.length) {
                                    for (int length = jArr.length; length < layerParameter.getTopCount(); length++) {
                                        newBuilder.addShape(newBuilder.getShape(newBuilder.getShapeCount() - 1));
                                    }
                                }
                                Caffe.LayerParameter.Builder newBuilder4 = Caffe.LayerParameter.newBuilder();
                                newBuilder4.setType("Input");
                                newBuilder4.setName(layerParameter.getName());
                                newBuilder4.setInputParam(newBuilder);
                                Iterator it3 = layerParameter.getTopList().iterator();
                                while (it3.hasNext()) {
                                    newBuilder4.addTop((String) it3.next());
                                }
                                net.addLayer(new DataLayer(newBuilder4.m1064build(), net));
                                z = false;
                            } else {
                                CaffeBlob[] caffeBlobArr = null;
                                if (layerParameter.getBottomCount() > 0) {
                                    caffeBlobArr = new CaffeBlob[layerParameter.getBottomCount()];
                                    int i = 0;
                                    while (i < layerParameter.getBottomCount()) {
                                        caffeBlobArr[i] = net.findBlob(layerParameter.getBottom(i));
                                        if (caffeBlobArr[i] == null) {
                                            break;
                                        }
                                        i++;
                                    }
                                    if (i < layerParameter.getBottomCount()) {
                                    }
                                }
                                NetworkLayer createFromProto = NetworkLayer.createFromProto(layerParameter, net, caffeBlobArr);
                                if ((createFromProto instanceof CreateDeformationLayer) && jArr != null && jArr.length > 0) {
                                    createFromProto.outputBlobs()[0].shape()[1] = jArr[0][2];
                                    createFromProto.outputBlobs()[0].shape()[2] = jArr[0][3];
                                    if (jArr[0].length == 5) {
                                        createFromProto.outputBlobs()[0].shape()[3] = jArr[0][4];
                                    }
                                }
                                net.addLayer(createFromProto);
                                z = false;
                            }
                        }
                    }
                }
            }
        }
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        Iterator<NetworkLayer> it4 = net.layers().iterator();
        while (it4.hasNext()) {
            NetworkLayer next3 = it4.next();
            if (next3.outputBlobs() != null) {
                for (int i2 = 0; i2 < next3.outputBlobs().length; i2++) {
                    CaffeBlob caffeBlob = next3.outputBlobs()[i2];
                    if (!vector2.contains(caffeBlob)) {
                        int i3 = 0;
                        Iterator<NetworkLayer> it5 = net.layers().iterator();
                        while (it5.hasNext()) {
                            NetworkLayer next4 = it5.next();
                            if (next3 != next4 && next4.inputBlobs() != null && Arrays.asList(next4.inputBlobs()).contains(caffeBlob) && (next4.outputBlobs() == null || !Arrays.asList(next4.outputBlobs()).contains(caffeBlob))) {
                                i3++;
                            }
                        }
                        if (i3 > 1) {
                            Caffe.LayerParameter.Builder newBuilder5 = Caffe.LayerParameter.newBuilder();
                            newBuilder5.setType("Split");
                            String str2 = caffeBlob.name() + "_" + next3.name() + "_" + i2 + "_split";
                            newBuilder5.setName(str2);
                            for (int i4 = 0; i4 < i3; i4++) {
                                newBuilder5.addTop(str2 + "_" + i4);
                            }
                            vector.add(new SplitLayer(newBuilder5.m1064build(), net, new CaffeBlob[]{caffeBlob}));
                            vector2.add(caffeBlob);
                        }
                    }
                }
            }
        }
        Iterator it6 = vector.iterator();
        while (it6.hasNext()) {
            net.addLayer((NetworkLayer) it6.next(), true);
        }
        return net;
    }

    public void addLayer(NetworkLayer networkLayer) {
        addLayer(networkLayer, false);
    }

    public void addLayer(NetworkLayer networkLayer, boolean z) {
        this._layers.add(networkLayer);
        if (networkLayer.inputBlobs() != null) {
            for (CaffeBlob caffeBlob : networkLayer.inputBlobs()) {
                if (this._outputBlobs.contains(caffeBlob)) {
                    this._outputBlobs.remove(caffeBlob);
                }
            }
        }
        if (networkLayer.outputBlobs() != null) {
            for (CaffeBlob caffeBlob2 : networkLayer.outputBlobs()) {
                if (!this._outputBlobs.contains(caffeBlob2) && !z) {
                    this._outputBlobs.add(caffeBlob2);
                }
                if (!this._blobs.contains(caffeBlob2)) {
                    this._blobs.add(caffeBlob2);
                }
            }
        }
    }

    public Vector<NetworkLayer> layers() {
        return this._layers;
    }

    public NetworkLayer findLayer(String str) {
        Iterator<NetworkLayer> it = this._layers.iterator();
        while (it.hasNext()) {
            NetworkLayer next = it.next();
            if (next.name().equals(str)) {
                return next;
            }
        }
        return null;
    }

    public CaffeBlob findBlob(String str) {
        Iterator<CaffeBlob> it = this._blobs.iterator();
        while (it.hasNext()) {
            CaffeBlob next = it.next();
            if (next.name().equals(str)) {
                return next;
            }
        }
        return null;
    }

    public Caffe.Phase phase() {
        return this._phase;
    }

    public String toString() {
        String str = "Net (phase = " + (this._phase.equals(Caffe.Phase.TRAIN) ? "TRAIN" : "TEST") + ") { \n";
        Iterator<NetworkLayer> it = this._layers.iterator();
        while (it.hasNext()) {
            str = str + "  " + it.next() + "\n";
        }
        return str + "}";
    }

    public CaffeBlob[] outputBlobs() {
        return (CaffeBlob[]) this._outputBlobs.toArray(new CaffeBlob[this._outputBlobs.size()]);
    }

    public long memoryParameters() {
        long j = 0;
        Iterator<NetworkLayer> it = this._layers.iterator();
        while (it.hasNext()) {
            j += it.next().memoryParameters();
        }
        return j;
    }

    public long memorySolver() {
        if (this._phase.equals(Caffe.Phase.TRAIN)) {
            return 3 * memoryParameters();
        }
        return 0L;
    }

    public long memoryOther() {
        long j = 0;
        Iterator<NetworkLayer> it = this._layers.iterator();
        while (it.hasNext()) {
            j += it.next().memoryOther();
        }
        return j;
    }

    public long memoryOverhead(boolean z) {
        long j = 0;
        Iterator<NetworkLayer> it = this._layers.iterator();
        while (it.hasNext()) {
            j += it.next().memoryOverhead(z);
        }
        return j;
    }

    public long memoryBlobsForward() {
        long j = 0;
        Iterator<CaffeBlob> it = this._blobs.iterator();
        while (it.hasNext()) {
            j += it.next().memoryForward();
        }
        return j;
    }

    public long memoryBlobsBackward() {
        if (!this._phase.equals(Caffe.Phase.TRAIN)) {
            return 0L;
        }
        long j = 0;
        Iterator<CaffeBlob> it = this._blobs.iterator();
        while (it.hasNext()) {
            j += it.next().memoryBackward();
        }
        return j;
    }

    public long memoryTotal(boolean z) {
        return memoryParameters() + memoryOther() + memoryOverhead(z) + memoryBlobsForward() + memoryBlobsBackward() + memorySolver();
    }

    public long memoryTotalWithValidation(boolean z) {
        return memoryTotal(z) + memoryOther() + memoryBlobsForward();
    }

    public void printMemoryBreakdown(boolean z) {
        System.out.println("Total memory used (" + (z ? "" : "no ") + "cuDNN) = " + ((memoryTotal(z) / 1024) / 1024) + " MB <= " + ((memoryParameters() / 1024) / 1024) + " MB (param) + " + ((memoryBlobsForward() / 1024) / 1024) + " MB (data) + " + ((memoryBlobsBackward() / 1024) / 1024) + " MB (gradient) + " + ((memoryOverhead(z) / 1024) / 1024) + " MB (conv) + " + ((memoryOther() / 1024) / 1024) + " MB (other) + " + ((memorySolver() / 1024) / 1024) + " MB (solver)");
        if (this._phase.equals(Caffe.Phase.TRAIN)) {
            System.out.println("  Training with validation set requires " + ((memoryTotalWithValidation(z) / 1024) / 1024) + " MB");
        }
    }
}
