import React, { Component } from 'react';
import ReactDOM from 'react-dom';

import XORTask from './XORTask';
import './style.css';

ReactDOM.render(<XORTask/>, document.getElementById('root'));
{
  "name": "@plnkr/starter-react",
  "version": "1.0.2",
  "description": "React starter template",
  "dependencies": {
     "@tensorflow/tfjs": "^2.0.1",
    "plotly.js": "^1.54.7",
    "react": "^16.13.0",
    "react-dom": "^16.13.0",
    "react-plotly.js": "^2.4.0",
    "react-charts": "^2.0.0-beta.7",
    "recharts": "^1.8.5"
  },
  "plnkr": {
    "runtime": "system",
    "useHotReload": true
  }
}
h1,
p {
  font-family: sans-serif;
}
import React, { useEffect, useRef, useState } from 'react';

export default ({ data, size = 400, fontSize = 18, squareAmount }) => {
    const canvasRef = useRef();
    const [ctx, changeCtx] = useState(null);

    const squareSize = size / squareAmount;
    console.log(squareSize)

    useEffect(() => {
        changeCtx(canvasRef.current.getContext('2d'));
    }, []);

    useEffect(() => {
        if (ctx) {
            ctx.clearRect(0, 0, size, size);
            data.forEach(({ x1, x2, out }) => {
                // draw the square
                ctx.fillStyle = `rgba(0, 0, 0, ${out})`;
                ctx.fillRect(x1 * size, x2 * size, squareSize, squareSize);
                // put the value
                ctx.fillStyle = 'red';
                ctx.font = `${fontSize}px serif`;
                ctx.fillText(out.toFixed(2), x1 * size + squareSize / 2 - 20, x2 * size + squareSize / 2 + 5);
            });
        }
    }, [data, ctx]);

    return (
        <canvas
            ref={canvasRef}
            width={size}
            height={size}
        />
    )
}
import React from 'react';
import { LineChart, XAxis, YAxis, CartesianGrid, Line } from 'recharts';

export default ({ width = 400, height = 430, loss = [] }) => {
    return (
        <LineChart
            width={width}
            height={height}
            margin={{ top: 0, left: 0, bottom: 0, right: 0 }}
            data={loss}>
            <XAxis dataKey="epoch"/>
            <YAxis/>
            <CartesianGrid

            />
            <Line
                type="monotone"
                dataKey="loss"
                stroke="#8884d8"
                dot={false}/>
        </LineChart>
    )
}
import React, { useEffect, useState } from 'react';
import LossPlot from './components/LossPlot';
import Canvas from './components/Canvas';
import * as tf from "@tensorflow/tfjs";

let model;

export default () => {
    const [data, changeData] = useState([]);
    const [lossHistory, changeLossHistory] = useState([]);

    useEffect(() => {

        async function initModel() {
            const trainingInput = [[0, 0], [1, 0], [0, 1], [1, 1]];
            const trainingInputTensor = tf.tensor(trainingInput, [trainingInput.length, 2]);

            const trainingOutput = [[0], [1], [1], [0]]
            const trainingOutputTensor = tf.tensor(trainingOutput, [trainingOutput.length, 1]);

            const testInput = generateInputs(10);
            const testInputTensor = tf.tensor(testInput, [testInput.length, 2]);

            model = tf.sequential();
            model.add(tf.layers.dense({ inputShape: [2], units: 3, activation: 'sigmoid' }));
            model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
            model.compile({
                optimizer: tf.train.adam(0.1),
                loss: 'meanSquaredError'
            });

            await model.fit(trainingInputTensor, trainingOutputTensor, {
                epochs: 1000,
                shuffle: true,
                callbacks: {
                    onEpochEnd: async (epoch, { loss }) => {
                        changeLossHistory((prevHistory) => [...prevHistory, {
                            epoch,
                            loss
                        }]);

                        const output = model.predict(testInputTensor).arraySync();
                        changeData(() => output.map(([out], i) => ({
                            out,
                            x1: testInput[i][0],
                            x2: testInput[i][1]
                        })));
                        await tf.nextFrame();
                    }
                }
            })
        }
        initModel();
    }, []);

    return (
        <div>
            <Canvas data={data} squareAmount={10}/>
            <LossPlot loss={lossHistory}/>
        </div>
    );
}

function generateInputs(squareAmount) {
    const step = 1 / squareAmount;
    const input = [];
    for (let i = 0; i < 1; i += step) {
        for (let j = 0; j < 1; j += step) {
            input.push([i, j]);
        }
    }
    return input;
}