package org.denom.crypt.blockcipher;

import java.util.Arrays;

import org.denom.Binary;

import static org.denom.Binary.Bin;
import static org.denom.Ex.MUST;
import static java.lang.System.arraycopy;

/**
 * GOST R 34.12-2015: Block Cipher "Kuznyechik".
 * RFC 7801
 */
public class Kuznyechik extends BlockCipher
{
	private static final int BLOCK_SIZE = 16;
	private static final int KEY_SIZE = 32;

	private byte[][] wKey = new byte[ 10 ][ 16 ];
	private byte[] tmp = new byte[ 16 ];
	private byte[] x = new byte[ 16 ];
	private byte[] y = new byte[ 16 ];

	// -----------------------------------------------------------------------------------------------------------------
	private static byte[][] GFMul = new byte[ 256 ][ 256 ];

	// -----------------------------------------------------------------------------------------------------------------
	static
	{
		for( int x = 0; x < 256; x++ )
			for( int y = 0; y < 256; y++ )
				GFMul[ x ][ y ] = mulGF256( (byte)x, (byte)y );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private static byte mulGF256( byte a, byte b )
	{
		byte p = 0;
		for( byte counter = 0; (counter < 8) && (a != 0) && (b != 0); counter++ )
		{
			if( (b & 1) != 0 )
				p ^= a;
			
			boolean hiBitSet = (byte)(a & 0x80) != 0;

			a <<= 1;

			if( hiBitSet )
				a ^= 0xc3; // x^8 + x^7 + x^6 + x + 1

			b >>= 1;
		}
		return p;
	}

	// -----------------------------------------------------------------------------------------------------------------
	public Kuznyechik()
	{
		this( Bin( KEY_SIZE ) );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param key [32 bytes].
	 */
	public Kuznyechik( final Binary key )
	{
		super.initialize( BLOCK_SIZE );
		setKey( key );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public Kuznyechik clone()
	{
		return new Kuznyechik( this.key );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public String getAlgName()
	{
		return "Kuznyechik";
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public Binary generateKey()
	{
		Binary k = new Binary().randomSecure( KEY_SIZE );
		setKey( k );
		return k;
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void XOR( byte[] res, byte[] a )
	{
		for( int i = 0; i < 16; i++ )
			res[ i ] ^= a[ i ];
	}

	// -----------------------------------------------------------------------------------------------------------------
	private byte l( byte[] data )
	{
		byte x = data[ 15 ];
		for( int i = 14; i >= 0; i-- )
		{
			x ^= GFMul[ data[ i ] & 0xFF ][ lFactors[ i ] & 0xFF ];
		}
		return x;
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void L( byte[] data )
	{
		for( int i = 0; i < 16; i++ )
		{
			byte z = l( data );
			arraycopy( data, 0, data, 1, 15 );
			data[ 0 ] = z;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param key [32 bytes].
	 */
	@Override
	public void setKey( Binary key )
	{
		MUST( key.size() == KEY_SIZE, "Invalid key size" );
		this.key = key.clone();
		byte[] keyArr = key.getDataRef();

		arraycopy( keyArr, 0,  x, 0, 16 );
		arraycopy( keyArr, 0,  wKey[ 0 ], 0, 16 );

		arraycopy( keyArr, 16, y, 0, 16 );
		arraycopy( keyArr, 16, wKey[ 1 ], 0, 16 );

		for( int k = 1; k < 5; k++ )
		{
			for( int j = 1; j <= 8; j++ )
			{
				Arrays.fill( tmp, (byte)0 );
				tmp[ 15 ] = (byte)(8 * (k - 1) + j);
				L( tmp );

				XOR( tmp, x );

				// S
				for( int i = 0; i < 16; i++ )
					tmp[ i ] = PI[ tmp[ i ] & 0xFF ];

				L( tmp );
				XOR( tmp, y );

				arraycopy( x, 0, y, 0, 16 );
				arraycopy( tmp, 0, x, 0, 16 );
			}

			arraycopy( x, 0, wKey[ 2 * k ], 0, 16 );
			arraycopy( y, 0, wKey[ 2 * k + 1 ], 0, 16 );
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public void encryptBlock( Binary block )
	{
		MUST( block.size() == BLOCK_SIZE, "Incorrect block size" );
		byte[] arr = block.getDataRef();

		for( int i = 0; i < 9; i++ )
		{
			XOR( arr, wKey[ i ] );

			// S
			for( int j = 0; j < BLOCK_SIZE; j++ )
				arr[ j ] = PI[ arr[ j ] & 0xFF ];

			L( arr );
		}
		XOR( arr, wKey[ 9 ] );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public void decryptBlock( Binary block )
	{
		MUST( block.size() == BLOCK_SIZE, "Incorrect block size" );

		byte[] arr = block.getDataRef();

		for( int i = 9; i > 0; i-- )
		{
			XOR( arr, wKey[ i ] );

			// Inverse L
			for( int k = 0; k < 16; k++ )
			{
				arraycopy( arr, 1, tmp, 0, 15 );
				tmp[ 15 ] = arr[ 0 ];
				byte z = l( tmp );
				arraycopy( arr, 1, arr, 0, 15 );
				arr[ 15 ] = z;
			}

			// Inverse S
			for( int j = 0; j < arr.length; j++ )
				arr[ j ] = inversePI[ arr[ j ] & 0xFF ];
		}
		XOR( arr, wKey[ 0 ] );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private static final byte[] PI = new byte[] {
			-4, -18, -35, 17, -49, 110, 49, 22, -5, -60, -6, -38, 35, -59, 4, 77, -23, 119, -16, -37, -109, 46, -103, -70,
			23, 54, -15, -69, 20, -51, 95, -63, -7, 24, 101, 90, -30, 92, -17, 33, -127, 28, 60, 66, -117, 1, -114, 79, 5,
			-124, 2, -82, -29, 106, -113, -96, 6, 11, -19, -104, 127, -44, -45, 31, -21, 52, 44, 81, -22, -56, 72, -85, -14,
			42, 104, -94, -3, 58, -50, -52, -75, 112, 14, 86, 8, 12, 118, 18, -65, 114, 19, 71, -100, -73, 93, -121, 21,
			-95, -106, 41, 16, 123, -102, -57, -13, -111, 120, 111, -99, -98, -78, -79, 50, 117, 25, 61, -1, 53, -118, 126,
			109, 84, -58, -128, -61, -67, 13, 87, -33, -11, 36, -87, 62, -88, 67, -55, -41, 121, -42, -10, 124, 34, -71,
			3, -32, 15, -20, -34, 122, -108, -80, -68, -36, -24, 40, 80, 78, 51, 10, 74, -89, -105, 96, 115, 30, 0, 98, 68,
			26, -72, 56, -126, 100, -97, 38, 65, -83, 69, 70, -110, 39, 94, 85, 47, -116, -93, -91, 125, 105, -43, -107,
			59, 7, 88, -77, 64, -122, -84, 29, -9, 48, 55, 107, -28, -120, -39, -25, -119, -31, 27, -125, 73, 76, 63, -8,
			-2, -115, 83, -86, -112, -54, -40, -123, 97, 32, 113, 103, -92, 45, 43, 9, 91, -53, -101, 37, -48, -66, -27,
			108, 82, 89, -90, 116, -46, -26, -12, -76, -64, -47, 102, -81, -62, 57, 75, 99, -74 };

	// -----------------------------------------------------------------------------------------------------------------
	private static final byte[] inversePI = new byte[]{
			-91, 45, 50, -113, 14, 48, 56, -64, 84, -26, -98, 57, 85, 126, 82, -111, 100, 3, 87, 90, 28, 96, 7, 24, 33, 114,
			-88, -47, 41, -58, -92, 63, -32, 39, -115, 12, -126, -22, -82, -76, -102, 99, 73, -27, 66, -28, 21, -73, -56, 6,
			112, -99, 65, 117, 25, -55, -86, -4, 77, -65, 42, 115, -124, -43, -61, -81, 43, -122, -89, -79, -78, 91, 70, -45,
			-97, -3, -44, 15, -100, 47, -101, 67, -17, -39, 121, -74, 83, 127, -63, -16, 35, -25, 37, 94, -75, 30, -94, -33,
			-90, -2, -84, 34, -7, -30, 74, -68, 53, -54, -18, 120, 5, 107, 81, -31, 89, -93, -14, 113, 86, 17, 106, -119,
			-108, 101, -116, -69, 119, 60, 123, 40, -85, -46, 49, -34, -60, 95, -52, -49, 118, 44, -72, -40, 46, 54, -37,
			105, -77, 20, -107, -66, 98, -95, 59, 22, 102, -23, 92, 108, 109, -83, 55, 97, 75, -71, -29, -70, -15, -96, -123,
			-125, -38, 71, -59, -80, 51, -6, -106, 111, 110, -62, -10, 80, -1, 93, -87, -114, 23, 27, -105, 125, -20, 88, -9,
			31, -5, 124, 9, 13, 122, 103, 69, -121, -36, -24, 79, 29, 78, 4, -21, -8, -13, 62, 61, -67, -118, -120, -35, -51,
			11, 19, -104, 2, -109, -128, -112, -48, 36, 52, -53, -19, -12, -50, -103, 16, 68, 64, -110, 58, 1, 38, 18, 26,
			72, 104, -11, -127, -117, -57, -42, 32, 10, 8, 0, 76, -41, 116 };

	// -----------------------------------------------------------------------------------------------------------------
	private static final byte[] lFactors = { -108, 32, -123, 16, -62, -64, 1, -5, 1, -64, -62, 16, -123, 32, -108, 1 };

}
