<!DOCTYPE html>
<html>

  <head>
    <!-- <script src="https://ajax.googleapis.com/ajax/libs/prototype/1.7.2.0/prototype.js"></script> -->
    <script data-require="p5.js@*" data-semver="0.5.14" src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.5.14/p5.min.js"></script>
    <script data-require="p5.dom.js@*" data-semver="0.5.7" src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.5.7/addons/p5.dom.min.js"></script>
    <script src="https://cs.stanford.edu/people/karpathy/reinforcejs/lib/rl.js"></script>
    <script src="game.js"></script>
    <script src="snake.js"></script>
    <script src="agent.js"></script>
    <script src="script.js"></script>
    <style>
      input[type="number"] {
        max-width: 50px;
      }
    </style>
  </head>

  <body>
      <div id="sketch-holder"></div>

      <button onclick="enableWalls = !enableWalls; snakeGame.toggleEdges(enableWalls);">Toggle walls</button>
      <br/>

      Speed : <button onclick="itt=1000; mod=1;">Very fast</button>
      <button onclick="itt=500; mod=1;">Fast</button>
      <button onclick="itt=1; mod=3;">Normal</button>
      <br/>

      Score : <input id="scoretxt" disabled="disabled" />
      <br/>
      
      Input mode : 
      <input type="radio" name="inputMode" value="0" id="inputModeGrid"/>Grid mode
      <input type="radio" name="inputMode" value="1" id="inputModeNeighbour" />Neighbour mode<br/>
      NeighbourSize : <input type="number" min="1" max="5" step="1" id="neighbourSize">
      <br/>
      Update mode : 
      <input type="radio" name="updateMode" value="qlearn" id="qlearn"/>qlearn
      <input type="radio" name="updateMode" value="sarsa" id="sarsa" />sarsa
      <br/>
      Gamma : <input type="number" min="0" max="1" step="0.1" id="gamma"><button onclick="brain.agent.gamma = document.getElementById('gamma').value">Update</button><br/>
      Epsilon : <input type="number" min="0" max="1" step="0.01" id="epsilon"><button onclick="brain.agent.epsilon = document.getElementById('epsilon').value">Update</button><br/>
      Alpha : <input type="number" min="0" max="1" step="0.001" id="alpha"><button onclick="brain.agent.alpha = document.getElementById('alpha').value">Update</button><br/>
      Add experience every : <input type="number" min="1" step="1" id="experience_add_every"> step <button onclick="brain.agent.experience_add_every = document.getElementById('experience_add_every').value">Update</button><br/>
      tderror_clamp : <input type="number" min="0" max="1" step="0.1" id="tderror_clamp"><button onclick="brain.agent.tderror_clamp = document.getElementById('tderror_clamp').value">Update</button><br/>
      Hidden units : <input type="number" min="1"  step="10" id="num_hidden_units"><br/>
      <button onclick="brain = new Agent(getSpecs());">Reset brain</button>
  </body>

</html>

let enableWalls = true;
let size = 20;
let itt = 1;
let mod = 3;
let snakeGame;

let inputMode = NEIGHBOURS_CELLS;
let neighboursCells = 2;   

let spec = {
  inputMode: NEIGHBOURS_CELLS,
  neighboursCells: 2, //Neighbour cells. 1 => 9 cells, 2 => 25 cells ...
  size: size,
  update: 'qlearn', 
  gamma: 0.9,
  epsilon: 0.02,
  alpha: 0.005,
  experience_add_every: 5,
  tderror_clamp: 1.0,
  num_hidden_units: 100
};
let brain = null;

function setSpecsUI() {
  document.getElementById(spec.inputMode === GRID_MODE ? "inputModeGrid" : "inputModeNeighbour").checked = true;
  document.getElementById(spec.update === 'qlearn' ? "qlearn" : "sarsa").checked = true;
  document.getElementById('gamma').value = spec.gamma;
  document.getElementById('alpha').value = spec.alpha;
  document.getElementById('epsilon').value = spec.epsilon;
  document.getElementById('experience_add_every').value = spec.experience_add_every;
  document.getElementById('tderror_clamp').value = spec.tderror_clamp;
  document.getElementById('num_hidden_units').value = spec.num_hidden_units;
  document.getElementById('neighbourSize').value = spec.neighboursCells;
}
function getSpecs() {
  spec.inputMode = document.getElementById("inputModeGrid").checked ? GRID_MODE : NEIGHBOURS_CELLS;
  spec.update = document.getElementById("qlearn").checked ? "qlearn" : "sarsa";
  spec.gamma = document.getElementById('gamma').value;
  spec.alpha = document.getElementById('alpha').value;
  spec.epsilon = document.getElementById('epsilon').value;
  spec.experience_add_every = document.getElementById('experience_add_every').value;
  spec.tderror_clamp = document.getElementById('tderror_clamp').value;
  spec.num_hidden_units = document.getElementById('num_hidden_units').value;
  spec.neighboursCells = document.getElementById('neighbourSize').value;
  return spec;
}

function setup() {
  createCanvas(400, 400).parent('sketch-holder');
  snakeGame = new SnakeGame(20, size, size);
  snakeGame.toggleEdges(enableWalls); //Enable edges
  background(0);
  setSpecsUI();
  brain = new Agent(spec);
}

function draw() {
  background(0);
  for (let i = 0; i < itt; i++) {
    if (frameCount % mod === 0) {
      let direction = getBrainDecision();
      snakeGame.snake.updateDirection(direction);
      let reward = getReward();
      let result = snakeGame.update();
      if (result === 1) {
        reward = brain.rewards.apple;
      } else if (result === -1) {
        reward = brain.rewards.death;
      }
      brain.rewardCount.gameReward += reward;
      brain.learn(reward);

      if (result === -1) brain.rewardCount.recordReward();
    }
    if (itt < 100 || i === 0) drawChart();
  }
  document.getElementById('scoretxt').value = brain.rewardCount.totGamesMeanReward;
  if (itt > 100) console.log(brain.rewardCount.totGamesMeanReward);
}

function getReward() {
  return map(snakeGame.getDistToApple() ,0, Math.sqrt(2*(size*size)), brain.rewards.nearApple, brain.rewards.farApple);
}

function getBrainDecision() {
  let input = [];
  if (brain.spec.inputMode === NEIGHBOURS_CELLS) {
    let snakeHead = snakeGame.snake.getSnake()[0];
    for (let i = -neighboursCells; i <= neighboursCells; i++) {
      for (let j = -neighboursCells; j <= neighboursCells; j++) {
        input.push(map(snakeGame.getCellContent(snakeHead.x + i, snakeHead.y + j), 0, 4, 0, 1));
      }
    }
  } else {
   let i = 0;
    for (let line = 0; line < snakeGame.HCells; line++) {
      for (let col = 0; col < snakeGame.WCells; col++) {
        input[i++] = map(snakeGame.getCellContent(col, line), 0, 4, 0, 1);
      }
    }
  }
  input.push(map(snakeGame.getAngleToApple(), -180, 180, 0, 1));
  input.push(map(snakeGame.getDistToApple(), 0, Math.sqrt(2*(size*size)), 0, 1));
  let action = brain.act(input);
  return action;
}

function drawChart() {
  for (let line = 0; line < snakeGame.HCells; line++) {
    for (let col = 0; col < snakeGame.WCells; col++) {
      let x = col * snakeGame.cellSize;
      let y = line * snakeGame.cellSize;
      let w = snakeGame.cellSize;
      let h = snakeGame.cellSize;
      switch (snakeGame.getCellContent(col, line)) {
        case WALL:
          stroke(20, 20, 20);
          fill(120, 120, 120);
          rect(x, y, w, h);
          break;
        case SNAKE_CELL:
          stroke(150, 150, 150);
          fill(200, 200, 200);
          rect(x, y, w, h);
          stroke(120, 220, 120);
          fill(100, 200, 100);
          ellipse(x + w / 2, y + h / 2, w, h);
          break;
        case SNAKE_HEAD:
          stroke(150, 150, 150);
          fill(200, 200, 200);
          rect(x, y, w, h);
          stroke(100, 200, 100);
          fill(80, 180, 80);
          ellipse(x + w / 2, y + h / 2, w, h);
          break;
        case EMPTY_CELL:
          stroke(150, 150, 150);
          fill(200, 200, 200);
          rect(x, y, w, h);
          break;
        case APPLE:
          stroke(220, 120, 120);
          fill(200, 100, 100);
          rect(x, y, w, h);
          break;
      }
    }
  }
  stroke(0);
  strokeWeight(1);
  fill(255);
  text("Itt:" + brain.agent.t, 15, 15);
  text("Rew: " +
    brain.rewardCount.totGamesMinReward.toFixed(2) +
    " - " +
    brain.rewardCount.totGamesMaxReward.toFixed(2) +
    " - " +
    brain.rewardCount.totGamesMeanReward.toFixed(2), 110, 15);
}
const GRID_MODE = 0;         //The input is the full grid + dist & angle to apple
const NEIGHBOURS_CELLS = 1;  //The input is the n neighbour cells & angle to apple

class Agent {

  constructor(spec) {

    this.inputMode = inputMode;
    this.neighboursCells = neighboursCells;
    this.size = size;
    this.spec = spec;
    let that = this;

    this.env = {
      getNumStates: function() { return 2 + (that.spec.inputMode === GRID_MODE ? (that.spec.size*that.spec.size) : Math.pow(1 + 2 * that.spec.neighboursCells ,2)); },
      getMaxNumActions: function() { return 4; }
    };
    // create the DQN agent
    this.agent = new RL.DQNAgent(this.env, this.spec); 
    

    this.rewards = {
      apple: 100.0,
      death: -100.0,
      farApple: -10.0,
      nearApple: 10.0
    };

    this.rewardCount = {
      nGames:0,
      gameReward:0,
      totGamesReward:0,
      totGamesMeanReward:0,
      totGamesMinReward:9999999,
      totGamesMaxReward:-999999,
      recordReward: function() {
        this.nGames++;
        this.totGamesReward += this.gameReward;
        this.totGamesMeanReward = this.totGamesReward / this.nGames;
        if (this.gameReward > this.totGamesMaxReward) this.totGamesMaxReward = this.gameReward;
        if (this.gameReward < this.totGamesMinReward) this.totGamesMinReward = this.gameReward;
        this.gameReward = 0;
      }
    }

  }

  act(input) {
    return this.agent.act(input);
  }
  learn(reward) {
    this.agent.learn(reward);
  }
}
const EMPTY_CELL = 0;
const SNAKE_CELL = 1;
const WALL = 2;
const APPLE = 3;
const SNAKE_HEAD = 4;

class SnakeGame {
  constructor(cellSize, HCells, WCells) {
    this.toggleEdges(false);
    this.cellSize = cellSize;
    this.HCells = HCells;
    this.WCells = WCells;
    //Generating Snake
    this.snake = this._generateSnake(HCells, WCells);
    //Generating Apple
    this.apple = {x:this.snake._snake[0].x + 1, y:this.snake._snake[0].y};//this._generateApple();
  }

  toggleEdges(v) {
    this._egdes = v;
  }

  getCellContent(x, y) {
    if(this._isEdge(x, y)) return WALL;
    else if (this.apple.x === x && this.apple.y === y) return APPLE;
    else if (this.snake.isPresent(x, y)) return (this.snake.getSnake()[0].x === x && this.snake.getSnake()[0].y ===y) 
    ? SNAKE_HEAD : SNAKE_CELL;
    else return EMPTY_CELL;
  }

  _generateApple() {
    let x, y;
     do {
      x = Math.round(random(this.WCells - 1));
      y = Math.round(random(this.HCells - 1));
     } while(this.snake.isPresent(x, y) || this._isEdge(x, y));
     return {x:x, y:y};
  }

  _generateSnake(HCells, WCells) {
    let xSnake;
    let ySnake;
    xSnake = Math.round(this.WCells/2);
    ySnake = Math.round(this.HCells/2);
    return new Snake(xSnake, ySnake, HCells, WCells);
  }

  _isEdge(x, y) {
    let isEdge = 
      x === 0 ||
      x === this.WCells - 1 ||
      y === 0 ||
      y === this.HCells - 1;
    return this._egdes && isEdge;
  }

  update() {
    let result = 0;
    this.snake.move(this.apple);
    let head = this.snake.getSnake()[0];
    if (head.x === this.apple.x && head.y === this.apple.y) {
      this.apple = this._generateApple();
      result = 1;
    }
    if (this._isEdge(head.x, head.y) || this.snake.isCrossed()) {
      //Generating Snake
      this.snake = this._generateSnake(this.HCells, this.WCells);
      //Generating Apple
      this.apple = this._generateApple();
      //this.snake.revertMove(this.apple);
      result = -1;
    } 
    return result;
  }



  //AI
  getCellNextSnake(direction) {
    switch (direction) {
      case 0: //LEFT:
        break;
      case 1: //FRONT:
        break;
      case 2: //RIGHT:
        break;
    }
  }
  getDistToApple() {
    let head = this.snake.getSnake()[0];
    let apple = this.apple;
    let x = head.x - apple.x;
    let y = head.y - apple.y;
    return Math.sqrt(x*x + y*y);

  }

  getAngleToApple() {
    let head = this.snake.getSnake()[0];
    let apple = this.apple;
    let deltaY = apple.y - head.y
    let deltaX = apple.x - head.x
    return Math.atan2(deltaY, deltaX) * 180 / PI
  }
}
const UP = 0;
const RIGHT = 1;
const DOWN = 2;
const LEFT = 3;

class Snake {
  constructor(x, y, worldX, worldY) {
    this._world = {
      x: worldX,
      y: worldY
    }
    this._snake = [{x:x, y:y}, {x:x-1, y:y}];
    this._direction = RIGHT;
  }

  getSnake() {
    return this._snake;
  }

  isPresent(x, y) {
    let isPresent = false;
    for (let i = 0; i < this._snake.length; i++) {
      if (this._snake[i].x === x && this._snake[i].y === y) {
        isPresent = true;
        break;
      }
    }
    return isPresent;
  }

  move(apple) {
    switch (this._direction) {
      case UP:
        this._updateArray(0, -1, apple);
      break;
      case RIGHT:
        this._updateArray(1, 0, apple);
      break;
      case DOWN:
        this._updateArray(0, 1, apple);
      break;
      case LEFT:
        this._updateArray(-1, 0, apple);
      break;
    }
  }
  revertMove(apple) {
    switch (this._direction) {
      case UP:
        this._updateArray(0, 1, apple);
      break;
      case RIGHT:
        this._updateArray(-1, 0, apple);
      break;
      case DOWN:
        this._updateArray(0, -1, apple);
      break;
      case LEFT:
        this._updateArray(1, 0, apple);
      break;
    }
    this._snake = [this._snake[0], this._snake[1]];
  }

  updateDirection(direction) {
    this._direction = this._isOpposite(direction, this._direction) ? this._direction : direction;
  }

  _isOpposite(dir1, dir2) {
    switch (dir1) {
      case UP: return dir2===DOWN;
      case DOWN: return dir2===UP;
      case LEFT: return dir2===RIGHT;
      case RIGHT: return dir2===LEFT;
    }
  } 

  _updateArray(ox, oy, apple) {
    let head = {x: this._snake[0].x + ox, y: this._snake[0].y + oy};

    //If edge of world, move to opposite side
    if (head.x < 0) head.x = this._world.x;
    if (head.x > this._world.x) head.x = 0;
    if (head.y < 0) head.y = this._world.y;
    if (head.y > this._world.y) head.y = 0;
    
    if (head.x === apple.x && head.y === apple.y) {
      this._snake.unshift(head);
    } else {
      let nSnake = [];
      for (let i = 0; i < this._snake.length; i++) {
        nSnake.push(i === 0 ? head : this._snake[i - 1]);
      }
      this._snake = nSnake;
    }
  }

  isCrossed() {
    let head = this._snake[0];
    let crossed = false;
    for (let i = 0; i < this._snake.length; i++) {
      if (i !== 0 && this._snake[i].x === head.x && this._snake[i].y === head.y) {
        crossed = true;
        break;
      }
    }
    return crossed;
  }
}