基于tensorflow.js的在线手写数字识别js文件
时间: 2023-05-27 13:05:23 浏览: 171
本文提供一份基于tensorflow.js的在线手写数字识别js文件,可以在浏览器中实现手写数字的识别。
首先需要引入tensorflow.js的库文件:
```javascript
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.1/dist/tf.min.js"></script>
```
然后定义一些变量,包括canvas和context:
```javascript
var canvas = document.getElementById('canvas');
var context = canvas.getContext('2d');
var model;
var xs;
var ys;
var isDrawing = false;
```
接着定义一些函数,包括画布的初始化、开始绘制、结束绘制、清除画布等:
```javascript
function initCanvas() {
context.fillStyle = '#ffffff';
context.fillRect(0, 0, canvas.width, canvas.height);
context.lineWidth = 10;
context.lineJoin = 'round';
context.lineCap = 'round';
context.strokeStyle = '#000000';
}
function startDrawing(event) {
isDrawing = true;
context.beginPath();
context.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
}
function endDrawing() {
isDrawing = false;
xs = tf.browser.fromPixels(canvas, 1)
.resizeNearestNeighbor([28, 28])
.toFloat()
.div(255.0);
xs = xs.reshape([1, 784]);
}
function clearCanvas() {
context.clearRect(0, 0, canvas.width, canvas.height);
initCanvas();
}
```
其中,startDrawing函数会在鼠标按下时调用,endDrawing函数会在鼠标松开时调用,清空画布的函数是clearCanvas。
最后是加载模型的函数:
```javascript
async function loadModel() {
model = await tf.loadLayersModel('http://localhost:8000/model.json');
}
```
loadModel函数会在页面加载时调用,用于加载我们预先训练好的模型。这里假设我们已经将模型文件放在本地的8000端口上。
最后是监听鼠标事件的代码:
```javascript
canvas.addEventListener('mousedown', startDrawing);
canvas.addEventListener('mousemove', function(event) {
if (isDrawing) {
context.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
context.stroke();
}
});
canvas.addEventListener('mouseup', endDrawing);
canvas.addEventListener('mouseout', endDrawing);
```
这段代码会监听鼠标的mousedown、mousemove、mouseup和mouseout事件,调用相应的函数。
完整的代码如下:
```javascript
var canvas = document.getElementById('canvas');
var context = canvas.getContext('2d');
var model;
var xs;
var ys;
var isDrawing = false;
async function loadModel() {
model = await tf.loadLayersModel('http://localhost:8000/model.json');
}
function initCanvas() {
context.fillStyle = '#ffffff';
context.fillRect(0, 0, canvas.width, canvas.height);
context.lineWidth = 10;
context.lineJoin = 'round';
context.lineCap = 'round';
context.strokeStyle = '#000000';
}
function startDrawing(event) {
isDrawing = true;
context.beginPath();
context.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
}
function endDrawing() {
isDrawing = false;
xs = tf.browser.fromPixels(canvas, 1)
.resizeNearestNeighbor([28, 28])
.toFloat()
.div(255.0);
xs = xs.reshape([1, 784]);
predict();
}
function predict() {
var result = model.predict(xs).dataSync();
var maxIndex = 0;
for (var i = 1; i < result.length; i++) {
if (result[i] > result[maxIndex]) {
maxIndex = i;
}
}
document.getElementById('result').innerText = maxIndex;
}
function clearCanvas() {
context.clearRect(0, 0, canvas.width, canvas.height);
initCanvas();
}
loadModel();
initCanvas();
canvas.addEventListener('mousedown', startDrawing);
canvas.addEventListener('mousemove', function(event) {
if (isDrawing) {
context.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
context.stroke();
}
});
canvas.addEventListener('mouseup', endDrawing);
canvas.addEventListener('mouseout', endDrawing);
document.getElementById('clear').addEventListener('click', clearCanvas);
```
其中,predict函数用于对手写数字进行预测,会在endDrawing函数中调用。我们使用了dataSync函数来获取预测结果,并找到其中最大的数字作为预测结果。最后,我们将预测结果显示在页面上的一个div中。
完整的HTML代码如下:
```html
<!DOCTYPE html>
<html>
<head>
<title>Online Handwritten Digit Recognition</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.1/dist/tf.min.js"></script>
</head>
<body>
<canvas id="canvas" width="280" height="280" style="border: 1px solid #000000;"></canvas>
<button id="clear">Clear</button>
<div>Result: <span id="result"></span></div>
<script>
var canvas = document.getElementById('canvas');
var context = canvas.getContext('2d');
var model;
var xs;
var ys;
var isDrawing = false;
async function loadModel() {
model = await tf.loadLayersModel('http://localhost:8000/model.json');
}
function initCanvas() {
context.fillStyle = '#ffffff';
context.fillRect(0, 0, canvas.width, canvas.height);
context.lineWidth = 10;
context.lineJoin = 'round';
context.lineCap = 'round';
context.strokeStyle = '#000000';
}
function startDrawing(event) {
isDrawing = true;
context.beginPath();
context.moveTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
}
function endDrawing() {
isDrawing = false;
xs = tf.browser.fromPixels(canvas, 1)
.resizeNearestNeighbor([28, 28])
.toFloat()
.div(255.0);
xs = xs.reshape([1, 784]);
predict();
}
function predict() {
var result = model.predict(xs).dataSync();
var maxIndex = 0;
for (var i = 1; i < result.length; i++) {
if (result[i] > result[maxIndex]) {
maxIndex = i;
}
}
document.getElementById('result').innerText = maxIndex;
}
function clearCanvas() {
context.clearRect(0, 0, canvas.width, canvas.height);
initCanvas();
}
loadModel();
initCanvas();
canvas.addEventListener('mousedown', startDrawing);
canvas.addEventListener('mousemove', function(event) {
if (isDrawing) {
context.lineTo(event.clientX - canvas.offsetLeft, event.clientY - canvas.offsetTop);
context.stroke();
}
});
canvas.addEventListener('mouseup', endDrawing);
canvas.addEventListener('mouseout', endDrawing);
document.getElementById('clear').addEventListener('click', clearCanvas);
</script>
</body>
</html>
```
阅读全文