// Denom.org
// bouncycastle.org

package org.denom.crypt.streamcipher;

import org.denom.Binary;

import static java.lang.Integer.rotateLeft;
import static org.denom.Binary.getIntLE;
import static org.denom.Binary.setIntLE;
import static org.denom.Binary.Bin;
import static org.denom.Ex.MUST;

/**
 * Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005.
 * Key size = 16 or 32 bytes.
 * IV size = 8 bytes.
 */
public class Salsa20 extends StreamCipher
{
	public final static int DEFAULT_ROUNDS = 20;

	private final static int[] TAU_SIGMA = littleEndianToInt( "expand 16-byte kexpand 32-byte k".getBytes(), 0, 8 );

	// -----------------------------------------------------------------------------------------------------------------
	protected void packTauOrSigma( int keyLength, int[] state, int stateOffset )
	{
		int tsOff = (keyLength - 16) / 4;
		state[ stateOffset ] = TAU_SIGMA[ tsOff ];
		state[ stateOffset + 1 ] = TAU_SIGMA[ tsOff + 1 ];
		state[ stateOffset + 2 ] = TAU_SIGMA[ tsOff + 2 ];
		state[ stateOffset + 3 ] = TAU_SIGMA[ tsOff + 3 ];
	}

	protected int rounds;

	private int index = 0;
	protected int[] engineState = new int[ 16 ];
	protected int[] x = new int[ 16 ];
	private byte[] keyStream = new byte[ 16 * 4 ]; // expanded state, 64 bytes

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * 20 rounds.
	 * Set key later.
	 */
	public Salsa20()
	{
		this( DEFAULT_ROUNDS, Bin(16) );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param rounds the number of rounds (must be an even number).
	 * Set key later.
	 */
	public Salsa20( int rounds )
	{
		this( rounds, Bin(16) );
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param rounds the number of rounds (must be an even number).
	 * @param key [16 or 32 bytes]
	 */
	public Salsa20( int rounds, final Binary key )
	{
		MUST( rounds > 1 && ((rounds & 1) == 0), "'rounds' must be a positive, even number" );
		this.rounds = rounds;
		setKey( key );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public String getAlgName()
	{
		String name = "Salsa20";
		if( rounds != DEFAULT_ROUNDS )
			name += "/" + rounds;
		return name;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 *  @param key [16 or 32 bytes]
	 */
	@Override
	public void setKey( final Binary key )
	{
		MUST( (key.size() == 16) || (key.size() == 32), "Wrong key size" );
		this.key = key.clone();
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param IV - 8 bytes.
	 */
	@Override
	public Salsa20 startEncrypt( final Binary iv )
	{
		MUST( (iv != null) && (iv.size() == 8), "Wrong IV size" );
		initSalsa( iv );
		reset();
		return this;
	}

	// -----------------------------------------------------------------------------------------------------------------
	/**
	 * @param IV - 8 bytes.
	 */
	@Override
	public Salsa20 startDecrypt( final Binary iv )
	{
		MUST( (iv != null) && (iv.size() == 8), "Wrong IV size" );
		initSalsa( iv );
		reset();
		return this;
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void initSalsa( final Binary iv )
	{
		int tsOff = (key.size() - 16) / 4;
		engineState[ 0 ] = TAU_SIGMA[ tsOff ];
		engineState[ 5 ] = TAU_SIGMA[ tsOff + 1 ];
		engineState[ 10 ] = TAU_SIGMA[ tsOff + 2 ];
		engineState[ 15 ] = TAU_SIGMA[ tsOff + 3 ];

		littleEndianToInt( key.getDataRef(), 0, engineState, 1, 4 );
		littleEndianToInt( key.getDataRef(), key.size() - 16, engineState, 11, 4 );

		littleEndianToInt( iv.getDataRef(), 0, engineState, 6, 2 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	@Override
	public byte process( byte in )
	{
		byte out = (byte)(keyStream[ index ] ^ in);
		index = (index + 1) & 63;

		if( index == 0 )
		{
			advanceCounter();
			generateKeyStream( keyStream );
		}

		return out;
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void advanceCounter( long diff )
	{
		int hi = (int)(diff >>> 32);
		int lo = (int)diff;

		if( hi > 0 )
		{
			engineState[ 9 ] += hi;
		}

		int oldState = engineState[ 8 ];

		engineState[ 8 ] += lo;

		if( oldState != 0 && engineState[ 8 ] < oldState )
		{
			engineState[ 9 ]++;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void advanceCounter()
	{
		if( ++engineState[ 8 ] == 0 )
			++engineState[ 9 ];
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void retreatCounter( long diff )
	{
		int hi = (int)(diff >>> 32);
		int lo = (int)diff;

		if( hi != 0 )
		{
			MUST( (engineState[ 9 ] & 0xffffffffL) >= (hi & 0xffffffffL), "attempt to reduce counter past zero." );
				engineState[ 9 ] -= hi;
		}

		if( (engineState[ 8 ] & 0xffffffffL) >= (lo & 0xffffffffL) )
		{
			engineState[ 8 ] -= lo;
		}
		else
		{
			MUST( engineState[ 9 ] != 0, "attempt to reduce counter past zero." );

			--engineState[ 9 ];
			engineState[ 8 ] -= lo;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void retreatCounter()
	{
		MUST( (engineState[ 8 ] != 0) || (engineState[ 9 ] != 0), "attempt to reduce counter past zero." );

		if( --engineState[ 8 ] == -1 )
			--engineState[ 9 ];
	}

	// -----------------------------------------------------------------------------------------------------------------
	public long skip( long numberOfBytes )
	{
		if( numberOfBytes >= 0 )
		{
			long remaining = numberOfBytes;
			if( remaining >= 64 )
			{
				long count = remaining / 64;
				advanceCounter( count );
				remaining -= count * 64;
			}

			int oldIndex = index;
			index = (index + (int)remaining) & 63;
			if( index < oldIndex )
			{
				advanceCounter();
			}
		}
		else
		{
			long remaining = -numberOfBytes;
			if( remaining >= 64 )
			{
				long count = remaining / 64;
				retreatCounter( count );
				remaining -= count * 64;
			}

			for( long i = 0; i < remaining; i++ )
			{
				if( index == 0 )
				{
					retreatCounter();
				}
				index = (index - 1) & 63;
			}
		}

		generateKeyStream( keyStream );

		return numberOfBytes;
	}

	// -----------------------------------------------------------------------------------------------------------------
	public long seekTo( long position )
	{
		reset();
		return skip( position );
	}

	// -----------------------------------------------------------------------------------------------------------------
	public long getPosition()
	{
		return getCounter() * 64 + index;
	}

	// -----------------------------------------------------------------------------------------------------------------
	public void reset()
	{
		index = 0;
		resetCounter();
		generateKeyStream( keyStream );
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected long getCounter()
	{
		return ((long)engineState[ 9 ] << 32) | (engineState[ 8 ] & 0xffffffffL);
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void resetCounter()
	{
		engineState[ 8 ] = engineState[ 9 ] = 0;
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected void generateKeyStream( byte[] output )
	{
		salsaCore( rounds, engineState, x );
		intToLittleEndian( x, output, 0 );
	}

	// -----------------------------------------------------------------------------------------------------------------
	static void salsaCore( int rounds, int[] input, int[] x )
	{
		MUST( input.length == 16 );
		MUST( x.length == 16 );
		MUST( rounds % 2 == 0 );

		int x00 = input[ 0 ];
		int x01 = input[ 1 ];
		int x02 = input[ 2 ];
		int x03 = input[ 3 ];
		int x04 = input[ 4 ];
		int x05 = input[ 5 ];
		int x06 = input[ 6 ];
		int x07 = input[ 7 ];
		int x08 = input[ 8 ];
		int x09 = input[ 9 ];
		int x10 = input[ 10 ];
		int x11 = input[ 11 ];
		int x12 = input[ 12 ];
		int x13 = input[ 13 ];
		int x14 = input[ 14 ];
		int x15 = input[ 15 ];

		for( int i = rounds; i > 0; i -= 2 )
		{
			x04 ^= rotateLeft( x00 + x12, 7 );
			x08 ^= rotateLeft( x04 + x00, 9 );
			x12 ^= rotateLeft( x08 + x04, 13 );
			x00 ^= rotateLeft( x12 + x08, 18 );
			x09 ^= rotateLeft( x05 + x01, 7 );
			x13 ^= rotateLeft( x09 + x05, 9 );
			x01 ^= rotateLeft( x13 + x09, 13 );
			x05 ^= rotateLeft( x01 + x13, 18 );
			x14 ^= rotateLeft( x10 + x06, 7 );
			x02 ^= rotateLeft( x14 + x10, 9 );
			x06 ^= rotateLeft( x02 + x14, 13 );
			x10 ^= rotateLeft( x06 + x02, 18 );
			x03 ^= rotateLeft( x15 + x11, 7 );
			x07 ^= rotateLeft( x03 + x15, 9 );
			x11 ^= rotateLeft( x07 + x03, 13 );
			x15 ^= rotateLeft( x11 + x07, 18 );

			x01 ^= rotateLeft( x00 + x03, 7 );
			x02 ^= rotateLeft( x01 + x00, 9 );
			x03 ^= rotateLeft( x02 + x01, 13 );
			x00 ^= rotateLeft( x03 + x02, 18 );
			x06 ^= rotateLeft( x05 + x04, 7 );
			x07 ^= rotateLeft( x06 + x05, 9 );
			x04 ^= rotateLeft( x07 + x06, 13 );
			x05 ^= rotateLeft( x04 + x07, 18 );
			x11 ^= rotateLeft( x10 + x09, 7 );
			x08 ^= rotateLeft( x11 + x10, 9 );
			x09 ^= rotateLeft( x08 + x11, 13 );
			x10 ^= rotateLeft( x09 + x08, 18 );
			x12 ^= rotateLeft( x15 + x14, 7 );
			x13 ^= rotateLeft( x12 + x15, 9 );
			x14 ^= rotateLeft( x13 + x12, 13 );
			x15 ^= rotateLeft( x14 + x13, 18 );
		}

		x[ 0 ] = x00 + input[ 0 ];
		x[ 1 ] = x01 + input[ 1 ];
		x[ 2 ] = x02 + input[ 2 ];
		x[ 3 ] = x03 + input[ 3 ];
		x[ 4 ] = x04 + input[ 4 ];
		x[ 5 ] = x05 + input[ 5 ];
		x[ 6 ] = x06 + input[ 6 ];
		x[ 7 ] = x07 + input[ 7 ];
		x[ 8 ] = x08 + input[ 8 ];
		x[ 9 ] = x09 + input[ 9 ];
		x[ 10 ] = x10 + input[ 10 ];
		x[ 11 ] = x11 + input[ 11 ];
		x[ 12 ] = x12 + input[ 12 ];
		x[ 13 ] = x13 + input[ 13 ];
		x[ 14 ] = x14 + input[ 14 ];
		x[ 15 ] = x15 + input[ 15 ];
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected static void littleEndianToInt( byte[] bs, int bOff, int[] ns, int nOff, int count )
	{
		for( int i = 0; i < count; ++i )
		{
			ns[ nOff + i ] = getIntLE( bs, bOff );
			bOff += 4;
		}
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected static int[] littleEndianToInt( byte[] bs, int off, int count )
	{
		int[] ns = new int[ count ];
		for( int i = 0; i < ns.length; ++i )
		{
			ns[ i ] = getIntLE( bs, off );
			off += 4;
		}
		return ns;
	}

	// -----------------------------------------------------------------------------------------------------------------
	protected static void intToLittleEndian( int[] ns, byte[] bs, int off )
	{
		for( int i = 0; i < ns.length; ++i )
		{
			setIntLE( bs, off, ns[ i ] );
			off += 4;
		}
	}
	
}
