'''
Created on Sep 26, 2012

@author: elnio
'''
from __future__ import division
import numpy as np
import Global as gl
from numpy.lib.scimath import sqrt
import BasicWindow as b

'''
PUBLIC METHODS
'''
                    
'''Product of the controlVector and the randomVector.    
    randomVector            controlVector
    =============         =================
    x,x,...,x            o,o,...,o
    ~~-----~~            ~~-----~~\
      k*bw                k*nb
       
            Output(Sketch)
            ==============
                k*sw                                  
                '''
def computeFinalRandomVectorSW(controlVector, randomVector ):
    assert len(controlVector)==len(randomVector), "The control vector and the random vector must have equal number of rows." 
    tmp = np.array( [[ randomVector[i]*x for x in controlVector[i] ]  for i in  range(len(controlVector)) ] )
    return np.reshape(tmp, (gl.numSketch, gl.sw) )    
            
''' Output is the random vector of size, normally, k (sketches) x nb (basicWindows).
    The analogous definitions imply for the control vector''' 
def generateRandomVector(nSeries, lengthSeries):
    rn = np.random.normal(0,1,(nSeries*lengthSeries) )
    r = np.array([ (1 if x>0 else -1) for x in rn ])
    return np.reshape(r, (nSeries, lengthSeries) )

'''Input: a time series of length sw
   Output:the sketch of this time series  '''
def batchConv(chunk):
    assert chunk.shape[1]==gl.uk.shape[1], "The two vectors must have at least one dimension equal"    
    return np.dot(chunk, gl.uk.transpose())

'''Input: a basic window of data
Output: its inner product with the random vector of length bw '''
def convPseudo(newbw):
    tmp = newbw*gl.uk
    return sum( np.transpose(tmp))    





class Sketch:
    series= np.zeros( gl.sw )
    sketches = np.zeros( gl.numSketch )
    dotprods = None
    head = gl.nb-1
    tail = 0
    order = range(int(gl.nb))
    index = 0
    sumSketch = 0
    sumSketchSquared =0
    std = 0
    norm =  np.zeros( gl.sw )
    
    def __init__(self):
        self.dotprods = np.empty(gl.nb,dtype=object)
        for i in range( int(gl.nb)):
            self.dotprods[i] = b.Bw() 
        
'''The chunk is of size sw'''
def build( chunk,idx):
          
    chops= chunk.reshape(gl.nb,gl.bw)
        
    #series for a sliding window
    gl.data[idx].series=chunk
        
    tmpSUM = np.sum(chops, axis=1)
    tmpSUMSQ = np.sum( chops*chops, axis=1)
    tmpBWU = batchConv(chops)
    
    assert len(tmpSUM)==len(tmpSUMSQ)==len(tmpBWU), "error when creating the statistics for basic window"

    for i in range( int(gl.nb) ):
        gl.data[idx].dotprods[i].init( tmpSUM[i],tmpSUMSQ[i],tmpBWU[i] )
        
    gl.data[idx].sumSketch = sum([ x.sumBw for x in  gl.data[idx].dotprods])
    gl.data[idx].sumSketchSquared = sum([ x.squaredSum for x in  gl.data[idx].dotprods])
    gl.data[idx].std= sqrt(  gl.data[idx].sumSketchSquared - np.square( gl.data[idx].sumSketch)/gl.sw )
    
    gl.data[idx].std += 0.000001

    tmp =(  gl.data[idx].series- gl.data[idx].sumSketch/gl.sw )/ gl.data[idx].std
                
    gl.data[idx].sketches= np.dot(gl.uswk,tmp)
    
        
           
'''The chunk is of size bw'''
def batchUpdate(idx, chunk):
        
    gl.data[idx].series = np.append(gl.data[idx].series, chunk)
    gl.data[idx].series = gl.data[idx].series[gl.bw:]
        
    bwTmp = b.Bw()
    bwTmp.init(sum(chunk),sum(chunk*chunk),convPseudo(chunk))
    #for i in range( int(gl.nb)):
        
    gl.data[idx].sumSketch += bwTmp.sumBw - gl.data[idx].dotprods[gl.data[idx].tail].sumBw
    gl.data[idx].sumSketchSquared += bwTmp.squaredSum - gl.data[idx].dotprods[gl.data[idx].tail].squaredSum
  
   
    gl.data[idx].std = sqrt( gl.data[idx].sumSketchSquared
                                -
                                  np.square( gl.data[idx].sumSketch)/gl.sw  )
    
    gl.data[idx].std += 0.000001

    gl.data[idx].head = ( gl.data[idx].head+1 if gl.data[idx].head<gl.nb-1 else 0)
    gl.data[idx].tail = ( gl.data[idx].tail+1 if gl.data[idx].tail<gl.nb-1 else 0)
        
    gl.data[idx].dotprods[gl.data[idx].head] = bwTmp
        
    listindex = np.array(gl.data[idx].order )
    listindex =np.append(listindex,gl.data[idx].head )
    listindex = listindex[1:]
    gl.data[idx].order = listindex
            
    # a vector of size nb x 60 -- it will be subtracted from sumu of size 60.
    tmpDotProds = np.array( [ gl.data[idx].dotprods[i].bwu for i in gl.data[idx].order ] )
    tmpSumu = gl.sumu*gl.data[idx].sumSketch/gl.sw
        
    temp = (tmpDotProds - tmpSumu)/gl.data[idx].std
    ttmp = np.empty(len(temp),dtype=object)

    for i in range( len(temp) ):
        ttmp[i] = temp[i,:]*gl.bi[:,i]
    
    gl.data[idx].sketches = np.sum(ttmp,axis=0)
    
            