package testblockcipher;

import org.denom.Binary;
import org.denom.Ex;
import org.denom.Ticker;
import org.denom.log.LogConsole;
import org.denom.crypt.blockcipher.*;

import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import static org.denom.Binary.Bin;
import static org.denom.Ex.MUST;

class CheckBlockCipher
{
	LogConsole log = new LogConsole();

	int ITERATIONS = 1000;
	int DATA_SIZE = 4096;

	// -----------------------------------------------------------------------------------------------------------------
	static void check( BlockCipher alg, CryptoMode mode, AlignMode alignMode, String keyHex, Binary IV, String dataHex, String cryptHex )
	{
		alg.setKey( keyHex );
		Binary crypt = alg.encrypt( Bin( dataHex ), mode, alignMode, IV );
		Binary data = alg.decrypt( crypt, mode, alignMode, IV );
		MUST( crypt.equals( cryptHex ), "Wrong cryptogram" );
		MUST( data.equals( dataHex ), "Wrong plain data" );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void check( BlockCipher alg, String keyHex, String dataHex, String cryptHex )
	{
		check( alg, CryptoMode.ECB, AlignMode.NONE, keyHex, null, dataHex, cryptHex );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void checkMonteCarlo( BlockCipher alg, int iterations, String keyHex, String dataHex, String cryptHex )
	{
		alg.setKey( keyHex );
		
		Binary block = Bin( dataHex );
		for( int i = 0; i < iterations; ++i )
		{
			alg.encryptBlock( block );
		}
		MUST( block.equals( cryptHex ) );

		for( int i = 0; i < iterations; ++i )
		{
			alg.decryptBlock( block );
		}
		MUST( block.equals( dataHex ) );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void checkMode( BlockCipher alg, CryptoMode cryptMode, AlignMode alignMode, int dataSize )
	{
		int blockSize = alg.getBlockSize();

		int i = 0;

		// Different data length for CFB and OFB and for alignMode != NONE
		if( ((cryptMode == CryptoMode.ECB) || (cryptMode == CryptoMode.CBC)) && (alignMode == AlignMode.NONE) )
			i = blockSize;
		
		for( ; i <= blockSize; ++i )
		{
			Binary key = Bin().random( alg.getKeySize() );
			alg.setKey( key );

			Binary plain = Bin().random( dataSize + i );
			Binary iv = Bin().random( blockSize );

			// ---------------------------------------------------------------------------------
			// Encrypt
			// ---------------------------------------------------------------------------------
			Binary crypt = Bin();
			alg.startEncrypt( cryptMode, alignMode, iv );
			int offset = 0;
			for( int partSize = 0; partSize < (32 + i); ++partSize )
			{
				crypt.add( alg.process( plain.slice( offset, partSize ) ) );
				offset += partSize;
			}
			crypt.add( alg.finish( plain.slice( offset, plain.size() - offset ) ) );

			Binary crypt2 = alg.encrypt( plain, cryptMode, alignMode, iv );
			MUST( crypt2.equals( crypt ), "Wrong encryption by chunks" );

			// ---------------------------------------------------------------------------------
			// Decrypt
			// ---------------------------------------------------------------------------------
			Binary plain2 = Bin();
			alg.startDecrypt( cryptMode, alignMode, iv );
			offset = 0;
			for( int partSize = 0; partSize < (32 + i); ++partSize )
			{
				plain2.add( alg.process( crypt.slice( offset, partSize ) ) );
				offset += partSize;
			}
			plain2.add( alg.finish( crypt.slice( offset, crypt.size() - offset ) ) );

			MUST( plain2.equals( plain ), "Wrong decryption by chunks" );

			Binary plain3 = alg.decrypt( crypt, cryptMode, alignMode, iv );
			MUST( plain3.equals( plain ), "Wrong decryption" );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void checkMode( BlockCipher alg, CryptoMode cryptMode, AlignMode alignMode )
	{
		checkMode( alg, cryptMode, alignMode, 2048 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void checkMAC( BlockCipher alg, AlignMode alignMode )
	{
		int blockSize = alg.getBlockSize();
		for( int i = 0; i < 16; ++i )
		{
			Binary key = Bin().random( alg.getKeySize() );
			alg.setKey( key );

			Binary plain = Bin().random( 1500 + i );
			if( alignMode == AlignMode.NONE )
				plain.resize( plain.size() / blockSize * blockSize );

			Binary iv = Bin().random( blockSize );

			alg.startMAC( alignMode, iv );
			int offset = 0;
			for( int partSize = 0; partSize < (32 + i); ++partSize )
			{
				alg.process( plain.slice( offset, partSize ) );
				offset += partSize;
			}
			Binary mac = alg.finish( plain.slice( offset, plain.size() - offset ) );

			Binary mac2 = alg.calcMAC( plain, alignMode, iv );
			MUST( mac2.equals( mac ), "Wrong MAC by chunks" );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	void checkCTR( BlockCipher alg, String key, String SV, String data, String crypt )
	{
		checkCTR( alg, alg.getBlockSize(), key, SV, data, crypt );
	}

	// -----------------------------------------------------------------------------------------------------------------
	void checkCTR( BlockCipher alg, int jBytes, String key, String SV, String data, String crypt )
	{
		alg.setKey( key );
		MUST( alg.cryptCTR( Bin(data), Bin(SV), jBytes ).equals( crypt ), "Wrong CTR" );
		MUST( alg.cryptCTR( Bin(crypt), Bin(SV), jBytes ).equals( data ), "Wrong CTR" );
	}

	// -----------------------------------------------------------------------------------------------------------------
	void checkCMAC( BlockCipher alg, String keyHex, Binary data, Binary iv, String cmacHex )
	{
		alg.setKey( keyHex );
		MUST( alg.calcCMAC( data, iv ).equals( cmacHex ), "Wrong CMAC" );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static long measure( BlockCipher alg, CryptoMode mode, int iterations, int dataSize )
	{
		Binary key = Bin().random( alg.getKeySize() );
		alg.setKey( key );
		Binary data = Bin().random( dataSize );
		Binary iv = Bin().random( alg.getBlockSize() );

		long tMS = Ticker.measureMs( iterations, () ->
		{
			Binary crypt = alg.encrypt( data, mode, AlignMode.NONE, iv );
			alg.decrypt( crypt, mode, AlignMode.NONE, iv );
		} );
		return tMS;
	}

	// -----------------------------------------------------------------------------------------------------------------
	void checkAndMeasure( BlockCipher alg )
	{
		checkMode( alg, CryptoMode.CBC, AlignMode.NONE, 8192 );

		long t = measure( alg, CryptoMode.CBC, ITERATIONS, DATA_SIZE );
		log.writeln( "Time " + alg.getAlgName() +": " + t + " ms" );

		log.writeln( "Test " + alg.getAlgName() + ": OK\n" );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static String jceCryptoModeStr( CryptoMode cryptMode )
	{
		switch( cryptMode )
		{
			case ECB: return "/ECB/NoPadding";
			case CBC: return "/CBC/NoPadding";
			case OFB: return "/OFB/NoPadding";
			case CFB: return "/CFB/NoPadding";
			case Gost28147CTR: return "GOST28147CTR";
		}
		return "";
	}
	
	// -----------------------------------------------------------------------------------------------------------------
	static Binary jceOperation( String algName, int opMode, CryptoMode cryptMode, final Binary key, final Binary iv, final Binary data )
	{
		try
		{
			Cipher cipher = Cipher.getInstance( algName + jceCryptoModeStr( cryptMode ) );
			SecretKeySpec skey = new SecretKeySpec( key.getDataRef(), 0, key.size(), algName );
			IvParameterSpec ivSpec = new IvParameterSpec( iv.getDataRef(), 0, iv.size() );

			if( cryptMode == CryptoMode.ECB )
				cipher.init( opMode, skey );
			else
				cipher.init( opMode, skey, ivSpec );

			return new Binary( cipher.doFinal( data.getDataRef(), 0, data.size() ) );
		}
		catch( Throwable ex )
		{
			throw new Ex( ex.toString() );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	static long jceMeasureTime( String algName, CryptoMode mode, int iterations, Binary key, Binary iv, Binary plain )
	{
		long tMS = Ticker.measureMs( iterations, () ->
		{
			Binary crypt = jceOperation( algName, Cipher.ENCRYPT_MODE, mode, key, iv, plain );
			jceOperation( algName, Cipher.DECRYPT_MODE, mode, key, iv, crypt );
		} );
		return tMS;
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void jceCompareEncrypt( BlockCipher alg, String jceAlgName, CryptoMode cryptMode )
	{
		int blockSize = alg.getBlockSize();

		for( int i = 0; i < 16; ++i )
		{
			Binary plain = Bin().random( 64 * blockSize + i * blockSize );
			Binary key = Bin().random( alg.getKeySize() );
			Binary iv = Bin().random( alg.getBlockSize() );

			alg.setKey( key );

			Binary crypt = alg.encrypt( plain, cryptMode, AlignMode.NONE, iv );
			Binary crypt2 = jceOperation( jceAlgName, Cipher.ENCRYPT_MODE, cryptMode, key, iv, plain );
			MUST( crypt2.equals( crypt ) );
		}
	}

}
