import matplotlib.pyplot as plt
from agent import Agent

class Statistics:

    def __init__(self, ax_alive, ax_territory, strategy_cnt):
        '''
        time_scale: the x axis indicating the current iteration
        alive_cnt: number of agents alive for each strategy type
        territory: total number of squared captured by agents for each strategy type
        '''
        
        self.time_scale = []
        self.alive_cnt = {}
        self.territory = {}
        self.alive_cnt_graphs = {}
        self.territory_graphs = {}
        
        for i in range(strategy_cnt):
            self.alive_cnt[i] = []
            self.alive_cnt_graphs[i] = ax_alive.plot([0], [0], label=f'Strategy type: {i}')[0] # store ax to optimize and keep a single plot
        for i in range(strategy_cnt):
            self.territory[i] = []
            self.territory_graphs[i] = ax_territory.plot([0], [0], label=f'Strategy type: {i}')[0]
        
    def gather_board_statistics(self, board, iter_cnt):
        self.time_scale.append(iter_cnt) # update x-axis

        cur_alive_cnt = {}
        cur_territory = {}

        # calculate how many agents of each strategy type are alive
        for agent in Agent.all_agents:
            if agent.alive:
                cur_alive_cnt[agent.strategy_type] = cur_alive_cnt.get(agent.strategy_type, 0) + 1

        # calculate how much total territory agents of each strategy type have captured
        for x in range(len(board)):
            for y in range(len(board[x])):
                if board[x, y] > 0:
                    cur_strategy = Agent.all_agents[int(board[x, y]) - 1].strategy_type
                    cur_territory[cur_strategy] = cur_territory.get(cur_strategy, 0) + 1

        # add new values to our running lists
        for key in self.alive_cnt:
            self.alive_cnt[key].append(cur_alive_cnt.get(key, 0))
        for key in self.territory:
            self.territory[key].append(cur_territory.get(key, 0))

        # update our graphs with new values
        for key, value in self.alive_cnt.items():
            self.alive_cnt_graphs[key].set_data(self.time_scale, value)
        for key, value in self.territory.items():
            self.territory_graphs[key].set_data(self.time_scale, value)
        