'''
Created on 23/03/2014, tested for Python 2.7
@authors: Juan Pablo Posadas, Grigori Sidorov
Class for obtaining syntactic n-grams from dependency trees using Stanford parser output
NOTE: Input tree should NOT be collapsed: -outputFormat "wordsAndTags, typedDependencies" -outputFormatOptions "basicDependencies"
NOTE: Since ",", "[", "]", and "\" are part of our metalanguage, we add a slash to them when they are part of the sentence (sn-grams), e.g., "\,", "\[", "\\"
'''
import copy, sys
import codecs

class SNgrams(object):
  '''
  Class for obtaining syntactic n-grams from dependency trees using Stanford parser output
  Both continuous and non-continuous n-grams
  '''

  def __init__(self):
    self.words          = {} #Dictionary of original words according to their positions       
    self.deps           = {} #Dictionary with dependent words
    self.rels           = {} #Dictionary with dependency relations 
    self.children       = {} #Dictionary with words that are dependent for a (key) word
    self.leafs          = [] #List of indexes of words that are leaves
    self.root_idx       = 0  #Index of the root 
    self.DepNgrams      = [] #Auxiliar conatiner of the sngrams
    self.listDep        = [] #List of n-grams of words organized by size       
    self.listSRTags     = [] #List of n-grams of SR tags organized by size
    self.listContDep    = [] #List of CONTINUOS n-grams of words organized by size
    self.listContSRTags = [] #List of CONTINUOS n-gram of SR tags organized by size

  ######################
  def reset_vars (self):
    self.words   .clear()    
    self.deps    .clear() 
    self.rels    .clear() 
    self.children.clear() 
    del self.leafs [:]  
    self.root_idx     = 0  

    del self.DepNgrams      [:] 
    del self.listDep        [:]           
    del self.listSRTags     [:]
    del self.listContDep    [:]
    del self.listContSRTags [:]

  ######################
  def process_sentence (self, lines, option):
    '''
    This method call the specific methods (general steps) for producing sngrams according to the specific parameter "option"
    '''
    self.reset_vars()
    self.prepare_indices (lines)        
    self.print_parsed_sentence ()     

    if option in [0,1,2]:
      self.get_all_DepNgrams()
      if option in [0,1]:
        self.store_all_DepNgrams(option)
      else:
        self.store_all_DepNgrams(0)
        self.store_all_DepNgrams(1)              
    elif option in [3,4,5]:
      max_height = self.get_cont_DepNgrams()
      if option in [3,4]:
        self.store_cont_DepNgrams(max_height,option)
      else:
        self.store_cont_DepNgrams(max_height,3)
        self.store_cont_DepNgrams(max_height,4)
    else:
      print "Invalid option."
          
         
  ######################        
  def prepare_indices(self, lines):
    for idx, line in enumerate(lines):            
      self.rels[idx + 1] = line[ : line.find("(")]
      line=line[line.find("("):]
                    
      p_idx = int(line [line.find("-") + 1 : line.find(",",line.find("-"))] )
      self.deps [idx + 1]  = p_idx
      self.words[idx + 1]  = line[line.find(",") + 2 : line.rfind("-")]
      
      if self.words[idx+1] == ',':        
        self.words[idx+1] = "\,"
      elif self.words[idx+1] == '[':
        self.words[idx+1] = "\["
      elif self.words[idx+1] == ']':
        self.words[idx+1] = "\]"
      elif self.words[idx+1] == '\\':
        self.words[idx+1] = "\\\\"
        
      self.children[p_idx] = self.children.get(p_idx, [])
      self.children[p_idx].append(idx + 1)

      if self.deps [idx + 1] == 0:
        self.root_idx = idx + 1
        self.rels[idx+1] = "root"#Line added because of the FREELING output
      
    #Determine if a word is a leaf
    for i in self.words.keys():            
      if i not in self.children.keys():
        self.leafs.append(i)

    #Initialize container for n-grams
    for i in xrange(0, len(lines)):
      self.listSRTags.append([])
      self.listDep   .append([])

  ######################            
  def print_parsed_sentence(self):
    line = ""
    for i in sorted(self.words.keys()):
      line += self.words[i]+" "
    print line.rstrip(" ")

    line = ""
    for i in self.words.keys():
      line  = str(i)             + "\t"
      line +=     self.words[i]  + "\t"
      line +=     self.rels [i]  + "\t"
      line += str(self.deps [i]) + "\t"
      if i not in self.leafs:
        line += str(self.children[i])
      print line

    print "Leaf nodes are:"
    line = ""
    for i in self.leafs:
      line += self.words[i] + ", "
    print line

  
  #################################
  def get_cont_DepNgrams(self):
    '''
    Method that begins the look for continuous sngrams
    '''
    del self.DepNgrams  [:]              
    path = [] #Auxiliary container for all the posible paths        
    max_height =0    
    self.DepNgrams.append([self.root_idx])
    path.append(self.root_idx)        
    max_height = self.traverse_pathDep(self.root_idx,path,max_height)
    return (max_height)
                
  def store_cont_DepNgrams(self,max_height,op):
    '''
    Method that translate the id of the index of the word into the SR tag or word
    '''
    if op == 3:#Case focus on continuous word sngrams    
      for i in xrange(0,max_height):#Initialize the final container of the sngrams 
        self.listContDep.append([])        
      for sngram in self.DepNgrams:#Clasify the sngram by its size
        line = ""  
        for item in sngram:#Translate index into word
          line += self.words[item] + " "
        line = line.rstrip(" ")
        self.listContDep[len(sngram)-1].append(line)

    elif op == 4:
      for i in xrange(0,max_height):#Initialize the final container of the sngrams 
        self.listContSRTags.append([])        
      for sngram in self.DepNgrams:#Clasify the sngram by its size
        line = ""
        for item in sngram:
          line += self.rels[item] + " "
        line = line.rstrip(" ")
        self.listContSRTags[len(sngram)-1].append(line)
                                   
    ######################                                
  def traverse_pathDep(self,idx,path,max_height):                                
    '''
    Method that builds all the posible paths of the tree 
    '''
    if idx not in self.leafs:
      for child in self.children[idx]:                                      
        path.append(child)                
        for tam in xrange(0,len(path)):#Size of posible sngrams that can be obtained from the path                    
          sngram = []
          for i in xrange(tam,len(path)):                                            
            sngram.append(path[i])#Generate the posible sngrams acording with the size                    
                                
          self.DepNgrams.append(copy.deepcopy(sngram)) #Update the dictionary with the new sngram

          size=len(sngram)
          if size > max_height:
            max_height = size
                
        max_height = self.traverse_pathDep(child,path,max_height)                                
        path.pop()#Erase the last node that was added to the path
    return (max_height)

  ######################            
  def prepare_SNgram(self, line,op):
    ngram = ""
    for item in line:
      if type(item) is str:
        ngram += item
      elif type(item) is int:
        if op == 0:
          ngram += self.words[item]
        else:
          ngram += self.rels[item]            
      else:
        ngram += self.prepare_SNgram(item,op)
    return ngram

  ######################              
  def len_Ngram(self, ngram):
    n = 1
    n += ngram.count("[")
    n += ngram.count(",")
    n -= ngram.count("\[")
    n -= ngram.count("\,")
    return n    

  ########################
  def get_all_DepNgrams(self):
    '''
    This method begins the process of getting all the sngrams of the dependency tree
    '''
    del self.DepNgrams  [:]
    
    if self.root_idx > 0:      
      self.DepNgrams.append (copy.deepcopy([self.root_idx]))#Add the firs sngram founded (root element)                                    
      self.traverse_DepTree (self.root_idx)#Call the method in order to get the rest of sngrams
    else:
      print "Error, no root found"


  def store_all_DepNgrams(self,op):
    '''
    This method stores the sngram in the container specified by the parameter op
    '''
    for item in self.DepNgrams:
      ngram = self.prepare_SNgram (item,op)
      size  = self.len_Ngram      (ngram)
      if op == 0:#Acording with the params, the sngrams are stored in the container
        self.listDep[size - 1].append (ngram)
      elif op == 1:
        self.listSRTags[size - 1].append (ngram)
                
            
  ######################              
  def traverse_DepTree (self, idx):   # Recursive function that traverses the syntactic tree and generates syntactic n-grams
    sig          = [] # List of sn-grams for the current node
    combinations = [] # List of all possible combinations of nodes and their children
    children = self.children [idx]
  
    #First sn-grams in a subtree
    for child in children:
      sig = []
      sig.append (child)
      self.DepNgrams.append (copy.deepcopy(sig))
                                    
    combinations = copy.deepcopy(self.get_next_combinations(idx, children))
    self.DepNgrams.extend(copy.deepcopy(combinations)) # We save new sn-grams in the global dictionary
                
    for num, child in enumerate(children):
      if child in self.leafs:
        if num + 1 == len(children):
          return(combinations)
      else:
        sig = []                
        sig = self.traverse_DepTree (child)

        while len(sig) > 0:
          temp  = []
          ngram = sig.pop(0)                    

          for idx2 in xrange(0, len(combinations)): 
            aux = copy.deepcopy (combinations [idx2])

            for idx, item in enumerate(aux):                            
              if item == ngram[0]:
                aux.pop (idx)
                aux.insert (idx, copy.deepcopy (ngram))
                self.DepNgrams.append (copy.deepcopy (aux))
                temp.append (copy.deepcopy (aux))
                
          if len(temp) > 0: 
            combinations.extend (copy.deepcopy(temp))
            
    return(combinations) 

  ######################                  
  def get_next_combinations (self, value, children):
    ngram         = []
    combinations  = [] # Auxiliary variable for generating a combination
    lista         = [] # Auxiliary variable for all sn-grams during analysis of a sub-tree
            
    # Initialize the list of combinations    
    for p in xrange(0, len(children)):
      combinations.append (0)
      
    #Generating sn-grams    
    for r in xrange (1, len(children) + 1):                 
      for j in xrange (1, r + 1):
        combinations [j - 1] = j - 1

      #################### The first combination
      ngram = []
      ngram.append (value)
      ngram.append ("[")
      for z in xrange (0, r):
        ngram.append(children [combinations [z]])  
        ngram.append (",")
      ngram.pop (len(ngram) - 1)          
      ngram.append ("]")            
      lista.append (copy.deepcopy(ngram))

      ################### The rest
      top = self.Combination (len(children), r)
      
      for j in xrange(2, top + 1):
        m = r
        val_max = len(children)

        while combinations [m - 1] + 1 == val_max:
          m       -= 1
          val_max -= 1

        combinations [m - 1] += 1

        for k in xrange (m + 1, r + 1):
          combinations [k - 1] = combinations [k - 2] + 1
            
        ngram = []
        ngram.append(value)
        ngram.append("[")                
        for z in xrange(0, r):
          ngram.append (children [combinations [z]])                
          ngram.append (",")
        ngram.pop (len(ngram) - 1)
        ngram.append ("]")
        lista.append (copy.deepcopy(ngram))
      
    return (lista)          
    
  ######################                  
  def Combination (self, sz, r):
    if sz == r:
      numerator = 1
    else:
      numerator = sz
      for i in xrange (1, sz):
        numerator *= sz - i
        
      aux = r
      for i in xrange (1, r):
        aux *= r - i
        
      divisor = sz - r
      for i in xrange (1, sz - r):
        divisor *= sz - r - i
        
      numerator = numerator / (aux * divisor)
      
    return (numerator)
  
  ###################### 
  def print_all_sn_grams (self,option):
    #This version of the method just print one kind of sngrams
    if option == 0:
      self.print_Words_sngrams()
    elif option == 1:      
      self.print_SR_sngrams()
    elif option == 2:
      self.print_Words_sngrams()
      self.print_SR_sngrams()
    elif option == 3:      
      self.print_cont_Words_sngrams()
    elif option == 4:      
      self.print_cont_SR_sngrams()
    elif option == 5:
      self.print_cont_Words_sngrams()
      self.print_cont_SR_sngrams()
            

  def print_Words_sngrams(self):
    print "************SNgrams of words/POS tags:"
    for idx, sn_list in enumerate(result.listDep):
      print "************Size: " + str(idx + 1) 
      for item in sn_list:
        print item

  def print_SR_sngrams(self):
    print "************SNgrams of tags of syntactic relations (SR tags):"
    for idx, sn_list in enumerate(result.listSRTags):
      print "************Size: " + str(idx + 1) 
      for item in sn_list:
        print item
    
  def print_cont_Words_sngrams(self):
    print "************Continuous SNgrams of words/POS tags:"
    for idx, sn_list in enumerate(result.listContDep):
      print "************Size: " + str(idx + 1) 
      for item in sn_list:
        print item

  def print_cont_SR_sngrams(self):
    print "************Continuous SNgrams of tags of syntactic relations (SR tags):"
    for idx, sn_list in enumerate(result.listContSRTags):
      print "************Size: " + str(idx + 1) 
      for item in sn_list:
        print item
        
  ###################### 
  def write_all_sn_grams (self, f2, option):
    '''
    This method write in a file one kind of sngram acording with the value of op
    '''
    if option == 0:
      self.write_WordSngrams(f2)
    elif option == 1:
      self.write_SRSngrams(f2)
    elif option == 2:
      self.write_WordSngrams(f2)
      self.write_SRSngrams(f2)
    elif option == 3:
      self.write_cont_WordSngrams(f2)
    elif option == 4:
      self.write_cont_SRSngrams(f2)
    elif option == 5:
      self.write_cont_WordSngrams(f2)
      self.write_cont_SRSngrams(f2)

  def write_WordSngrams(self,f2):
    f2.write("************SNgrams of words/POS tags:\n")
    for idx, sn_list in enumerate(result.listDep):
      f2.write ("************Size: " + str(idx + 1) + "\n") 
      for item in sn_list:
        f2.write (item + "\n")
    f2.write("\n")
        
  def write_SRSngrams(self,f2):
    f2.write("************SNgrams of tags of syntactic relations (SR tags):\n")
    for idx, sn_list in enumerate(result.listSRTags):
      f2.write ("************Size: " + str(idx + 1) + "\n") 
      for item in sn_list:
        f2.write (item + "\n")
    f2.write("\n")

  def write_cont_WordSngrams(self,f2):
    f2.write("************Continuous SNgrams of words/POS tags:\n")
    for idx, sn_list in enumerate(result.listContDep):
      f2.write ("************Size: " + str(idx + 1) + "\n") 
      for item in sn_list:
        f2.write (item + "\n")
    f2.write("\n")
        
  def write_cont_SRSngrams(self, f2):
    f2.write("************Continuous SNgrams of tags of syntactic relations (SR tags):\n")
    for idx, sn_list in enumerate(result.listContSRTags):
      f2.write ("************Size: " + str(idx + 1) + "\n") 
      for item in sn_list:
        f2.write (item + "\n")
    f2.write("\n")

############
def process_one_sentence (result, sent_num, f2, option):
  print "Sentence " + str(sent_num)
  result.process_sentence (lines, option)
  result.print_all_sn_grams (option)
  result.write_all_sn_grams (f2, option)
  return sent_num + 1    
  
############### MAIN ################################
if __name__ == '__main__':
  if len(sys.argv) < 3:
    print "Usage with at least two parameters:"
    print "python SNGrams.py input output"
    exit(1)
    
  fname_in  = sys.argv[1] #'input.txt'
  fname_out = sys.argv[2] #'output.txt'
  option = 0
  if len(sys.argv) > 3:
    option = int(sys.argv [3])
                # value 0: for sngrams of words; 
                # value 1: for sngrams of sr tags; 
                # value 2: for sngrams of word and sr tags (equal to call with option 0 and then with option 1)
                # value 3: for continuous sngrams of words; 
                # value 4: for continuous sngrams of sr tags; 
                # value 5: for continuous sngrams of words and sr tags (equal to call with option 3 and then with option 4)

  encod = 'utf-8'  # 'utf-8' or '1252'
  print "You are assuming enconding: " + encod
  
  try:
    f1 = codecs.open (fname_in,  "rU", encoding = encod)
  except IOError as e:
    print "I/O error({0}): {1}".format(e.errno, e.strerror)
    exit(1)
    
  try:
    f2 = codecs.open (fname_out, "wb", encoding = encod)  # b - Binary, for Unix line endings
  except IOError as e:
    print "I/O error({0}): {1}".format(e.errno, e.strerror)
    exit(1)
 
  sent_num = 1;
  result = SNgrams()
  lines  = []
  
  while True :
    ln = f1.readline ()
    if (not ln) or (ln == ""):
      break;

    ln = ln.strip()
    if ln == "":   # Sentences are separated by EMPTY line
      if len (lines) > 0:
        sent_num = process_one_sentence (result, sent_num, f2, option)
        del lines [:]
    else:
      lines.append (ln)

  if len(lines) > 0: # Last piece
    sent_num = process_one_sentence (result, sent_num, f2, option)

  f1.close ()
  f2.close ()
            
      
        
