import React, { Component } from 'react';
import ReactDOM from 'react-dom';
import XORTask from './XORTask';
import './style.css';
ReactDOM.render(<XORTask/>, document.getElementById('root'));
{
"name": "@plnkr/starter-react",
"version": "1.0.2",
"description": "React starter template",
"dependencies": {
"@tensorflow/tfjs": "^2.0.1",
"plotly.js": "^1.54.7",
"react": "^16.13.0",
"react-dom": "^16.13.0",
"react-plotly.js": "^2.4.0",
"react-charts": "^2.0.0-beta.7",
"recharts": "^1.8.5"
},
"plnkr": {
"runtime": "system",
"useHotReload": true
}
}
h1,
p {
font-family: sans-serif;
}
import React, { useEffect, useRef, useState } from 'react';
export default ({ data, size = 400, fontSize = 18, squareAmount }) => {
const canvasRef = useRef();
const [ctx, changeCtx] = useState(null);
const squareSize = size / squareAmount;
console.log(squareSize)
useEffect(() => {
changeCtx(canvasRef.current.getContext('2d'));
}, []);
useEffect(() => {
if (ctx) {
ctx.clearRect(0, 0, size, size);
data.forEach(({ x1, x2, out }) => {
// draw the square
ctx.fillStyle = `rgba(0, 0, 0, ${out})`;
ctx.fillRect(x1 * size, x2 * size, squareSize, squareSize);
// put the value
ctx.fillStyle = 'red';
ctx.font = `${fontSize}px serif`;
ctx.fillText(out.toFixed(2), x1 * size + squareSize / 2 - 20, x2 * size + squareSize / 2 + 5);
});
}
}, [data, ctx]);
return (
<canvas
ref={canvasRef}
width={size}
height={size}
/>
)
}
import React from 'react';
import { LineChart, XAxis, YAxis, CartesianGrid, Line } from 'recharts';
export default ({ width = 400, height = 430, loss = [] }) => {
return (
<LineChart
width={width}
height={height}
margin={{ top: 0, left: 0, bottom: 0, right: 0 }}
data={loss}>
<XAxis dataKey="epoch"/>
<YAxis/>
<CartesianGrid
/>
<Line
type="monotone"
dataKey="loss"
stroke="#8884d8"
dot={false}/>
</LineChart>
)
}
import React, { useEffect, useState } from 'react';
import LossPlot from './components/LossPlot';
import Canvas from './components/Canvas';
import * as tf from "@tensorflow/tfjs";
let model;
export default () => {
const [data, changeData] = useState([]);
const [lossHistory, changeLossHistory] = useState([]);
useEffect(() => {
async function initModel() {
const trainingInput = [[0, 0], [1, 0], [0, 1], [1, 1]];
const trainingInputTensor = tf.tensor(trainingInput, [trainingInput.length, 2]);
const trainingOutput = [[0], [1], [1], [0]]
const trainingOutputTensor = tf.tensor(trainingOutput, [trainingOutput.length, 1]);
const testInput = generateInputs(10);
const testInputTensor = tf.tensor(testInput, [testInput.length, 2]);
model = tf.sequential();
model.add(tf.layers.dense({ inputShape: [2], units: 3, activation: 'sigmoid' }));
model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
model.compile({
optimizer: tf.train.adam(0.1),
loss: 'meanSquaredError'
});
await model.fit(trainingInputTensor, trainingOutputTensor, {
epochs: 1000,
shuffle: true,
callbacks: {
onEpochEnd: async (epoch, { loss }) => {
changeLossHistory((prevHistory) => [...prevHistory, {
epoch,
loss
}]);
const output = model.predict(testInputTensor).arraySync();
changeData(() => output.map(([out], i) => ({
out,
x1: testInput[i][0],
x2: testInput[i][1]
})));
await tf.nextFrame();
}
}
})
}
initModel();
}, []);
return (
<div>
<Canvas data={data} squareAmount={10}/>
<LossPlot loss={lossHistory}/>
</div>
);
}
function generateInputs(squareAmount) {
const step = 1 / squareAmount;
const input = [];
for (let i = 0; i < 1; i += step) {
for (let j = 0; j < 1; j += step) {
input.push([i, j]);
}
}
return input;
}