// Denom.org
// bouncycastle.org

package org.denom.crypt.streamcipher;

import org.denom.Binary;

import static java.lang.Integer.rotateLeft;
import static org.denom.Ex.MUST;

/**
 * Daniel J. Bernstein's ChaCha stream cipher.
 */
public class ChaCha extends Salsa20
{
	// -----------------------------------------------------------------------------------------------------------------
	public ChaCha()
	{
		super();
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param rounds the number of rounds (must be an even number).
	 */
	public ChaCha( int rounds )
	{
		super( rounds );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param rounds the number of rounds (must be an even number).
	 * @param key [16 or 32 bytes]
	 */
	public ChaCha( int rounds, final Binary key )
	{
		super( rounds, key );
	}

	// -----------------------------------------------------------------------------------------------------------------
	public String getAlgName()
	{
		return "ChaCha" + rounds;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param IV - 8 bytes.
	 */
	@Override
	public Salsa20 startEncrypt( final Binary iv )
	{
		MUST( (iv != null) && (iv.size() == 8), "Wrong IV size" );
		initChaCha( iv );
		reset();
		return this;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param IV - 8 bytes.
	 */
	@Override
	public Salsa20 startDecrypt( final Binary iv )
	{
		MUST( (iv != null) && (iv.size() == 8), "Wrong IV size" );
		initChaCha( iv );
		reset();
		return this;
	}

	// -----------------------------------------------------------------------------------------------------------------
	private void initChaCha( final Binary iv )
	{
		packTauOrSigma( key.size(), engineState, 0 );

		littleEndianToInt( key.getDataRef(), 0, engineState, 4, 4 );
		littleEndianToInt( key.getDataRef(), key.size() - 16, engineState, 8, 4 );

		littleEndianToInt( iv.getDataRef(), 0, engineState, 14, 2 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void advanceCounter( long diff )
	{
		int hi = (int)(diff >>> 32);
		int lo = (int)diff;

		if( hi > 0 )
		{
			engineState[ 13 ] += hi;
		}

		int oldState = engineState[ 12 ];

		engineState[ 12 ] += lo;

		if( oldState != 0 && engineState[ 12 ] < oldState )
		{
			engineState[ 13 ]++;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void advanceCounter()
	{
		if( ++engineState[ 12 ] == 0 )
		{
			++engineState[ 13 ];
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void retreatCounter( long diff )
	{
		int hi = (int)(diff >>> 32);
		int lo = (int)diff;

		if( hi != 0 )
		{
			MUST( (engineState[ 13 ] & 0xffffffffL) >= (hi & 0xffffffffL), "attempt to reduce counter past zero." );
				engineState[ 13 ] -= hi;
		}

		if( (engineState[ 12 ] & 0xffffffffL) >= (lo & 0xffffffffL) )
		{
			engineState[ 12 ] -= lo;
		}
		else
		{
			MUST( engineState[ 13 ] != 0, "attempt to reduce counter past zero." );

			--engineState[ 13 ];
			engineState[ 12 ] -= lo;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void retreatCounter()
	{
		MUST( (engineState[ 12 ] != 0) || (engineState[ 13 ] != 0), "attempt to reduce counter past zero." );

		if( --engineState[ 12 ] == -1 )
			--engineState[ 13 ];
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected long getCounter()
	{
		return ((long)engineState[ 13 ] << 32) | (engineState[ 12 ] & 0xffffffffL);
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void resetCounter()
	{
		engineState[ 12 ] = 0;
		engineState[ 13 ] = 0;
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void generateKeyStream( byte[] output )
	{
		chachaCore( rounds, engineState, x );
		intToLittleEndian( x, output, 0 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void chachaCore( int rounds, int[] input, int[] x )
	{
		int x00 = input[ 0];
		int x01 = input[ 1];
		int x02 = input[ 2];
		int x03 = input[ 3];
		int x04 = input[ 4];
		int x05 = input[ 5];
		int x06 = input[ 6];
		int x07 = input[ 7];
		int x08 = input[ 8];
		int x09 = input[ 9];
		int x10 = input[10];
		int x11 = input[11];
		int x12 = input[12];
		int x13 = input[13];
		int x14 = input[14];
		int x15 = input[15];

		for( int i = rounds; i > 0; i -= 2 )
		{
			x00 += x04; x12 = rotateLeft(x12 ^ x00, 16);
			x08 += x12; x04 = rotateLeft(x04 ^ x08, 12);
			x00 += x04; x12 = rotateLeft(x12 ^ x00, 8);
			x08 += x12; x04 = rotateLeft(x04 ^ x08, 7);
			x01 += x05; x13 = rotateLeft(x13 ^ x01, 16);
			x09 += x13; x05 = rotateLeft(x05 ^ x09, 12);
			x01 += x05; x13 = rotateLeft(x13 ^ x01, 8);
			x09 += x13; x05 = rotateLeft(x05 ^ x09, 7);
			x02 += x06; x14 = rotateLeft(x14 ^ x02, 16);
			x10 += x14; x06 = rotateLeft(x06 ^ x10, 12);
			x02 += x06; x14 = rotateLeft(x14 ^ x02, 8);
			x10 += x14; x06 = rotateLeft(x06 ^ x10, 7);
			x03 += x07; x15 = rotateLeft(x15 ^ x03, 16);
			x11 += x15; x07 = rotateLeft(x07 ^ x11, 12);
			x03 += x07; x15 = rotateLeft(x15 ^ x03, 8);
			x11 += x15; x07 = rotateLeft(x07 ^ x11, 7);
			x00 += x05; x15 = rotateLeft(x15 ^ x00, 16);
			x10 += x15; x05 = rotateLeft(x05 ^ x10, 12);
			x00 += x05; x15 = rotateLeft(x15 ^ x00, 8);
			x10 += x15; x05 = rotateLeft(x05 ^ x10, 7);
			x01 += x06; x12 = rotateLeft(x12 ^ x01, 16);
			x11 += x12; x06 = rotateLeft(x06 ^ x11, 12);
			x01 += x06; x12 = rotateLeft(x12 ^ x01, 8);
			x11 += x12; x06 = rotateLeft(x06 ^ x11, 7);
			x02 += x07; x13 = rotateLeft(x13 ^ x02, 16);
			x08 += x13; x07 = rotateLeft(x07 ^ x08, 12);
			x02 += x07; x13 = rotateLeft(x13 ^ x02, 8);
			x08 += x13; x07 = rotateLeft(x07 ^ x08, 7);
			x03 += x04; x14 = rotateLeft(x14 ^ x03, 16);
			x09 += x14; x04 = rotateLeft(x04 ^ x09, 12);
			x03 += x04; x14 = rotateLeft(x14 ^ x03, 8);
			x09 += x14; x04 = rotateLeft(x04 ^ x09, 7);

		}

		x[ 0 ] = x00 + input[ 0 ];
		x[ 1 ] = x01 + input[ 1 ];
		x[ 2 ] = x02 + input[ 2 ];
		x[ 3 ] = x03 + input[ 3 ];
		x[ 4 ] = x04 + input[ 4 ];
		x[ 5 ] = x05 + input[ 5 ];
		x[ 6 ] = x06 + input[ 6 ];
		x[ 7 ] = x07 + input[ 7 ];
		x[ 8 ] = x08 + input[ 8 ];
		x[ 9 ] = x09 + input[ 9 ];
		x[ 10 ] = x10 + input[ 10 ];
		x[ 11 ] = x11 + input[ 11 ];
		x[ 12 ] = x12 + input[ 12 ];
		x[ 13 ] = x13 + input[ 13 ];
		x[ 14 ] = x14 + input[ 14 ];
		x[ 15 ] = x15 + input[ 15 ];
	}
}
