# -*- coding: utf-8 -*-

import math
import sys
import json
from nupack_ops import *
from r_plotter_ops import *
import copy

class State(object):
  def __init__(self, program):
    self.program = program
    self.strands = [] # Ordered unpseudoknotted strands in the complex


class DNAProgram(object):

  def __init__(self, out_dir, program_name, temperature):
    self.out_dir = out_dir
    self.program_name = program_name
    self.design_filename = out_dir + '/' + self.program_name
    self.temperature = temperature
    self.fstop = 0.001
    self.instruction_size = 20
    self.clock_size = 16
    self.alpha_size = int(math.floor(self.clock_size / 2.0))
    self.beta_size = int(math.ceil(self.clock_size / 2.0))
    self.instr_compiled = {}
    self.instr_seq = {}
    self.instr_domains = {}
    self.design = []
    self.ordered_uniq_instructions = [] # All strands, mentioned once, in the
                                        # order of an unpseudoknotted complex
    self.instructions = [] # Just the user-specified instructions as read from file
    self.nupack = NupackOps()
    self.rplotter = RPlotterOps()
    self.desired_structure = ""
    self.states = []
    self.current_state = -1
    self.iteration = 0
    self.iteration_state = -1
    self.state_machine = [self.initial_state, self.half_tick_state, self.tick_state,
        self.post_tick_state, self.half_tock_state, self.tock_state]
    self.current_state_pairings = None


  # This is only here because we don't have list comprehension in the Jinja
  # template language
  def current_strand_domains(self):
    return json.dumps([self.instr_domains[s] for s in self.current_strands()])

  def step_with_tick(self):
    """ Changes the state of the program to a state where the tick strand
        hybridized with an available anti-tick """

    instruction_clock = self.instr_seq[self.instructions[0]]
    beta_address = self.instruction_size + self.alpha_size
    self.instr_seq['TICK'] = (
        instruction_clock[beta_address:beta_address + self.beta_size] + 
        instruction_clock[self.instruction_size:beta_address])
    self.instr_domains['TICK'] = [u'β*', u'α*']
    self.ordered_uniq_instructions.insert(-1, 'TICK')
    self.current_strands().insert(len(self.instructions), 'TICK')


  def next_state(self):
    self.states.append(copy.deepcopy(self.states[self.current_state]))
    self.current_state += 1
    self.iteration_state = (self.iteration_state + 1) % len(self.state_machine)
    if self.iteration_state == 0:
      self.iteration += 1
    self.state_machine[self.iteration_state]()


  def step_with_tock(self):
    """ Changes the state of the program to a state with a tock strand """

    anti_tick = self.instr_seq['ANTI-TICK']
    if not ('TOCK' in self.instr_seq):
      self.instr_seq['TOCK'] = (
          anti_tick[self.alpha_size:self.alpha_size + self.beta_size] +
          anti_tick[0:self.alpha_size])
      self.instr_domains['TOCK'] = [u'β', u'α']

    self.ordered_uniq_instructions.insert(-1, 'TOCK')
    self.current_strands().insert(len(self.instructions), 'TOCK')


  def read_program_file(self, design_filename):
    """ Reads a program file composed of one line instruction names, returning an
        array of the instructions in the file """

    instructions = []
    program_file = open(design_filename + '.stack', 'r')
    for line in program_file:
      instruction = line.strip()
      instructions.append(instruction)
      
    program_file.close()
    return instructions


  # TODO: Make this (and read_program_file) into static functions
  def write_design_file(self, design, design_filename):
    """ Writes the design molecules in the array to the filename specified by
        design_filename, suffixed by '.fold' """

    design_file = open(design_filename + '.fold', 'w')
    design_file.write('+'.join(design) + '\n')
    design_file.close()


  def current_strands(self):
    return self.states[self.current_state].strands


  def read_program(self):
    """ Initializes the DNA Program from the stack file """
    try: 
      self.instructions = self.read_program_file(self.design_filename)
      # TODO: Find a logging facility
      sys.stdout.write('Read program file\n')
    except IOError:
      sys.stderr.write('problem reading: ' + self.design_filename + '\n')

    self.states.append(State(self))
    self.current_state = 0
    toehold = False # Did we already generate a toehold domain for the first
                    # instruction being designed
    for instruction in reversed(self.instructions):
      self.current_strands().append(instruction)
      if not instruction in self.instr_compiled:
        self.ordered_uniq_instructions.append(instruction)
        self.instr_compiled[instruction] = '.' * self.instruction_size
        self.instr_domains[instruction] = [instruction, u'α*', u'β*']
        if not toehold:
          self.instr_compiled[instruction] += (('(' * self.alpha_size) + 
                                               ('(' * self.beta_size))
          toehold = True

        self.design.append(self.instr_compiled[instruction])

    # Append the anti-tick strand
    self.design.append(')' * self.alpha_size + ')' * self.beta_size)
    curr_strands = self.current_strands() 
    curr_strands += ['ANTI-TICK'] * (len(self.instructions) + 1)
    self.instr_domains['ANTI-TICK'] = [u'α', u'β']

    self.desired_structure = self.initial_desired_structure()
    self.next_state() # Prepare the initial state


  def render(self, renderer):
    renderer.render(self, self.current_state)

  
  def write_design(self):
    try:
      self.write_design_file(self.design, self.design_filename)
      # TODO: Logging facility
      sys.stdout.write("Wrote design file\n")
    except IOError:
      sys.stderr.write('Cannot write to: ' + design_filename + '\n')


  def design_sequences(self):
    """ Uses Nupack to figure out the sequences for the program's design, and
        saves them in instr_seq """
    self.write_design()
    complex_seq = self.nupack.design_sequences(self.temperature, self.fstop,
        self.design_filename)
    parts = complex_seq.split('+')
    # Fill in the first instruction as the sequence was given.  We have to
    # iterate in the same (backward) order we iterated when creating the design
    # The anti-tick is reversed in order to correct pseudoknotting
    self.instr_seq['ANTI-TICK'] = (
        parts[-1][self.beta_size:self.beta_size + self.alpha_size] + 
        parts[-1][0:self.beta_size])
    self.ordered_uniq_instructions.append('ANTI-TICK')
    instruction = self.instructions[-1]
    self.instr_seq[instruction] = parts[0]
    i = 1
    binding_part = self.instr_seq[instruction][(-1 * self.clock_size):]
    for instruction in reversed(self.instructions[0:-1]):
      if not instruction in self.instr_seq:
        self.instr_seq[instruction] = parts[i] + binding_part
        i += 1


  def pretty_sequences(self):
    return str(self.instr_seq)


  def write_analysis_file(self, design_filename, instr_seq, 
      ordered_uniq_instructions, strands):
    """ Writes the program's .in file for processing by Nupack. """

    analysis_file = open(design_filename + '.in', 'w')
    analysis_file.write(str(len(ordered_uniq_instructions)) + "\n")
    for instruction in ordered_uniq_instructions:
      analysis_file.write(instr_seq[instruction] + "\n")

    for strand in strands:
      # TODO: Use a map instead of index
      analysis_file.write(str(ordered_uniq_instructions.index(strand) + 1) + " ")

    analysis_file.close()


  def desired_structures(self):
    states = [self.initial_desired_structure,
              self.desired_structure_half_tick,
              self.desired_structure_tick,
              self.desired_structure_post_tick,
              self.desired_structure_half_tock,
              self.desired_structure_tock]

    return [state_func() for state_func in states]


  # t=1
  def initial_state(self):
    if self.iteration > 1:
      self.ordered_uniq_instructions.remove('TOCK')
      # Remove the hybridized tock and instruction
      del self.current_strands()[len(self.instructions) - 1:len(self.instructions) + 1]
      # TODO: Make a state for instructions and delete the last instruction as well

    self.current_state_pairings = self.initial_state_pairings

  # t=2
  def half_tick_state(self):
    self.step_with_tick()
    self.current_state_pairings = self.half_tick_state_pairings

  # t=3
  def tick_state(self): 
    self.current_state_pairings = self.tick_state_pairings

  # t=4
  def post_tick_state(self):
    # Remove the hybridized tick
    del self.current_strands()[len(self.instructions):len(self.instructions) + 2]
    self.ordered_uniq_instructions.remove('TICK')
    self.current_state_pairings = self.post_tick_state_pairings

  # t=5
  def half_tock_state(self):
    self.step_with_tock()
    self.current_state_pairings = self.half_tock_state_pairings

  # t=6
  def tock_state(self):
    self.current_state_pairings = self.tock_state_pairings

  # t=1
  def initial_state_pairings(self):
    pairings = []
    for i in xrange(len(self.instructions)):
      pairings.append([i * 3 + 1, len(self.instructions) * 3 + (len(self.instructions) - i) * 2])
      pairings.append([i * 3 + 2, len(self.instructions) * 3 + (len(self.instructions) - i) * 2 - 1])

    return pairings

  # t=1
  def initial_desired_structure(self):
    # TODO: This can be done in a less specific way, by iterating over and using the
    # compiled instructions. However, it's pretty easy for this task to be specific.
    return ( 
        ((("." * self.instruction_size + "(" * self.clock_size) + "+") 
            * len(self.instructions)) + 
        ("." * self.alpha_size) + (")" * self.beta_size) + "+" +
        (((")" * self.clock_size) + "+") * (len(self.instructions) - 1)) + 
        (")" * self.alpha_size + "." * self.beta_size) )


  # t=2
  def half_tick_state_pairings(self):
    pairings = []
    for i in xrange(len(self.instructions)):
      pairings.append([i * 3 + 1, len(self.instructions) * 3 + 2 + (len(self.instructions) - i) * 2])
      pairings.append([i * 3 + 2, len(self.instructions) * 3 + 2 + (len(self.instructions) - i) * 2 - 1])

    # This is the extra half-tick binding
    pairings.append([len(self.instructions) * 3 + 1, len(self.instructions) * 3 + 2])

    return pairings


  # t=2
  def desired_structure_half_tick(self):
    return (
        ((("." * self.instruction_size + "(" * self.clock_size) + "+") 
            * len(self.instructions)) + 
        ("." * self.beta_size) + ("(" * self.alpha_size) + "+" +
        (((")" * self.clock_size) + "+") * len(self.instructions)) + 
        (")" * self.alpha_size + "." * self.beta_size) )


  # t=3
  def tick_state_pairings(self):
    pairings = []
    for i in xrange(len(self.instructions) - 1):
      pairings.append([i * 3 + 1, len(self.instructions) * 3 + 2 + (len(self.instructions) - i) * 2])
      pairings.append([i * 3 + 2, len(self.instructions) * 3 + 2 + (len(self.instructions) - i) * 2 - 1])

    last_instr = len(self.instructions) - 1
    pairings.append([last_instr * 3 + 1, len(self.instructions) * 3 + 2 + (len(self.instructions) - last_instr) * 2])
    # This is the extra half-tick binding
    pairings.append([len(self.instructions) * 3, len(self.instructions) * 3 + 3])
    pairings.append([len(self.instructions) * 3 + 1, len(self.instructions) * 3 + 2])

    return pairings


  # t=3
  def desired_structure_tick(self):
    return (
        ((("." * self.instruction_size + "(" * self.clock_size) + "+") 
            * (len(self.instructions) - 1)) + 
        (("." * self.instruction_size) + ("(" * self.alpha_size) + 
            ("." * self.beta_size)) + "+" +
        ("(" * self.clock_size) + "+" + (")" * self.clock_size) + "+" +
        (((")" * self.clock_size) + "+") * (len(self.instructions) - 1)) + 
        (")" * self.alpha_size + "." * self.beta_size) )

    
  # t=4
  def post_tick_state_pairings(self):
    pairings = []
    for i in xrange(len(self.instructions) - 1):
      pairings.append([i * 3 + 1, len(self.instructions) * 3 + (len(self.instructions) - 1 - i) * 2])
      pairings.append([i * 3 + 2, len(self.instructions) * 3 + (len(self.instructions) - 1 - i) * 2 - 1])

    last_instr = len(self.instructions) - 1
    pairings.append([last_instr * 3 + 1, len(self.instructions) * 3])

    return pairings


  # t=4
  def desired_structure_post_tick(self):
    return (
        ((("." * self.instruction_size + "(" * self.clock_size) + "+") 
            * (len(self.instructions) - 1)) + 
        (("." * self.instruction_size) + ("(" * self.alpha_size) + 
            ("." * self.beta_size)) + "+" +
        (((")" * self.clock_size) + "+") * (len(self.instructions) - 1)) + 
        (")" * self.alpha_size + "." * self.beta_size) )
       

  # t=5
  def half_tock_state_pairings(self):
    pairings = []
    for i in xrange(len(self.instructions) - 1):
      pairings.append([i * 3 + 1, len(self.instructions) * 3 + (len(self.instructions) - i) * 2])
      pairings.append([i * 3 + 2, len(self.instructions) * 3 + (len(self.instructions) - i) * 2 - 1])

    last_instr = len(self.instructions) - 1
    pairings.append([last_instr * 3 + 2, len(self.instructions) * 3])
    pairings.append([last_instr * 3 + 1, len(self.instructions) * 3 + 2])

    return pairings


  # t=5
  def desired_structure_half_tock(self):
    return (
        ((("." * self.instruction_size + "(" * self.clock_size) + "+") 
            * len(self.instructions)) + 
        ((")" * self.beta_size) + ("." * self.alpha_size) + "+") +
        (((")" * self.clock_size) + "+") * (len(self.instructions) - 1)) + 
        (")" * self.alpha_size + "." * self.beta_size) )

  # t=6
  def tock_state_pairings(self):
    pairings = []
    for i in xrange(len(self.instructions) - 1):
      pairings.append([i * 3 + 1, len(self.instructions) * 3 + (len(self.instructions) - i) * 2])
      pairings.append([i * 3 + 2, len(self.instructions) * 3 + (len(self.instructions) - i) * 2 - 1])

    last_instr = len(self.instructions) - 1
    pairings.append([last_instr * 3 + 2, len(self.instructions) * 3])
    pairings.append([last_instr * 3 + 1, len(self.instructions) * 3 + 1])

    return pairings

   
  # t=6
  def desired_structure_tock(self):
    return (
        ((("." * self.instruction_size + "(" * self.clock_size) + "+") 
            * len(self.instructions)) + 
        ((")" * self.clock_size) + "+") +
        (("." * self.alpha_size) + (")" * self.beta_size) + "+") +
        (((")" * self.clock_size) + "+") * (len(self.instructions) - 2)) + 
        (")" * self.alpha_size + "." * self.beta_size) )


  def write_analysis(self, filename):
    try:
      # Write the program's .in file
      self.write_analysis_file(filename, self.instr_seq, 
                               self.ordered_uniq_instructions,
                               self.current_strands())
      sys.stdout.write("Wrote analysis file to " + filename + ".in\n")
    except IOError:
      sys.stderr.write("Cannot write analysis file to " + 
          filename + ".in\n")


  def run_design_analysis(self):
    plot_filename = self.design_filename + '-' + str(self.current_state)
    self.write_analysis(plot_filename)
    self.nupack.generate_pairs_data(self.temperature, plot_filename)
    sys.stdout.write("Ran Nupack pairs program\n")

    self.rplotter.plot_pairs(plot_filename)
    sys.stdout.write("Plotted pairs PDF using R\n")



