/*
 * Decompiled with CFR 0.152.
 */
package org.denom.scp81;

import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import org.denom.Arr;
import org.denom.Binary;
import org.denom.Ex;
import org.denom.log.ILog;
import org.denom.log.LogDummy;
import org.denom.scp81.TlsConst;
import org.denom.scp81.TlsHelloClient;
import org.denom.scp81.TlsHelloServer;
import org.denom.scp81.TlsPSKCrypt;
import org.denom.scp81.http.HttpReq;
import org.denom.scp81.http.HttpResp;

public class TlsPSKSession {
    protected static final int COLOR_BLUE1 = -12541697;
    protected static final int COLOR_BLUE2 = -13189889;
    protected static final int COLOR_BLUE3 = -15666945;
    protected static final int COLOR_PINK1 = -1027851;
    protected static final int COLOR_PINK2 = -1011472;
    protected static final int COLOR_PINK3 = -16129;
    protected int protocolVersion = 771;
    protected int cipherSuite = 0;
    protected Binary clientRandom = null;
    protected Binary serverRandom = null;
    protected boolean isEncryptThenMAC = false;
    protected boolean isTruncateHMac = false;
    protected int plainLimit;
    protected int cryptLimit = this.plainLimit = 16384;
    protected Binary handshakeMessages = Binary.Bin().reserve(256);
    protected Binary identity;
    protected Binary psk;
    private TlsPSKCrypt encoder = null;
    private TlsPSKCrypt decoder = null;
    protected boolean handshakeDone = false;
    protected Function<Binary, Binary> funcGetPSK;
    protected Consumer<Binary> funcSendData;
    protected Consumer<Binary> funcRecievedAppData;
    protected ILog log = new LogDummy();
    protected static final int STATE_WAIT_HELLO = 1;
    protected static final int STATE_WAIT_KEY_EXCHANGE = 2;
    protected static final int STATE_WAIT_HELLO_DONE = 3;
    protected static final int STATE_WAIT_CHANGE_CIPHER = 10;
    protected static final int STATE_WAIT_FINISHED = 11;
    protected static final int STATE_HANDSHAKE_DONE = 20;
    protected static final int STATE_ERROR = -1;
    protected int state;
    protected Binary inBuf = Binary.Bin().reserve(16389);
    protected Binary recordBody = Binary.Bin();
    protected Binary bufHandshake = Binary.Bin();
    protected boolean isServer;
    protected Map<Integer, Binary> extensions = new HashMap<Integer, Binary>();
    protected Arr<Integer> cipherSuites = new Arr();
    protected int colorRecv1;
    protected int colorRecv2;
    protected int colorRecv3;
    protected int colorSend1;
    protected int colorSend2;
    protected int colorSend3;

    public TlsPSKSession(boolean isServer) {
        this.isServer = isServer;
    }

    public TlsPSKSession setLog(ILog log) {
        Ex.MUST(log != null, "Log must be not null");
        this.log = log;
        if (this.isServer) {
            this.colorSend1 = -12541697;
            this.colorSend2 = -13189889;
            this.colorSend3 = -15666945;
            this.colorRecv1 = -1027851;
            this.colorRecv2 = -1011472;
            this.colorRecv3 = -16129;
        } else {
            this.colorSend1 = -1027851;
            this.colorSend2 = -1011472;
            this.colorSend3 = -16129;
            this.colorRecv1 = -12541697;
            this.colorRecv2 = -13189889;
            this.colorRecv3 = -15666945;
        }
        return this;
    }

    public TlsPSKSession addCipherSuite(int cipherSuite) {
        this.cipherSuites.add(cipherSuite);
        return this;
    }

    public TlsPSKSession addExtension(int extensionType, Binary extData) {
        this.extensions.put(extensionType, extData);
        return this;
    }

    public TlsPSKSession setGetPSK(Function<Binary, Binary> funcGetPSK) {
        Objects.requireNonNull(funcGetPSK);
        this.funcGetPSK = funcGetPSK;
        return this;
    }

    public TlsPSKSession setSendData(Consumer<Binary> funcSendData) {
        Objects.requireNonNull(funcSendData);
        this.funcSendData = funcSendData;
        return this;
    }

    public TlsPSKSession setRecievedAppData(Consumer<Binary> funcRecievedAppData) {
        Objects.requireNonNull(funcRecievedAppData);
        this.funcRecievedAppData = funcRecievedAppData;
        return this;
    }

    public void clearHandshakeData() {
        this.clientRandom = null;
        this.serverRandom = null;
        this.cipherSuite = 0;
        this.isEncryptThenMAC = false;
        this.isTruncateHMac = false;
        this.cryptLimit = this.plainLimit = 16384;
        this.handshakeMessages = Binary.Bin();
        this.encoder = null;
        this.decoder = null;
    }

    public boolean isHandshakeDone() {
        return this.handshakeDone;
    }

    public void sendAppData(Binary data) {
        Ex.MUST(this.handshakeDone, "Cannot write application data until handshake completed.");
        int offset = 0;
        int len = data.size();
        while (len > 0) {
            int partSize = Math.min(len, this.plainLimit);
            this.sendRecord(23, data.slice(offset, partSize));
            offset += partSize;
            len -= partSize;
        }
    }

    public void sendHttpReq(HttpReq req) {
        this.log.writeln(-7303024, "HTTP Request:");
        this.log.writeln(-16352, req.toString());
        this.sendAppData(req.toBin());
    }

    public void sendHttpResp(HttpResp resp) {
        this.log.writeln(-7303024, "HTTP Response:");
        this.log.writeln(-14614592, resp.toString());
        this.sendAppData(resp.toBin());
    }

    public void sendCloseNotify() {
        try {
            this.sendRecord(21, Binary.Bin(1, 1).add(0));
        }
        catch (Throwable throwable) {
            // empty catch block
        }
    }

    protected void addHandshakeMessage(Binary handshakeMsg) {
        this.handshakeMessages.add(handshakeMsg);
    }

    protected void gotClientHello(TlsHelloClient clientHello) {
        this.protocolVersion = clientHello.protocol;
        this.clientRandom = clientHello.random;
    }

    public void gotServerHello(TlsHelloServer serverHello) {
        this.protocolVersion = serverHello.protocol;
        this.serverRandom = serverHello.random;
        this.cipherSuite = serverHello.cipherSuite;
        for (Map.Entry<Integer, Binary> entry : serverHello.extensions.entrySet()) {
            Binary data = entry.getValue();
            switch (entry.getKey()) {
                case 1: {
                    Ex.MUST(data.size() == 1, 47, "Wrong extension");
                    int maxFragmLen = data.get(0);
                    Ex.MUST(maxFragmLen >= 1 && maxFragmLen <= 4, 47);
                    this.plainLimit = 1 << 8 + maxFragmLen;
                    break;
                }
                case 22: {
                    Ex.MUST(data.size() == 0, 47);
                    this.isEncryptThenMAC = true;
                    break;
                }
                case 4: {
                    Ex.MUST(data.size() == 0, 47);
                    this.isTruncateHMac = true;
                }
            }
        }
    }

    public void initEncrypter(boolean isServerKeys) {
        Ex.MUST(this.psk != null, "TLS PSK Session: psk not set");
        this.encoder = TlsPSKCrypt.createEncoder(this, isServerKeys);
    }

    public void initDecrypter(boolean isServerKeys) {
        Ex.MUST(this.psk != null, "TLS PSK Session: psk not set");
        this.decoder = TlsPSKCrypt.createDecoder(this, isServerKeys);
    }

    public Binary calcVerifyData(boolean isFromServer) {
        TlsPSKCrypt tlsCrypt = this.decoder != null ? this.decoder : this.encoder;
        return tlsCrypt.calcVerifyData(isFromServer);
    }

    public Binary encodeRecord(int contentType, Binary recordBody) {
        if (this.encoder == null) {
            return recordBody;
        }
        return this.encoder.encode(contentType, recordBody);
    }

    public Binary decodeRecord(int recordType, Binary encryptedBody) {
        if (this.decoder == null) {
            return encryptedBody;
        }
        return this.decoder.decode(recordType, encryptedBody);
    }

    public void appendIncomingData(Binary data) {
        int offset = 0;
        while (offset < data.size()) {
            int partSize = Math.min(this.cryptLimit + 5, data.size() - offset);
            this.inBuf.add(data, offset, partSize);
            offset += partSize;
            while (this.processIncomingData()) {
            }
        }
    }

    private boolean processIncomingData() {
        int fullLen;
        int fragmentLength;
        int tlsVersion;
        int recordType;
        block12: {
            block11: {
                try {
                    if (this.inBuf.size() >= 5) break block11;
                    return false;
                }
                catch (Ex ex) {
                    this.sendFatal(ex.code);
                    this.state = -1;
                    throw new Ex("Wrong TLS");
                }
            }
            recordType = this.inBuf.get(0);
            tlsVersion = this.inBuf.getU16(1);
            fragmentLength = this.inBuf.getU16(3);
            Ex.MUST(fragmentLength <= this.cryptLimit, 22);
            fullLen = 5 + fragmentLength;
            if (this.inBuf.size() >= fullLen) break block12;
            return false;
        }
        this.recordBody.assign(this.inBuf, 5, fragmentLength);
        this.log.writeln(-7303024, "---------------------------------");
        this.log.writeln(this.colorRecv1, String.format(Locale.US, "recv [%4d] : %02X %04X %04X %s", fullLen, recordType, tlsVersion, fragmentLength, this.recordBody.Hex()));
        this.inBuf.assign(this.inBuf, fullLen, this.inBuf.size() - fullLen);
        this.recordBody = this.decodeRecord(recordType, this.recordBody);
        if (this.handshakeDone) {
            this.log.writeln(this.colorRecv2, "    Decrypted: " + this.recordBody.Hex());
        }
        Ex.MUST(this.recordBody.size() <= this.plainLimit, 22);
        switch (recordType) {
            case 21: {
                this.processRecordAlert(this.recordBody);
                break;
            }
            case 22: {
                this.processRecordHandshake(this.recordBody);
                break;
            }
            case 23: {
                Ex.MUST(this.handshakeDone, 10);
                this.funcRecievedAppData.accept(this.recordBody);
                break;
            }
            case 20: {
                this.processChangeCipher(this.recordBody);
                break;
            }
            default: {
                throw new Ex(10);
            }
        }
        return true;
    }

    protected void onHandshakeMessage(int type, Binary message, Binary data) {
    }

    protected void processRecordHandshake(Binary recBody) {
        Ex.MUST(recBody.size() > 0, 50);
        this.bufHandshake.add(recBody);
        while (this.bufHandshake.size() >= 4) {
            int type = this.bufHandshake.get(0);
            int dataLen = this.bufHandshake.getU24(1);
            Ex.MUST(dataLen <= 32768, 22);
            int fullLen = 4 + dataLen;
            if (this.bufHandshake.size() < fullLen) break;
            Binary message = this.bufHandshake.first(fullLen);
            Binary data = message.slice(4, message.size() - 4);
            this.bufHandshake.assign(this.bufHandshake, fullLen, this.bufHandshake.size() - fullLen);
            this.log.writeln(this.colorRecv2, String.format("    Handshake msg:  type: %02X (%s), data:  %s", type, TlsConst.HandshakeType.toStr(type), message.last(message.size() - 4).Hex()));
            this.onHandshakeMessage(type, message, data);
        }
    }

    protected void processChangeCipher(Binary recBody) {
        this.log.writeln(this.colorRecv2, "    Change Cipher Spec");
        Ex.MUST(recBody.size() == 1, 50);
        Ex.MUST(recBody.get(0) == 1, 50);
        Ex.MUST(this.bufHandshake.size() == 0, 10);
        Ex.MUST(this.state == 10, 10);
        this.initDecrypter(!this.isServer);
        this.state = 11;
    }

    protected void processRecordAlert(Binary recBody) {
        Ex.MUST(recBody.size() == 2, 50);
        this.log.writeln(this.colorRecv2, String.format("ALERT:  type:  %02X,  description:  %02X (%s)", recBody.get(0), recBody.get(1), TlsConst.Alert.toStr(recBody.get(1))));
        throw new Ex(recBody.asU16(), "Alert recieved");
    }

    protected void sendRecord(int recordType, Binary recordData) {
        Ex.MUST(recordData.size() <= this.plainLimit);
        Binary cipheredData = this.encodeRecord(recordType, recordData);
        Binary rec = new Binary().reserve(recordData.size() + 5 + 16);
        rec.add(recordType);
        rec.addU16(this.protocolVersion);
        rec.addU16(cipheredData.size());
        rec.add(cipheredData);
        this.log.writeln(this.colorSend1, String.format(Locale.US, "send [%4d] : %02X %04X %04X %s\n", rec.size(), recordType, this.protocolVersion, cipheredData.size(), cipheredData.Hex()));
        Ex.MUST(this.funcSendData != null, "funcSendData == null");
        this.funcSendData.accept(rec);
    }

    protected void finishHandshake() {
        this.log.writeln(this.colorSend2, "    Change Cipher Spec");
        this.sendRecord(20, Binary.Bin(1, 1));
        this.initEncrypter(this.isServer);
        this.sendHandshakeMsg(20, this.calcVerifyData(this.isServer));
    }

    protected void sendHandshakeMsg(int msgType, Binary body) {
        Binary b = new Binary().reserve(64);
        b.add(msgType);
        b.addU24(body.size());
        b.add(body);
        this.log.writeln(this.colorSend2, String.format("    Handshake msg:  type: %02X (%s), data:  %s", msgType, TlsConst.HandshakeType.toStr(msgType), body.Hex()));
        this.addHandshakeMessage(b);
        this.sendRecord(22, b);
    }

    private void sendFatal(int alert) {
        try {
            this.log.writeln(this.colorSend2, String.format("FATAL ALERT:  description:  %02X (%s)", alert, TlsConst.Alert.toStr(alert)));
            this.sendRecord(21, Binary.Bin(1, 2).add(alert));
        }
        catch (Throwable throwable) {
            // empty catch block
        }
    }
}

