import numpy as np

class Pather:
    def __init__(self, maze_choice):
        self.maze_map = self.create_map(maze_choice)
        
    #creates maze map and marks start and exit indicies, based on chosen configuration
    def create_map(self,maze_choice):
        maze_map = [[1 for _ in range(9)] for _ in range(9)]
        # Maze 1: Top Left
        if (maze_choice == 1):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][4] = 0
            maze_map[1][5] = 0
            maze_map[1][6] = 0
            maze_map[1][7] = 0
            maze_map[2][1] = 0
            maze_map[2][7] = 0
            maze_map[3][1] = 0
            maze_map[3][3] = 0
            maze_map[3][4] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[4][1] = 0
            maze_map[4][3] = 0
            maze_map[4][7] = 0
            maze_map[5][1] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][1] = 0
            maze_map[6][3] = 0
            maze_map[7][1] = 0
            maze_map[7][3] = 0
            maze_map[7][4] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            maze_map[8][7] = 0
            self.start_row = 1
            self.start_col = 1
            self.destination_row = 8
            self.destination_col = 7
        # Maze 2: Top Right
        elif (maze_choice == 2):
            maze_map[1][1] = 0
            maze_map[1][3] = 0
            maze_map[1][5] = 0
            maze_map[1][6] = 0
            maze_map[1][7] = 0
            maze_map[2][1] = 0
            maze_map[2][3] = 0
            maze_map[2][7] = 0
            maze_map[3][1] = 0
            maze_map[3][2] = 0
            maze_map[3][3] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[4][3] = 0
            maze_map[4][5] = 0
            maze_map[5][0] = 0
            maze_map[5][1] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][1] = 0
            maze_map[6][5] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][4] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            self.start_row = 1
            self.start_col = 3
            self.destination_row = 5
            self.destination_col = 0
        # Maze 3: Bottom Right
        elif (maze_choice == 3):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][5] = 0
            maze_map[1][6] = 0
            maze_map[1][7] = 0
            maze_map[2][3] = 0
            maze_map[2][7] = 0
            maze_map[3][0] = 0
            maze_map[3][1] = 0
            maze_map[3][3] = 0
            maze_map[3][4] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[4][1] = 0
            maze_map[4][7] = 0
            maze_map[5][1] = 0
            maze_map[5][2] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][1] = 0
            maze_map[6][5] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            self.start_row = 1
            self.start_col = 1
            self.destination_row = 3
            self.destination_col = 0
        # Maze 4: Bottom Left
        elif (maze_choice == 4):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][5] = 0
            maze_map[1][6] = 0
            maze_map[1][7] = 0
            maze_map[2][1] = 0
            maze_map[2][5] = 0
            maze_map[3][1] = 0
            maze_map[3][2] = 0
            maze_map[3][3] = 0
            maze_map[3][4] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[4][7] = 0
            maze_map[5][1] = 0
            maze_map[5][2] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][3] = 0
            maze_map[6][5] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            maze_map[8][1] = 0
            self.start_row = 1
            self.start_col = 7
            self.destination_row = 8
            self.destination_col = 1
        # Maze 1'
        elif (maze_choice == 5):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][4] = 0
            maze_map[1][5] = 0
            maze_map[1][7] = 0
            maze_map[2][5] = 0
            maze_map[2][7] = 0
            maze_map[3][1] = 0
            maze_map[3][3] = 0
            maze_map[3][4] = 0
            maze_map[3][5] = 0
            maze_map[3][7] = 0
            maze_map[4][1] = 0
            maze_map[4][5] = 0
            maze_map[4][7] = 0
            maze_map[5][1] = 0
            maze_map[5][2] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][1] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][4] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            maze_map[8][7] = 0
            self.start_row = 1
            self.start_col = 1
            self.destination_row = 8
            self.destination_col = 7
        # Maze 2'
        elif (maze_choice == 6):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][5] = 0
            maze_map[1][6] = 0
            maze_map[1][7] = 0
            maze_map[2][1] = 0
            maze_map[2][5] = 0
            maze_map[3][1] = 0
            maze_map[3][3] = 0
            maze_map[3][4] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[3][8] = 0
            maze_map[4][1] = 0
            maze_map[4][3] = 0
            maze_map[5][1] = 0
            maze_map[5][2] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][1] = 0
            maze_map[6][5] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            self.start_row = 1
            self.start_col = 3
            self.destination_row = 3
            self.destination_col = 8
        # Maze 3'
        elif (maze_choice == 7):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][4] = 0
            maze_map[1][5] = 0
            maze_map[1][7] = 0
            maze_map[1][8] = 0
            maze_map[2][1] = 0
            maze_map[2][3] = 0
            maze_map[2][7] = 0
            maze_map[3][1] = 0
            maze_map[3][3] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[4][3] = 0
            maze_map[4][5] = 0
            maze_map[5][1] = 0
            maze_map[5][2] = 0
            maze_map[5][3] = 0
            maze_map[5][5] = 0
            maze_map[5][7] = 0
            maze_map[6][1] = 0
            maze_map[6][5] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][4] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            self.start_row = 1
            self.start_col = 5
            self.destination_row = 1
            self.destination_col = 8
        # Maze 4'
        elif (maze_choice == 8):
            maze_map[1][1] = 0
            maze_map[1][2] = 0
            maze_map[1][3] = 0
            maze_map[1][5] = 0
            maze_map[1][6] = 0
            maze_map[1][7] = 0
            maze_map[2][1] = 0
            maze_map[2][3] = 0
            maze_map[2][5] = 0
            maze_map[3][1] = 0
            maze_map[3][3] = 0
            maze_map[3][4] = 0
            maze_map[3][5] = 0
            maze_map[3][6] = 0
            maze_map[3][7] = 0
            maze_map[4][1] = 0
            maze_map[5][1] = 0
            maze_map[5][2] = 0
            maze_map[5][3] = 0
            maze_map[5][4] = 0
            maze_map[5][5] = 0
            maze_map[5][6] = 0
            maze_map[5][7] = 0
            maze_map[6][3] = 0
            maze_map[6][7] = 0
            maze_map[7][1] = 0
            maze_map[7][2] = 0
            maze_map[7][3] = 0
            maze_map[7][5] = 0
            maze_map[7][6] = 0
            maze_map[7][7] = 0
            maze_map[8][1] = 0
            self.start_row = 1
            self.start_col = 7
            self.destination_row = 8
            self.destination_col = 1
        return maze_map

    #initializes maze environment with q-values, actions that can be taken, etc.
    def init_env(self):
        #set number of columns and rows in the maze
        self.maze_rows = len(self.maze_map)
        self.maze_cols = len(self.maze_map[0])
        #set number of actions that can be taken
        self.actions = ['forward','left','right','backward']
        #set q-values for each space
        self.q_values = np.zeros((self.maze_rows, self.maze_cols,len(self.actions)))
        #set rewards for accessing each space in the maze
        self.rewards = np.full((self.maze_rows, self.maze_cols),-100.)
        
        #After merging, set rewards for space states
        for i in range(len(self.rewards)):
            for j in range(len(self.rewards[i])):
                if (self.maze_map[i][j] == 0):
                    self.rewards[i][j] = -1.
        
        #set reward for destination space
        self.set_destination(self.destination_row,self.destination_col)

        
    def set_destination(self,destination_row, destination_col):
        self.rewards[destination_row][destination_col] = 100.
        
    def print_maze(self):
        for i in range(len(self.q_values)-1,-1,-1):
            print(self.maze_map[i])
            
    #determines if a state is a terminal state
    def is_terminal_state(self,current_row, current_col):
      #if the reward for this state is -1, then it is not a terminal state
      if self.rewards[current_row, current_col] == -1.:
        return False
      else:
        return True

    #chooses random starting state, that is not a terminal one
    def get_start_state(self):
        current_row = np.random.randint(self.maze_rows)
        current_col = np.random.randint(self.maze_cols)
        #if row,col chosen is a terminal state, then choose another until
        #non-terminal state is chosen
        while (self.is_terminal_state(current_row,current_col)):
            current_row = np.random.randint(self.maze_rows)
            current_col = np.random.randint(self.maze_cols)
        return current_row, current_col

    #chooses which action to take next
    def get_next_action(self,current_row,current_col,epsilon):
        #if random number is between 0 and the probability of taking the best
        #action, then take the best action
        if np.random.random() < epsilon:
            return np.argmax(self.q_values[current_row,current_col])
        else:
            return np.random.randint(4)

    #gets next state based on chosen action to take
    def get_next_state(self,current_row,current_col,action):
        new_row = current_row
        new_col = current_col
        if ((self.actions[action]=='forward') and (current_row > 0)):
            new_row -= 1
        elif ((self.actions[action]=='right') and (current_col < self.maze_cols-1)):
            new_col += 1
        elif ((self.actions[action]=='backward') and (current_row < self.maze_rows-1)):
            new_row += 1
        elif ((self.actions[action]=='left') and (current_col > 0)):
            new_col -= 1
        return new_row, new_col

    #obtains path from start to destination
    def get_path(self):
        path = list()
        if (self.is_terminal_state(self.start_row,self.start_col)):
            print('This is a terminal state')
        else:
            current_row, current_col = self.start_row, self.start_col
            path.append([current_row,current_col])
            #move along path until destination is reached
            while (not self.is_terminal_state(current_row,current_col)):
                action = self.get_next_action(current_row,current_col,1)
                current_row,current_col = self.get_next_state(current_row,current_col,
                                                         action)
                path.append([current_row,current_col])
        return self.maze_map,path

    #trains the q-learning algorithm on the maze
    def train_model(self,epsilon=0.9,discount_factor=0.9,learning_rate=0.9,
                    iterations=1000):
        #run through 1000 training iterations
        for _ in range(iterations):
            #get the starting location for this episode
            row_index, col_index = self.get_start_state()

            #continue taking actions (i.e., moving) until we reach a terminal state
            while not self.is_terminal_state(row_index, col_index):
                #choose which action to take (i.e., where to move next)
                action_index = self.get_next_action(row_index, col_index, epsilon)

                #perform the chosen action, and transition to the next state (i.e., move        #to the next location)
                old_row_index, old_col_index = row_index, col_index
                row_index, col_index = self.get_next_state(row_index, col_index,
                                                         action_index)
                #receive the reward for moving to the new state, and calculate
                #the temporal difference
                reward = self.rewards[row_index, col_index]
                old_q_value = self.q_values[old_row_index, old_col_index, action_index]
                temporal_difference = reward + (discount_factor * np.max(
                                      self.q_values[row_index, col_index])) - old_q_value

                #update the Q-value for the previous state and action pair
                new_q_value = old_q_value + (learning_rate * temporal_difference)
                self.q_values[old_row_index, old_col_index, action_index] = new_q_value

        print('Model training complete!')

