import qualified Data.Map as Map
import Debug.Trace
import ModArithmetic
import IntArithmetic

-- sample data
-- p :: Integer
-- p = 0xF3C68DAD0EBF3115BD89E3A22CE330FEA16A127D27E1343E1D076C3E6D8A3910BB0B19D7A953E1136E897CB6310187600F0A50C3398EB5240567EEA87B053F41
--
ps :: [Integer]
ps = [5, 7, 9, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 64, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 4294967311, 4831839503]
--
-- y :: Integer
y = 0x46AD64F32F5C4179046CF5AB6C7D8FA5D0BDFC95DA71FB0A061FB32B565BD2FC97D046F2CF1B7954A570276E17FEBABED681AC9F3DDE0D348DB30F54A0E52F3D
-- 
-- w :: Integer
w = 0xA00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004D3
--
-- sample solution
-- x :: Integer
-- x = 10863043687268440265725600742148764004943658996724261846980639736824676775277692568759835703032229550033577398472776592021184412939318080811891212674355023
-- 
-- xs :: [Integer]
-- xs = [1,4,6,4,12,11,12,17,15,1,32,31,28,32,35,26,49,19,9,26,25,60,63,63,9,23,99,30,106,104,91,56,40,24,25,55,92,42,1,83,101,82,144,97,108,21,174,21,115,2,99,91,85,118,108,30,116,103,265,117,121,18,280,184,308,229,279,1887567846,3867139166]
--
----------------------------------------------------

-- findExponent: Loosly speaking, this function returns the private key given a public key.
-- Specifically, it returns the (secret) exponent x given the generator w, (public) modulus y and 
-- the factorization ps of the numerus (p-1).
-- In other words, it finds an x such that w ^ x is congruent to y mod p. Given a factorization ps
-- such that product ps = p-1
--
-- It uses the babyStepGiantStep algorithm to find solutions x_i to smaller subproblems.
-- Namely x_i such that y_i is congruent to w_i ^ x_i mod p_i. This can be done with the given
-- factorization of p-1.
-- Then the general solution x is found by solving a set of simulateous congruences
-- using the chinese remainder theorem.
-- Namely x is congruent to x_i / p_1*..*p_(i-1)*p_(i+1)*..*p_n mod p_i
--
-- ps - the factorization
-- w  - the generator
-- y  - the public key
--
-- Given, the factorization, we split the problem into smaller sub-problems.
-- Namely y_i is congruent to w_i ^ x_i mod p_i
findExponent :: Integer -> Integer -> [Integer] -> Integer
findExponent w y ps = let p   = product ps + 1
                          ps' = map (div (p-1)) ps
							-- p_i' = p_1*..*p_(i-1)*p_(i+1)*..*p_n
                          ys  = [ modexp p (modexp p y p_i') p_i' | p_i' <- ps' ]
                          ws  = [ modexp p w                 p_i' | p_i' <- ps' ]
                          xs  = map (babyStepGiantStep p) (zip3 ys ws ps)
                          xs' = [modmult p_i x_i (modinv p_i p_i') | (x_i, p_i, p_i') <- zip3 xs ps ps' ]
							-- x_i' = x_i / p_i' mod p_i
                         in chineseRemainder xs' ps

-- babyStepGiantStep finds x_i. It sets up a meet-in-the-middle attack (i.e. table and search keys)
-- which is then carried out by find1stMatch.
babyStepGiantStep :: Integer -> (Integer, Integer, Integer) -> Integer
babyStepGiantStep p (y_i, w_i, p_i) = i*q_i + j
		where {	q_i 		= 1 + squareRoot p_i;
				table 		= Map.fromList (babySteps w_i q_i p);
				searchKeys 	= giantSteps w_i q_i p y_i;
				(i,j) 		= find1stMatch table searchKeys; }
			
giantSteps :: Integer -> Integer -> Integer -> Integer -> [Integer]
giantSteps w q p y = iterate (modmult p stepSize) y
	where stepSize = modexp p w (p-1-q)

babySteps :: Integer -> Integer -> Integer -> [(Integer, Integer)]
babySteps w q p =  [ (modexp p w j, j) | j <- [0..q-1] ]

-- find1stMatch implements a meet in the middle attack on babySteps and giantSteps to find x_i.
-- The result is a pair (i,j) which gives the value of the meeting value (j) and the number (i)
-- of the meeting value in the list.
find1stMatch :: (Ord t, Num t1) => Map.Map t a -> [t] -> (t1, a)
find1stMatch table searchKeys = ijFromList (0,0) searchResults
	where searchResults = map (flip Map.lookup table) searchKeys

ijFromList :: (Num t, Num t2) => (t, t2) -> [Maybe t1] -> (t, t1)
ijFromList (i,_) (Nothing:rest) = ijFromList (i+1,0) rest
ijFromList (i,_) ((Just j):_) 	= (i,j)




             		



