<!DOCTYPE html>
<html>

  <head>
    <!-- <script src="https://ajax.googleapis.com/ajax/libs/prototype/1.7.2.0/prototype.js"></script> -->
    <script data-require="p5.js@*" data-semver="0.5.14" src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.5.14/p5.min.js"></script>
    <script data-require="p5.dom.js@*" data-semver="0.5.7" src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.5.7/addons/p5.dom.min.js"></script>
    <script src="randomForest.js"></script>
    <script src="script.js"></script>
  </head>

  <body>
  </body>

</html>
let WIDTH = 400;
let HEIGHT = 400;

let learningSet = []
let lsSize = 200;
let dataToPredict = [];
let predSize = 100;

let classes = [0, 1, 2, 3, 4, 5];
let RFmotor;

let trained = false;
let createdTrees = 0;
let forestSize = 300;

let origX = 10;
let origY = 10;
let pbX = origX;
let pbY = origY;
let drawModulo = 10;
let offX = (WIDTH - 2 * pbX) / (forestSize/drawModulo);
let offY = 15;

function setup() {
  createCanvas(WIDTH, HEIGHT);
  background(0);
  initLearning();
  initDataToPredict();

  window.addEventListener('treecreated', function (e) {
    fill(255);
    stroke(255);
    if (createdTrees%drawModulo === 1) {
      rect(pbX, pbY, offX, 7.6);
      pbX += offX; 
    }
    stroke(0);
  });
  RFmotor = new RandomForest(learningSet, classes);
}

function draw() {
  fill(255);
  if (!trained) {
    if (createdTrees >= forestSize) {
      trained = true;
      pbY += 2*offY;
      text('End of training.', origX, pbY);
      pbY += offY;
      text('Evaluating training...', origX, pbY);
      let confData = RFmotor.evaluateCrossValidation(2, 10);
      drawMatrice(confData.data, 200, 35);
      console.log(confData);
      pbY += offY;
      text('Predicting data.', origX, pbY);
    } else {
      RFmotor.train();
      createdTrees++;
    }
  } else {
    let err = 0;
    let errDet = [];
    for (let i =0; i < classes.length; i ++) {
      errDet.push({label:getClassLabel(i), nbTot:0, nbErr: 0});
    }
    for (let i = 0; i < dataToPredict.length; i++) {
      let result = RFmotor.predict(dataToPredict[i]);
      errDet[dataToPredict[i].category].nbTot++;
      if (result !== dataToPredict[i].category) {
        err++;
        errDet[dataToPredict[i].category].nbErr++;
        pbY += offY;
        text('Pred failed : ' + getClassLabel(result) + ' instead of ' + getClassLabel(dataToPredict[i].category), origX, pbY);
      }
    }
    pbY += 2*offY;
    text("Error : " + err + '/' + dataToPredict.length + ' (' + (err*100/dataToPredict.length) + '%)', origX, pbY);
    console.log(errDet);
    noLoop();
  }
}

function drawMatrice(data, x, y) {
  let origx = x;
  for (let i = 0; i < data.length; i++) {
    for (let j = 0; j < data.length; j++) {
      text(data[i][j].toFixed(), x, y);
      x += 25;
    }
    y += 20;  
    x = origx;
  }
}

function initLearning() {
  for (let i = 0; i < lsSize; i++)
    learningSet.push(generateRandomColor());
}

function initDataToPredict() {
  for (let i = 0; i < predSize; i++)
  dataToPredict.push(generateRandomColor());
}

function generateRandomColor() {
  switch(random(classes)) {
    case 0 : return {params: {
      //dog
      legs:4,
      walkOnLegs:4,
      wings:0,
      tail:1,
      height: random(40, 100),
      width: random(40, 120),
      weight: random(15, 70)
    },
    category:0};
    case 1 : return {params: {
      //bird
      legs:2,
      wings:2,
      walkOnLegs:2,
      tail:0,
      height: random(7, 20),
      width: random(8, 30),
      weight: random(0.2, 5)
    },
    category:1};
    case 2 : return {params: {
      //fish
      legs:0,
      wings:0,
      walkOnLegs:0,
      tail:1,
      height: random(0.3, 10),
      width: random(5, 45),
      weight: random(0.4, 4)
    },
    category:2};
    case 3 : return {params: {
      //monkey
      legs:4,
      walkOnLegs:2,
      wings:0,
      tail:1,
      height: random(70, 140),
      width: random(30, 60),
      weight: random(10, 30)
    },
    category:3};
    case 4 : return {params: {
      //horse
      legs:4,
      walkOnLegs:4,
      wings:0,
      tail:1,
      height: random(150, 220),
      width: random(180, 240),
      weight: random(80, 120)
    },
    category:4};
    case 5 : return {params: {
      //human
      legs:2,
      walkOnLegs:2,
      wings:0,
      tail:0,
      height: random(155, 200),
      width: random(38, 70),
      weight: random(50, 120)
    },
    category:5};
  }
}

function getClassLabel(i) {
  switch(i) {
    case 0 : return 'dog';
    case 1 : return 'bird';
    case 2 : return 'fish';
    case 3 : return 'monkey';
    case 4 : return 'horse';
    case 5 : return 'human';
  }
}
class RandomForest {
  constructor(learning, classes, percentParam = -1) {
    this._forest = [];
    this._nParameters = Object.keys(learning[0].params).length;
    this._categories = classes;
    this._subSampleSize = Math.round(0.64 * learning.length);
    this._numberRandomParameters = 
      (percentParam === -1)  ? 
        Math.round(Math.sqrt(this._nParameters)) 
        : 
        Math.round(percentParam/100.0 * this._nParameters);
    this._percentParam = percentParam;
    this._learning = learning;

  }

  train() {
    this._populateForest(1, this._learning, true);
  }

  predict(sample) {
    let counter = [];
    for (let i = 0; i < this._categories.length; i++) 
      counter[i] = 0;
    for (let i = 0; i < this._forest.length; i++) {
      let predictedCategory = this._getCategoryFromTree(this._forest[i], this._arrayFromParams(sample.params));
      counter[this._categories.indexOf(predictedCategory)]++;
    }
    let maxPredicted = -1;
    let nPredicted = -1;
    for(let i = 0; i < this._categories.length; i++) {
      if (counter[i] > nPredicted) {
        maxPredicted = i;
        nPredicted = counter[i];
      }
    }
    return maxPredicted;
  }

  evaluateCrossValidation(itterations, forestSize) {
    let matrice = {
      data: []
    };

    for(let i = 0; i < this._categories.length; i++) {
      matrice.data[i] = [];
      for(let j = 0; j < this._categories.length; j++) {
        matrice.data[i][j] = 0;
      }
    }

    for (let itt = 0; itt < itterations; itt++) {
      //Randomize learning and split it in two arrays
      let randomizedLearning = this._shuffleArray(this._learning);
      let arrA = [];
      let arrB = [];
      let splitpoint = randomizedLearning.length * 0.6;
      for (let i = 0; i < randomizedLearning.length; i++) {
        if (i < splitpoint) arrA.push(randomizedLearning[i]);
        else arrB.push(randomizedLearning[i]);
      }
      //Predict A with B as learning set
      let RFA = new RandomForest(arrB, this._categories, this._percentParam);
      for (let i = 0; i < forestSize; i++) {
        RFA._populateForest(1, arrB, false);
      }
      for (let i = 0; i < arrA.length; i++) {
        let result = RFA.predict(arrA[i]);
        matrice.data[arrA[i].category][result]++;
      }
      //Predict B with A as learning set
      let RFB = new RandomForest(arrA, this._categories, this._percentParam);
      for (let i = 0; i < forestSize; i++) {
        RFB._populateForest(1, arrA, false);
      }
      for (let i = 0; i < arrB.length; i++) {
        let result = RFB.predict(arrB[i]);
        matrice.data[arrB[i].category][result]++;
      }
    }
    return matrice;
  }

  /********************************************************************************/
  /*                                 PRIVATE                                      */
  /********************************************************************************/
  _getCategoryFromTree(node, data) {
    if (node.category !== undefined) return node.category;
    let f = this._cosineSimilarity(data, node.classifier, node.paramThreshold);
    if (f > node.threshold) 
      return this._getCategoryFromTree(node.left, data);
    else
      return this._getCategoryFromTree(node.right, data);
  }

  _populateForest(forestSize, learning, dispatchEvent = false) {
    for (let i = 0; i < forestSize; i++) {
      this._forest.push(this._createTree(learning));
      if (dispatchEvent) 
        window.dispatchEvent(new CustomEvent('treecreated'));
    }
  }

  _createTree(learning) {
    let subSample = this._getRandom(learning, this._subSampleSize);
    let node = this._nodeFactory(subSample);
    if (node.category === undefined) {
      this._makeLeftAndRightNodes(node, subSample);
    }
    return node;
  }

  _nodeFactory(subSample) {
    let node = {};
    if (subSample.length === 1) {
      node.category = subSample[0].category;
    } else {
      let param = [];
      let c = 0;
      do {
        let randomParamIdx = Math.round(random(this._nParameters));
        if (param.indexOf(randomParamIdx) === -1) {
          c++;
          param.push(randomParamIdx);
        }
      } while (c < this._numberRandomParameters);
      node.paramThreshold = param;
      node.classifier = this._makeClassifier(subSample);
      node.threshold = this._getThreshold(subSample, node);
    }
    return node;
  }

  _makeLeftAndRightNodes(node, subSample) {
    let indicator = this._createIndicator(node, subSample);

    let left = [];
    let right = [];
    for (let i = 0; i < indicator.length; i++) {
      if (indicator[i] > 0) {
        left.push(subSample[i]);
      } else {
        right.push(subSample[i]);
      }
    }
    node.left = this._nodeFactory(left);
    node.right = this._nodeFactory(right);
    if (node.left.category === undefined) 
      this._makeLeftAndRightNodes(node.left, left);
    if (node.right.category === undefined) 
      this._makeLeftAndRightNodes(node.right, right);
  }

  _createIndicator(node, subSample) {
    let indicator = [];
    let cnt = 0;
    for(let i = 0; i < subSample.length; i++) {
      let f = this._cosineSimilarity(this._arrayFromParams(subSample[i].params), node.classifier, node.paramThreshold);
      indicator[cnt] = f >= node.threshold ? 1 : 0;
      cnt++;
    }
    //Infinite loop if indicator is all the same value
    let isSameValue = true;
    let val = indicator[0];
    for (let i = 0; i < subSample.length; i++) {
      if (val !== indicator[i]) {
        isSameValue = false;
        break;
      }
    }
    if (isSameValue) indicator[0] = val === 0 ? 1 : 0;
    return indicator;
  }

  _makeClassifier(subSample) {
    let categoryCount = [];
    for (let i = 0 ; i < this._categories.length; i++) 
      categoryCount.push(0);
    //On compte le nombre de fois que chaque classe apparaît dans les échantillons passés en paramètres (histogramme)
    for (let i = 0; i < subSample.length; i++) {
      let idx = this._categories.indexOf(subSample[i].category);
      categoryCount[idx]++;
    }
    //On récupère la catégorie la plus représentée dans la liste d'échantillons
    let mostRepresentative = -1;
    let nb = -1;
    for (let i = 0; i < categoryCount.length; i++) {
      if (categoryCount[i] > nb) {
        nb = categoryCount[i];
        mostRepresentative = i;
      }
    }
    //On récupère les échantillons correcpondants a la classe la plus représentative
    let vectors = [];
    for (let i = 0; i <  subSample.length; i++) {
      if (subSample[i].category === mostRepresentative) {
        vectors.push(subSample[i]);
      }
    }
    //On crée le classifier : pour chaque paramètre, on calcule la somme des valeurs de ce paramètre parmis les échantillons du tableau vectors
    let classifier = [];
    for (let i = 0; i < this._nParameters; i++) {
      classifier[i] = 0;
    }
    for (let i = 0; i < this._nParameters; i++) {
      for(let j = 0; j < vectors.length; j++) {
        classifier[i] += this._arrayFromParams(vectors[j].params)[i];
      }
    }
    if (this.moyennageClassfier) {
      for (let i = 0; i < classifier.length; i++) 
        classifier[i] /= vectors.length;
    }
    return classifier;
  }

  _getThreshold(subSample, node) {
    let fmin = 100000;
    let fmax = -1;
    for (let i = 0; i < subSample.length; i++) {
      let f = this._cosineSimilarity(this._arrayFromParams(subSample[i].params), node.classifier, node.paramThreshold);
      if (f < fmin)
      fmin = f;
      if (f > fmax)
          fmax = f;
    }
    return (fmax + fmin) / 2.0;
  }

  _arrayFromParams(params) {
    let a = [];
    for (let property in params) {
      a.push(params[property]);
    } 
    return a;
  }

  _cosineSimilarity(s1, s2, indexes) {
    let product = 0.0;
    let sqr1 = 0.0;
    let sqr2 = 0.0;
    for (let i = 0; i < indexes.length; i++) {
      let idx = indexes[i];
      product += s1[idx] * s2[idx];
      sqr1 += s1[idx] * s1[idx];
      sqr2 += s2[idx] * s2[idx];
    }
    return product / (Math.sqrt(sqr1) * Math.sqrt(sqr2));
  }

  _getRandom(arr, n) {
    var result = new Array(n),
        len = arr.length,
        taken = new Array(len);
    if (n > len)
        throw new RangeError("getRandom: more elements taken than available");
    while (n--) {
        var x = Math.floor(Math.random() * len);
        result[n] = arr[x in taken ? taken[x] : x];
        taken[x] = --len;
    }
    return result;
  }

  _shuffleArray(originalArray) {
    var array = [].concat(originalArray);
    var currentIndex = array.length, temporaryValue, randomIndex;
  
    // While there remain elements to shuffle...
    while (0 !== currentIndex) {
  
      // Pick a remaining element...
      randomIndex = Math.floor(Math.random() * currentIndex);
      currentIndex -= 1;
  
      // And swap it with the current element.
      temporaryValue = array[currentIndex];
      array[currentIndex] = array[randomIndex];
      array[randomIndex] = temporaryValue;
    }
    return array;
  }
}