// Denom.org
// Author:  Sergey Novochenko,  Digrol@gmail.com

package org.denom.crypt.blockcipher;

import org.denom.Binary;

import static org.denom.Binary.Bin;
import static org.denom.Ex.MUST;
import static org.denom.Binary.*;
import static java.lang.Integer.rotateLeft;

/**
 * GOST R 34.12-2015: Block Cipher "Magma"
 * RFC 8891.
 */
public class Magma extends BlockCipher
{
	public static final int BLOCK_SIZE = 8;
	public static final int KEY_SIZE = 32;

	private int key0;
	private int key1;
	private int key2;
	private int key3;
	private int key4;
	private int key5;
	private int key6;
	private int key7;

	// -----------------------------------------------------------------------------------------------------------------
	public Magma()
	{
		this( Bin( KEY_SIZE ) );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param key [32 bytes].
	 */
	public Magma( final Binary key )
	{
		super.initialize( BLOCK_SIZE );
		setKey( key );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public Magma clone()
	{
		return new Magma( this.key );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public String getAlgName()
	{
		return "Magma";
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public Binary generateKey()
	{
		Binary k = new Binary().randomSecure( KEY_SIZE );
		setKey( k );
		return k;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param key [32 bytes].
	 */
	@Override
	public void setKey( Binary key )
	{
		MUST( key.size() == KEY_SIZE, "Invalid key size" );
		this.key = key.clone();

		byte[] k = key.getDataRef();
		key0 = getIntBE( k, 0 );
		key1 = getIntBE( k, 4 );
		key2 = getIntBE( k, 8 );
		key3 = getIntBE( k, 12 );
		key4 = getIntBE( k, 16 );
		key5 = getIntBE( k, 20 );
		key6 = getIntBE( k, 24 );
		key7 = getIntBE( k, 28 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private static int g( int m )
	{
		return pi76[ (m >> 24) & 0xFF ]  |  pi54[ (m >> 16) & 0xFF ]
			 | pi32[ (m >>  8) & 0xFF ]  |  pi10[ m & 0xFF ];
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public void encryptBlock( Binary block )
	{
		MUST( block.size() == BLOCK_SIZE, "Incorrect block size" );

		byte[] arr = block.getDataRef();

		int n2 = getIntBE( arr, 0 );
		int n1 = getIntBE( arr, 4 );

		for( int i = 0; i < 3; ++i )
		{
			n2 ^= g( n1 + key0 );
			n1 ^= g( n2 + key1 );
			n2 ^= g( n1 + key2 );
			n1 ^= g( n2 + key3 );
			n2 ^= g( n1 + key4 );
			n1 ^= g( n2 + key5 );
			n2 ^= g( n1 + key6 );
			n1 ^= g( n2 + key7 );
		}

		n2 ^= g( n1 + key7 );
		n1 ^= g( n2 + key6 );
		n2 ^= g( n1 + key5 );
		n1 ^= g( n2 + key4 );
		n2 ^= g( n1 + key3 );
		n1 ^= g( n2 + key2 );
		n2 ^= g( n1 + key1 );
		n1 ^= g( n2 + key0 );

		setIntBE( arr, 0, n1 );
		setIntBE( arr, 4, n2 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public void decryptBlock( Binary block )
	{
		MUST( block.size() == BLOCK_SIZE, "Incorrect block size" );

		byte[] arr = block.getDataRef();
		int n2 = getIntBE( arr, 0 );
		int n1 = getIntBE( arr, 4 );

		n2 ^= g( n1 + key0 );
		n1 ^= g( n2 + key1 );
		n2 ^= g( n1 + key2 );
		n1 ^= g( n2 + key3 );
		n2 ^= g( n1 + key4 );
		n1 ^= g( n2 + key5 );
		n2 ^= g( n1 + key6 );
		n1 ^= g( n2 + key7 );

		for( int i = 0; i < 3; ++i )
		{
			n2 ^= g( n1 + key7 );
			n1 ^= g( n2 + key6 );
			n2 ^= g( n1 + key5 );
			n1 ^= g( n2 + key4 );
			n2 ^= g( n1 + key3 );
			n1 ^= g( n2 + key2 );
			n2 ^= g( n1 + key1 );
			n1 ^= g( n2 + key0 );
		}

		setIntBE( arr, 0, n1 );
		setIntBE( arr, 4, n2 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	private static final byte[] pi0 = { 12, 4, 6, 2, 10, 5, 11, 9, 14, 8, 13, 7, 0, 3, 15, 1 };
	private static final byte[] pi1 = { 6, 8, 2, 3, 9, 10, 5, 12, 1, 14, 4, 7, 11, 13, 0, 15 };
	private static final byte[] pi2 = { 11, 3, 5, 8, 2, 15, 10, 13, 14, 1, 7, 4, 12, 9, 6, 0 };
	private static final byte[] pi3 = { 12, 8, 2, 1, 13, 4, 15, 6, 7, 0, 10, 5, 3, 14, 9, 11 };
	private static final byte[] pi4 = { 7, 15, 5, 10, 8, 1, 6, 13, 0, 9, 3, 14, 11, 4, 2, 12 };
	private static final byte[] pi5 = { 5, 13, 15, 6, 9, 2, 12, 10, 11, 7, 8, 1, 4, 3, 14, 0 };
	private static final byte[] pi6 = { 8, 14, 2, 5, 6, 9, 1, 12, 15, 4, 11, 0, 13, 10, 3, 7 };
	private static final byte[] pi7 = { 1, 7, 14, 13, 0, 5, 8, 3, 4, 15, 10, 6, 9, 12, 11, 2 };

	private static final int[] pi76 = new int[ 256 ];
	private static final int[] pi54 = new int[ 256 ];
	private static final int[] pi32 = new int[ 256 ];
	private static final int[] pi10 = new int[ 256 ];

	static
	{
		for( int i = 0; i < 256; ++i )
		{
			int iLo = i & 0x0F;
			int iHi = i >> 4;
			pi10[ i ] = rotateLeft(  (pi1[ iHi ] << 4) | pi0[ iLo ]       , 11 );
			pi32[ i ] = rotateLeft( ((pi3[ iHi ] << 4) | pi2[ iLo ]) << 8 , 11 );
			pi54[ i ] = rotateLeft( ((pi5[ iHi ] << 4) | pi4[ iLo ]) << 16, 11 );
			pi76[ i ] = rotateLeft( ((pi7[ iHi ] << 4) | pi6[ iLo ]) << 24, 11 );
		}
	}
}
