// Denom.org
// Author:  Sergey Novochenko,  Digrol@gmail.com

package org.denom.crypt.blockcipher;

import org.denom.Binary;
import org.denom.Ex;

import static org.denom.Binary.Bin;
import static org.denom.Ex.MUST;

/**
 * Base class for different Block Ciphers.
 */
public abstract class BlockCipher
{
	private int blockSize;

	protected Binary key;

	protected Binary iv0;

	private AlignMode tAlignMode;
	private CryptoMode tCryptMode;
	private Binary tIV;
	private Binary prevIV;

	private int tOperation = 0;
	private final static int OP_ENCRYPTION = 1;
	private final static int OP_DECRYPTION = 2;
	private final static int OP_MAC        = 3;
	private Binary tail;

	// -----------------------------------------------------------------------------------------------------------------
	// for GOST 28147-89, Counter Encryption Mode
	private int N3 = 0;
	private int N4 = 0;
	private static final int GOST_C1 = 0x01010104;
	private static final int GOST_C2 = 0x01010101;

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Create copy of block cipher with same key.
	 */
	public abstract BlockCipher clone();

	public abstract String getAlgName();

	// -----------------------------------------------------------------------------------------------------------------
	public int getBlockSize()
	{
		return blockSize;
	}

	// -----------------------------------------------------------------------------------------------------------------
	public int getKeySize()
	{
		return key.size();
	}

	// -----------------------------------------------------------------------------------------------------------------
	public Binary getKey()
	{
		return key.clone();
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 *	Generate random key. Generated key will be set as current.
 	 */
	public abstract Binary generateKey();

	public abstract void setKey( final Binary key );

	// -----------------------------------------------------------------------------------------------------------------
	public void setKey( String keyHex )
	{
		setKey( Bin( keyHex ) );
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void initialize( int blockSize )
	{
		this.blockSize = blockSize;
		iv0 = new Binary( blockSize );
		tIV = iv0.clone();
		prevIV = iv0.clone();
		tail = Bin().reserve( blockSize );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Encription/Decryption of one data block in place.
	 * @param block [ getBlockSize() ].
	 */
	public abstract void encryptBlock( Binary block );
	public abstract void decryptBlock( Binary block );

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Encryption of plain data.
	 * @param plain
	 * @param cryptMode
	 * @param alignMode - if 'NONE', then data must be aligned on getBlockSize().
	 * @param iv - Initial Vector [ getBlockSize() ] or null (iv = zeroes). Ignored for ECB mode.
	 * @return ciphered data.
	 */
	public Binary encrypt( final Binary data, CryptoMode cryptMode, AlignMode alignMode, Binary iv )
	{
		startEncrypt( cryptMode, alignMode, iv );
		return finish( data );
	}

	// -----------------------------------------------------------------------------------------------------------------
	public Binary decrypt( final Binary data, CryptoMode cryptMode, AlignMode alignMode, Binary iv )
	{
		startDecrypt( cryptMode, alignMode, iv );
		return finish( data );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Calc MAC - ISO 9797-1, algorithm 1 (Encryption in CBC, MAC = last ciphered block).
	 * @param iv - can be null.
	 */
	public Binary calcMAC( final Binary data, AlignMode alignMode, final Binary iv )
	{
		startMAC( alignMode, iv );
		return finish( data );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void initIVAndTail( final Binary iv )
	{
		if( iv != null )
		{
			MUST( iv.size() == blockSize, "Wrong IV size" );
			tIV = iv.clone();
		}
		else
		{
			tIV = iv0.clone();
		}

		tail.resize( 0 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Encryption, Decryption or calcMAC for chunked data.
	 * Start process with one of methods: 'startEncrypt', 'startDecrypt', 'startMAC';
	 * then process some data with 'process'.
	 * call 'finish' for last iteration.
	 * Examples:
	 * 
	 * myAlg.startEncrypt( CryptoMode.CBC, AlignMode.BLOCK, null );
	 * cipheredPart1 = myAlg.process( plainDataPart1 );
	 * cipheredPart2 = myAlg.process( plainDataPart2 );
	 * cipheredPart3 = myAlg.finish( plainDataPart3 );

	 * Binary plain = Bin();
	 * myAlg.startDecrypt( CryptoMode.CFB, AlignMode.NONE, IV );
	 * plain.add( myAlg.process( cipheredPart1 ) );
	 * plain.add( myAlg.process( cipheredPart2 ) );
	 * plain.add( myAlg.finish( cipheredPart3 ) ); // Any size for CFB

	 * myAlg.startMAC( IV );
	 * myAlg.process( dataPart1 );
	 * myAlg.process( dataPart2 );
	 * mac = myAlg.finish( dataPart3 );
	 *
	 * @param iv can be null, then zeroed array will be used if needed by mode.
	 * For ECB mode IV is not used.
	 */
	public void startEncrypt( CryptoMode cryptMode, AlignMode alignMode, final Binary iv )
	{
		tOperation = OP_ENCRYPTION;
		tCryptMode = cryptMode;
		tAlignMode = alignMode;

		initIVAndTail( iv );

		if( cryptMode == CryptoMode.Gost28147CTR )
			startGostCTR();
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void startGostCTR()
	{
		MUST( blockSize == 8, "For GostOFB mode need BlockSize = 8" );
		encryptBlock( tIV );
		N3 = tIV.getIntLE( 0 );
		N4 = tIV.getIntLE( 4 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void calcGostCTRGamma()
	{
		N3 += GOST_C2;
		N4 += GOST_C1;
		if( (N4 < GOST_C1) && (N4 > 0) ) // addition is mod (2**32 - 1)
			N4++;

		tIV.setIntLE( 0, N3 );
		tIV.setIntLE( 4, N4 );
		encryptBlock( tIV );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Process one block for different modes
	 */
	private void encryptBlockModes( Binary block )
	{
		switch( tCryptMode )
		{
			case ECB:
				encryptBlock( block );
				break;

			case CBC:
				block.xor( tIV );
				encryptBlock( block );
				tIV.assign( block );
				break;

			case CFB:
				encryptBlock( tIV );
				block.xor( tIV );
				tIV.assign( block );
				break;

			case OFB:
				encryptBlock( tIV );
				block.xor( tIV );
				break;

			case Gost28147CTR:
				calcGostCTRGamma();
				block.xor( tIV );
				break;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	private Binary updateEncrypt( final Binary plain )
	{
		Binary ciphered = Bin();
		ciphered.reserve( plain.size() + ((tail.size() > 0) ? blockSize : 0) );

		int plainOffset = 0;
		int plainSize = plain.size();
		while( plainOffset < plainSize )
		{
			int sz = Math.min( blockSize - tail.size(), plain.size() - plainOffset );

			tail.add( plain, plainOffset, sz );
			plainOffset += sz;

			if( tail.size() == blockSize )
			{
				encryptBlockModes( tail );
				ciphered.add( tail );
				tail.clear();
			}
		}

		return ciphered;
	}

	// -----------------------------------------------------------------------------------------------------------------
	private Binary finishEncrypt( final Binary plain )
	{
		Binary ciphered = updateEncrypt( plain );

		pad( tail, blockSize, tAlignMode );
		if( tail.size() > 0 )
		{
			if( (tCryptMode == CryptoMode.ECB) || (tCryptMode == CryptoMode.CBC) )
				MUST( tail.size() == blockSize, "Wrong data size for encrypt" );
			encryptBlockModes( tail );
			ciphered.add( tail );
			tail.clear();
		}

		return ciphered;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * See 'startEncrypt' for comments.
	 */
	public void startDecrypt( CryptoMode cryptMode, AlignMode alignMode, final Binary iv )
	{
		tOperation = OP_DECRYPTION;
		tCryptMode = cryptMode;
		tAlignMode = alignMode;

		initIVAndTail( iv );

		if( cryptMode == CryptoMode.Gost28147CTR )
			startGostCTR();
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Process one block for different modes
	 */
	private void decryptBlockModes( Binary block )
	{
		switch( tCryptMode )
		{
			case ECB:
				decryptBlock( block );
				break;

			case CBC:
				prevIV.assign( block );
				decryptBlock( block );
				block.xor( tIV );
				tIV.assign( prevIV );
				break;

			case CFB:
				prevIV.assign( block );
				encryptBlock( tIV );
				block.xor( tIV );
				tIV.assign( prevIV );
				break;

			case OFB:
				encryptBlock( tIV );
				block.xor( tIV );
				break;

			case Gost28147CTR:
				calcGostCTRGamma();
				block.xor( tIV );
				break;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	private Binary updateDecrypt( final Binary ciphered )
	{
		Binary plain = Bin();
		plain.reserve( ciphered.size() + ((tail.size() > 0) ? blockSize : 0) );

		int cipheredOffset = 0;
		int cipheredSize = ciphered.size();
		while( cipheredOffset < cipheredSize )
		{
			int sz = Math.min( blockSize - tail.size(), ciphered.size() - cipheredOffset );

			tail.add( ciphered, cipheredOffset, sz );
			cipheredOffset += sz;

			if( tail.size() == blockSize )
			{
				decryptBlockModes( tail );
				plain.add( tail );
				tail.resize( 0 );
			}
		}

		return plain;
	}

	// -----------------------------------------------------------------------------------------------------------------
	private Binary finishDecrypt( final Binary ciphered )
	{
		Binary plain = updateDecrypt( ciphered );

		if( ((tCryptMode == CryptoMode.CFB) || (tCryptMode == CryptoMode.OFB) || (tCryptMode == CryptoMode.Gost28147CTR))
				&& (tAlignMode == AlignMode.NONE) && (tail.size() > 0) )
		{
			decryptBlockModes( tail );
			plain.add( tail );
			tail.resize( 0 );
		}

		MUST( tail.size() == 0, "Wrong data size for decrypt" );

		unPad( plain, blockSize, tAlignMode );
		return plain;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * See 'startEncrypt' for comments.
	 * ISO 9797-1, algorithm 1 (Encryption in CBC, MAC = last ciphered block)
	 */
	public void startMAC( AlignMode alignMode, final Binary iv )
	{
		tOperation = OP_MAC;
		tCryptMode = CryptoMode.CBC;
		tAlignMode = alignMode;
		initIVAndTail( iv );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void updateMAC( final Binary dataPart )
	{
		int plainOffset = 0;
		int plainSize = dataPart.size();
		while( plainOffset < plainSize )
		{
			int sz = Math.min( blockSize - tail.size(), dataPart.size() - plainOffset );

			tail.add( dataPart, plainOffset, sz );
			plainOffset += sz;

			if( tail.size() == blockSize )
			{
				tail.xor( tIV );
				encryptBlock( tail );
				tIV.assign( tail );
				tail.resize( 0 );
			}
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	private Binary finishMAC( final Binary dataPart )
	{
		updateMAC( dataPart );

		pad( tail, blockSize, tAlignMode );
		if( tail.size() > 0 )
		{
			MUST( tail.size() == blockSize, "Wrong data size for MAC" );
			tail.xor( tIV );
			encryptBlock( tail );
			tIV.assign( tail );
			tail.resize( 0 );
		}

		MUST( tail.size() == 0, "Wrong data size for MAC" );

		return tIV.clone();
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Process part of incoming data, see 'startEncrypt', 'startDecrypt', 'startMAC'.
	 * @return result data chunk, can be empty for MAC calculation and for Enc/Dec if dataPart is too small.
	 * When calulating MAC, 'null' will be returned.
	 */
	public Binary process( final Binary dataPart )
	{
		switch( tOperation )
		{
			case OP_ENCRYPTION:
				return updateEncrypt( dataPart );

			case OP_DECRYPTION:
				return updateDecrypt( dataPart );

			case OP_MAC:
				updateMAC( dataPart );
				return null;

			default:
				throw new Ex( "Crypt operation not started" );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Process last part of incoming data.
	 * @return last blocks or MAC.
	 */
	public Binary finish( final Binary lastDataPart )
	{
		int op = tOperation;
		tOperation = 0;
		switch( op )
		{
			case OP_ENCRYPTION:
				return finishEncrypt( lastDataPart );

			case OP_DECRYPTION:
				return finishDecrypt( lastDataPart );

			case OP_MAC:
				return finishMAC( lastDataPart );

			default:
				throw new Ex( "Crypt operation not started" );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Counter (CTR) mode -- ISO 10116:2006.
	 * Same algorithm for encryption and decryption.
	 * @param 'data' - data to process.
	 * @param SV - Starting Value [<= getBlockSize()]. Allowed less than getBlockSize()  -->  SV will be padded with zeroes.
	 * @param JBytes - size of part (in bytes) <= getBlockSize()
	 * @return C - cryptogram (if 'data' - plain) or plain text (if 'data' - cryptogram)
	 */
	public Binary cryptCTR( final Binary data, final Binary SV, int JBytes )
	{
		MUST( (JBytes > 0) && (JBytes <= blockSize) && (SV.size() <= blockSize), "Wrong params for CTR" );

		Binary Qi = SV.clone();
		Qi.resize( blockSize );
		Binary Ei = Bin( blockSize );
		Binary C = Bin( data.size() );
		byte[] CArr = C.getDataRef();
		byte[] dataArr = data.getDataRef();
		byte[] EArr = Ei.getDataRef();

		int PSize = data.size();
		for( int offset = 0; offset < PSize; offset += JBytes )
		{
			Ei.assign( Qi );
			encryptBlock( Ei );

			// Ci = Ei xor Pi
			int partSize = Math.min( JBytes, PSize - offset );
			for( int i = 0; i < partSize; ++i )
			{
				CArr[ offset + i ] = (byte)(dataArr[ offset + i ] ^ EArr[ i ]);
			}

			Qi.increment();
		}

		return C;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Comfort: JBytes = BlockSize.
	 * @param data - data to process.
	 * @param SV - Starting Value [getBlockSize()]
	 */
	public Binary cryptCTR( final Binary data, final Binary SV )
	{
		return cryptCTR( data, SV, getBlockSize() );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void CMAC_SHL( Binary K )
	{
		byte[] k = K.getDataRef();
		int kSize = K.size();

		byte nextBit = 0;
		for( int i = kSize - 1; i >= 0; --i )
		{
			int tmp = k[ i ] & 0xFF;
			k[ i ] = (byte)((tmp << 1) | nextBit);
			nextBit = (byte)(tmp >> 7);
		}
		k[ kSize - 1 ] ^= (byte)(0x87 & (-nextBit));
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * SP 800-38B. Recommendation for Block Cipher Modes of Operation: the CMAC Mode for Authentication.
	 * https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-38B.pdf .
	 * The Key must already be set.
	 * @param data
	 * @param iv - IV or null.
	 * @return CMAC [blockSize bytes].
	 */
	public Binary calcCMAC( Binary data, Binary iv )
	{
		Binary res;
		if( iv != null )
		{
			MUST( iv.size() == blockSize, "Wrong IV length" );
			res = iv.clone();
		}
		else
		{
			res = iv0.clone();
		}

		byte[] resArr = res.getDataRef();
		byte[] dataArr = data.getDataRef();

		// number of blocks including any final partial block
		int numDataBlocks = ( data.size() + blockSize - 1 ) / blockSize;

		// process all data blocks (encrypt in CBC mode), except last block
		int offset = 0;
		for( int i = 1; i < numDataBlocks; i++ )
		{
			for( int j = 0; j < blockSize; ++j, ++offset )
				resArr[ j ] ^= dataArr[ offset ];

			encryptBlock( res );
		}

		// last part
		int restSize = data.size() - offset;

		Binary K = Bin( blockSize );
		encryptBlock( K );
		CMAC_SHL( K );

		// XOR with last data part
		int j = 0;
		while( j < restSize )
			resArr[ j++ ] ^= dataArr[ offset++ ];

		if( restSize != blockSize )
		{
			resArr[ j ] ^= 0x80; // padding
			CMAC_SHL( K );
		}

		// XOR with K1 or K2
		res.xor( K );
		encryptBlock( res );

		return res;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Выровнять данные до длины, кратной размеру блока.
	 * Модифицируется входной массив.
	 * @param data - Входные данные.
	 * @param blockSize - Размер блока.
	 * @param alignMode - Режим выравнивания.
	 */
	public static void pad( Binary data, int blockSize, AlignMode alignMode )
	{
		switch( alignMode )
		{
			case NONE:
				break;

			case BLOCK:
				MUST( blockSize > 0, "Wrong data alignment" );
				int padLen = blockSize - (data.size() & (blockSize - 1));
				data.reserve( data.size() + padLen );
				data.add( 0x80 );
				--padLen;
				for( int i = 0; i < padLen; ++i )
				{
					data.add( 0x00 );
				}
				break;

			default:
				throw new Ex( "Wrong data alignment" );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * Убрать выравнивание.
	 * Модифицируется входной массив.
	 * @param data - Входной Binary.
	 * @param alignMode - Режим выравнивания.
	 */
	public static void unPad( Binary data, int blockSize, AlignMode alignMode )
	{
		switch( alignMode )
		{
			case NONE:
				break;

			case BLOCK:
				MUST( blockSize > 0, "Wrong data alignment" );
				MUST( !data.empty(), "Wrong data alignment" );
				MUST( (data.size() & (blockSize - 1)) == 0, 
					"Wrong data alignment" );
				
				int i = data.size();
				do
				{
					--i;
				}
				while( (i > 0) && (data.get( i ) == 0x00) );
				
				MUST( (data.get( i ) == 0x80) && ((data.size() - i) <= blockSize), "Wrong data alignment" );
				
				data.resize( i );
				break;

			default:
				throw new Ex( "Wrong data alignment" );
		}
	}
}
