package org.jpos.tcpay.connection;

import org.apache.commons.lang3.StringUtils;
import org.jpos.iso.ISOException;
import org.jpos.iso.ISOMsg;
import org.jpos.iso.ISOUtil;
import org.jpos.iso.packager.ISO87BPackager;
import org.jpos.tcpay.CustomOutputStream;
import org.jpos.tcpay.db.entity.AcquirerConnection;
import org.jpos.util.AppLogger;
import org.jpos.util.ByteUtil;
import org.jpos.util.LogEvent;
import org.jpos.util.StringParsingUtil;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.util.Objects;

public class YspChannel implements AcquirerChannel{
    Socket acqSocket = null;

    private AppLogger logger = new AppLogger();

    byte[] ADR = {0x30, 0x22};

    public YspChannel() {
        super();
    }

    @Override
    public ISOMsg transceive(ISOMsg req, AcquirerConnection acquirerConnection) throws SocketTimeoutException {
        LogEvent evt = new LogEvent ("transceive");
        ISOMsg  response = new ISOMsg();
        response.setPackager(new ISO87BPackager());
        try {
            if (Objects.isNull(acqSocket) || acqSocket.isClosed()) {
                logger.log("Creating Socket");
                closeSocket();
                acqSocket = getSocket(acquirerConnection);
            }

            int ret = sendReceive(req, response, acquirerConnection);
            if (ret != 0) {
                evt.addMessage( "YSP Transceive Failed " + ret);
                return null;
            }
            evt.addMessage("sendReceive ret = " + ret);

        } catch (SocketTimeoutException e) {
            throw e;
        }
        catch (Exception e) {
            evt.addMessage("Exception : " + e.getMessage());
            evt.addMessage(e);
        } finally {
            logger.log(evt);
        }
        return response;
    }

    private void closeSocket() {
        if (Objects.isNull(acqSocket)) {
            acqSocket = null;
            return;
        }

        try {
            acqSocket.close();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            acqSocket = null;
            logger.log("YSPChannel", "Socket closed");
        }
    }

    private Socket getSocket(AcquirerConnection acquirerConnection) throws Exception {

        if (acquirerConnection.isSsl()) {
            logger.log("SSL Connection");
            KeyStore keyStore = null;
            if (!StringUtils.isEmpty(acquirerConnection.getKeyStorePath())) {
                keyStore = KeyStore.getInstance("JKS");
                keyStore.load(Files.newInputStream(Paths.get(acquirerConnection.getKeyStorePath())), acquirerConnection.getKeyStorePassword().toCharArray());
            }

            KeyStore trustStore = null;
            if (!StringUtils.isEmpty(acquirerConnection.getKeyStorePath())) {
                trustStore = KeyStore.getInstance("JKS");
                trustStore.load(Files.newInputStream(Paths.get(acquirerConnection.getTrustStorePath())), acquirerConnection.getTrustStorePassword().toCharArray());
            }

            // Create TrustManagerFactory and initialize it with the truststore
            TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
            tmf.init(trustStore);

            // Create SSLContext with the default TrustManager and initialize SSLSocketFactory
            SSLContext sslContext = SSLContext.getInstance("TLS");
            sslContext.init(null, tmf.getTrustManagers(), new SecureRandom());
            // Create an SSL socket and connect to the server
            SSLSocketFactory socketFactory = sslContext.getSocketFactory();
            SSLSocket socket = (SSLSocket) socketFactory.createSocket(acquirerConnection.getIpAddress(), acquirerConnection.getPortNumber());
            socket.startHandshake();
            socket.setSoTimeout(15 * 1000);
            acqSocket = socket;
        } else {
            logger.log("Plain Connection " + acquirerConnection.getIpAddress() + ":" + acquirerConnection.getPortNumber());
            acqSocket = new Socket(acquirerConnection.getIpAddress(), acquirerConnection.getPortNumber());
            acqSocket.setSoTimeout(15 * 1000);
        }

        return acqSocket;
    }


    public byte[] getPacketToSend(ISOMsg sendIsoMsg, AcquirerConnection acquirerConnection) throws ISOException {
        byte[] tpduArray = null;


        if(!StringUtils.isEmpty(acquirerConnection.getTpdu())){
            tpduArray = ISOUtil.str2bcd(ISOUtil.padleft(acquirerConnection.getTpdu(), 10, '0'), true);
            sendIsoMsg.setHeader(tpduArray);
        }


        byte[] headerByteArray = new byte[12];
        int headerByteArrayLength = 0;

        byte[] tempSendArray = sendIsoMsg.pack();
        int tempSendPktLength = tempSendArray.length;
        int packetLength = Objects.isNull(tpduArray) ? tempSendPktLength : tempSendPktLength + 5;
        packetLength += 2; // For ADR and CB

        byte[] lengthArray = getLengthByte(packetLength, acquirerConnection.getLengthMode());
        if (Objects.nonNull(lengthArray)) {
            System.arraycopy(lengthArray, 0, headerByteArray, headerByteArrayLength, lengthArray.length);
            headerByteArrayLength += lengthArray.length;
        }

        // Adding ADR and CB
        System.arraycopy(ADR, 0, headerByteArray, headerByteArrayLength, 2);
        headerByteArrayLength += 2;

        if (Objects.nonNull(tpduArray)) {
            System.arraycopy(tpduArray, 0, headerByteArray, headerByteArrayLength, tpduArray.length);
            headerByteArrayLength += tpduArray.length;
        }

        if (headerByteArrayLength > 0) {
            byte[] finalHeaderArray = new byte[headerByteArrayLength];
            System.arraycopy(headerByteArray , 0, finalHeaderArray, 0, headerByteArrayLength);

            byte[] sendByteArray = ByteUtil.concatenate(finalHeaderArray, tempSendArray);
            return sendByteArray;
        } else {
            return tempSendArray;
        }

    }
    public int sendReceive( ISOMsg sendIsoMsg, ISOMsg receiveIsoMsg, AcquirerConnection acquirerConnection) throws SocketTimeoutException {
        LogEvent evt = new LogEvent ("sendReceive");

        boolean isTpdu = (!StringUtils.isEmpty(acquirerConnection.getTpdu()));
        InputStream socketInputStream = null;
        OutputStream socketOutputStream = null;
        int retVal = -1;
        try {

            byte[] sendByteArray = new byte[4096];
            byte[] initReceiveByteArray = new byte[4096];
            byte[] tempReceiveByteArray = new byte[4096];
            byte[] receiveByteArray = new byte[4096];
            int readLength = 0;

            sendByteArray = getPacketToSend(sendIsoMsg, acquirerConnection);
            socketOutputStream = acqSocket.getOutputStream();
            socketInputStream = acqSocket.getInputStream();

            dumpIsoMsg("Send To "+acquirerConnection.getAcquirer().getName(),sendIsoMsg, evt);
            evt.addMessage("SendDump ", StringParsingUtil.hexString(  sendByteArray));
            socketOutputStream.write(sendByteArray);
            socketOutputStream.flush();

            readLength = socketInputStream.read(initReceiveByteArray);
            evt.addMessage("ReadLength " + readLength);
            if (readLength < 2) {
                return -1;
            }

            byte[] adrArray = new byte[2];
            byte[] tpduRecvArray = new byte[5];

            System.arraycopy(initReceiveByteArray, 0, tempReceiveByteArray, 0, readLength);
            System.arraycopy(initReceiveByteArray, 2, adrArray, 0, 2);
            System.arraycopy(initReceiveByteArray, 4, tpduRecvArray, 0, 5);
            evt.addMessage( "Recv Dump: " +StringParsingUtil.hexString(  tempReceiveByteArray).substring(0, readLength * 2) );

            receiveByteArray = ISOUtil.hex2byte(StringParsingUtil.hexString(tempReceiveByteArray)
                    .substring(isTpdu ? 18 : 8, (readLength * 2)));

            receiveIsoMsg.unpack(receiveByteArray);
            receiveIsoMsg.setDirection(ISOMsg.INCOMING);
            receiveIsoMsg.setHeader(tpduRecvArray);
            dumpIsoMsg("Recv From "+acquirerConnection.getAcquirer().getName(), receiveIsoMsg, evt);

            retVal = 0;


        } catch (SocketTimeoutException e) {
            closeSocket();
            // TODO Auto-generated catch block
            evt.addMessage(e.getMessage());
            throw e;
        } catch (SocketException e) {
            closeSocket();
            // TODO Auto-generated catch block
            evt.addMessage(e.getMessage());
            retVal =  1;
        } catch (IOException e) {
            closeSocket();
            // TODO Auto-generated catch block
            evt.addMessage(e.getMessage());
            retVal = 2;
        } catch (ISOException e) {
            // TODO Auto-generated catch block
            evt.addMessage(e.getMessage());
            retVal = 3;
        } finally {
            logger.log(evt);
        }

        return retVal;
    }

    private void dumpIsoMsg(String tag, ISOMsg msg, LogEvent evt) {
        try (CustomOutputStream outputStream = new CustomOutputStream(); PrintStream printStream = new PrintStream(outputStream)) {
            msg.dump(printStream, "\t");
            evt.addMessage(tag, outputStream.toString());
        } catch (IOException e) {
            evt.addMessage("Exception while dumping Isomsg");
            // Not Required
        }
    }

    public byte[] computeHDLCCRC(byte[] data) {
        int crc = 0xFFFF;
        int polynomial = 0x1021;

        for (byte b : data) {
            for (int i = 0; i < 8; i++) {
                boolean bit = ((b >> (7 - i) & 1) == 1);
                boolean c15 = ((crc >> 15 & 1) == 1);
                crc <<= 1;
                if (c15 ^ bit) {
                    crc ^= polynomial;
                }
            }
        }

        // Final XOR
        crc ^= 0xFFFF;

        // Trim to 16-bit
        crc &= 0xFFFF;
        byte[] testCLC = ByteBuffer.allocate(4).putInt(crc).array();
        return new byte[] {testCLC[2], testCLC[3]};
    }


}
