みなさん,こんにちは。
シンノユウキ(shinno1993)です。
機械学習にはPythonを使う必要がある,ということはよく言われているかと思います.しかし,Googleから公開されているTensorFlowのJavaScript版であるTensorFlow.jsを利用することで,ブラウザ上で手軽に機械学習を行うことができます.今回はその方法と簡単な例を紹介したいと思います.
ではいきましょう!
TensorFlow.jsとは?
はじめに,今回紹介するTensorFlow.jsを紹介したいと思います.
TensorFlow.jsのオリジナルであるTensorFlowというのは,Googleが公開している機械学習用のオープン・ソースのライブラリです.PythonやJavaなどから利用できるAPIを備えています.
TensorFlow.jsはそのJavaScript版です.ブラウザ上で機械学習モデルのトレーニングや,そのデプロイできるようになります.
またトレーニングだけでなく,既存の学習済みモデルを読み込んで利用したり,そのモデルを再学習することができるようになります.機械学習の成果をWeb上で利用するためには最適なツールと言えるでしょう.
公式ページは以下からご確認ください.

MNISTのPredictionをブラウザでやってみる
手始めに,Kerasで学習済みのモデルをTensorFlow.jsで読み込み,ブラウザ上でPredictionしてみましょう.
KerasでMNIST学習済みモデルを作成しよう!
まずは,Kerasでモデルを構築し,MNISTで学習させて,TensorFlow.jsに渡す学習済みモデルを作成しましょう.ここでは,Google Colabを用いて,こちらのコードを使用して学習させました.
モデルをTensorFlow.jsで読める形に変換しよう!
作成したモデルは,そのままの形ではTensorFlow.jsに渡すことができません.TensorFlow.jsで利用できる形に変換してあげましょう.以下のコードのようにしてください.
model.save('model.h5')
!pip install tensorflowjs
!tensorflowjs_converter --input_format keras \
model.h5 \
target_path
まずモデルを保存し,tensorflowjsのconverterを使用してモデルを変換させます.
このモデルをダウンロードしておきましょう.
HTMLコードはこちら!
では実際のHTMLコードを示します.なお,こちらのコードは以下の記事を参考にしています.

<!DOCTYPE html>
<html lang="ja">
<head>
<script src="//cdnjs.cloudflare.com/ajax/libs/numeral.js/2.0.6/numeral.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/signature_pad@2.3.2/dist/signature_pad.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"> </script>
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css" integrity="sha384-MCw98/SFnGE8fJT3GXwEOngsV7Zt27NXFoaoApmYm81iuXoPkFOJwJ8ERdknLPMO" crossorigin="anonymous">
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>ブラウザ上でMNISTをTensorFlow.jsで認識する</title>
<style>
.row{
margin-bottom: 15px;
}
</style>
</head>
<body>
<div class="container" style="max-width:850px;">
<h1 class="title">ブラウザ上でMNISTをTensorFlow.jsで認識する</h1>
<h2 class="subtitle">Kerasで学習ずみのモデルを使用</h2>
<div class="row">
<div class="col">
<canvas id="draw-area" width="280" height="280" style="border: 2px solid;"></canvas>
</div>
</div>
<div class="row">
<div class="col">
<button id="predict-button" class="btn btn-primary" type="button" onclick="prediction()" disabled>Prediction</button>
<button class="btn btn-secondary" type="button" onclick="reset()">Reset</button>
</div>
</div>
<div class="row">
<div class="col">
<table class="table">
<thead>
<tr>
<th>Number</th>
<th>Accuracy</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td class="accuracy" data-row-index="0">-</td>
</tr>
<tr>
<th>1</th>
<td class="accuracy" data-row-index="1">-</td>
</tr>
<tr>
<th>2</th>
<td class="accuracy" data-row-index="2">-</td>
</tr>
<tr>
<th>3</th>
<td class="accuracy" data-row-index="3">-</td>
</tr>
<tr>
<th>4</th>
<td class="accuracy" data-row-index="4">-</td>
</tr>
<tr>
<th>5</th>
<td class="accuracy" data-row-index="5">-</td>
</tr>
<tr>
<th>6</th>
<td class="accuracy" data-row-index="6">-</td>
</tr>
<tr>
<th>7</th>
<td class="accuracy" data-row-index="7">-</td>
</tr>
<tr>
<th>8</th>
<td class="accuracy" data-row-index="8">-</td>
</tr>
<tr>
<th>9</th>
<td class="accuracy" data-row-index="9">-</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
<script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.3/umd/popper.min.js" integrity="sha384-ZMP7rVo3mIykV+2+9J3UJ46jBk0WLaUAdn689aCwoqbBJiSnjAK/l8WvCWPIPm49" crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/js/bootstrap.min.js" integrity="sha384-ChfqqxuZUCnJSK3+MXmPNIyE6ZbWh2IMqE241rYiqJxyMiZ6OW/JmZQ5stwEULTy" crossorigin="anonymous"></script>
<script>
// init SignaturePad
const drawElement = document.getElementById('draw-area');
const signaturePad = new SignaturePad(drawElement, {
minWidth: 6,
maxWidth: 6,
penColor: 'white',
backgroundColor: 'black',
});
// load pre-trained model
let model;
tf.loadModel('./model/model.json')
.then(pretrainedModel => {
document.getElementById('predict-button').removeAttribute('disabled', false);
model = pretrainedModel;
});
function getImageData() {
const inputWidth = 28;
const inputHeight = 28;
// resize
const tmpCanvas = document.createElement('canvas').getContext('2d');
tmpCanvas.drawImage(drawElement, 0, 0, inputWidth, inputHeight);
// convert grayscale
let imageData = tmpCanvas.getImageData(0, 0, inputWidth, inputHeight);
for (let i = 0; i < imageData.data.length; i+=4) {
const avg = (imageData.data[i] + imageData.data[i+1] + imageData.data[i+2]) / 3;
imageData.data[i] = imageData.data[i+1] = imageData.data[i+2] = avg;
}
return imageData;
}
function getAccuracyScores(imageData) {
const score = tf.tidy(() => {
// convert to tensor (shape: [width, height, channels])
const channels = 1; // grayscale
let input = tf.fromPixels(imageData, channels);
// normalized
input = tf.cast(input, 'float32').div(tf.scalar(255));
// reshape input format (shape: [batch_size, width, height, channels])
input = input.expandDims();
// predict
return model.predict(input).dataSync();
});
return score;
}
function prediction() {
const imageData = getImageData();
const accuracyScores = getAccuracyScores(imageData);
const maxAccuracy = accuracyScores.indexOf(Math.max.apply(null, accuracyScores));
const elements = document.querySelectorAll(".accuracy");
elements.forEach(el => {
el.parentNode.classList.remove('table-success');
const rowIndex = Number(el.dataset.rowIndex);
if (maxAccuracy === rowIndex) {
el.parentNode.classList.add('table-success');
}
var formatedNumber = numeral(accuracyScores[rowIndex]).format('0.00%');
el.innerText = formatedNumber;
})
}
function reset() {
signaturePad.clear();
let elements = document.querySelectorAll(".accuracy");
elements.forEach(el => {
el.parentNode.classList.remove('table-success');
el.innerText = '-';
})
}
</script>
</body>
</html>
実際にHTMLを表示してみると,以下のようなページが表示されます.

ここで紹介したコードでは,Canvasに書いた文字をPredictionするとその文字がどれに該当するのかを予測することができます.
まとめ
今回はKerasで学習したモデルをブラウザ上で利用するためのライブラリ:TensorFlow.jsを紹介しました.Python環境だけでなく,JavaScriptでも利用できるので機械学習を活用する場が広がりますね.参考になれば幸いです.
