TensorFlow.jsを使ってKerasで作成したモデルを利用してみる

みなさん,こんにちは。
シンノユウキ(shinno1993)です。

機械学習にはPythonを使う必要がある,ということはよく言われているかと思います.しかし,Googleから公開されているTensorFlowのJavaScript版であるTensorFlow.jsを利用することで,ブラウザ上で手軽に機械学習を行うことができます.今回はその方法と簡単な例を紹介したいと思います.

ではいきましょう!

TensorFlow.jsとは?

はじめに,今回紹介するTensorFlow.jsを紹介したいと思います.

TensorFlow.jsのオリジナルであるTensorFlowというのは,Googleが公開している機械学習用のオープン・ソースのライブラリです.PythonやJavaなどから利用できるAPIを備えています.

TensorFlow.jsはそのJavaScript版です.ブラウザ上で機械学習モデルのトレーニングや,そのデプロイできるようになります.

またトレーニングだけでなく,既存の学習済みモデルを読み込んで利用したり,そのモデルを再学習することができるようになります.機械学習の成果をWeb上で利用するためには最適なツールと言えるでしょう.

公式ページは以下からご確認ください.

TensorFlow.js | Machine Learning for JavaScript Developers
Train and deploy models in the browser, Node.js, or Google Cloud Platform. TensorFlow.js is an open source ML platform f...

MNISTのPredictionをブラウザでやってみる

手始めに,Kerasで学習済みのモデルをTensorFlow.jsで読み込み,ブラウザ上でPredictionしてみましょう.

KerasでMNIST学習済みモデルを作成しよう!

まずは,Kerasでモデルを構築し,MNISTで学習させて,TensorFlow.jsに渡す学習済みモデルを作成しましょう.ここでは,Google Colabを用いて,こちらのコードを使用して学習させました.

File not found · keras-team/keras
Deep Learning for humans. Contribute to keras-team/keras development by creating an account on GitHub.

モデルをTensorFlow.jsで読める形に変換しよう!

作成したモデルは,そのままの形ではTensorFlow.jsに渡すことができません.TensorFlow.jsで利用できる形に変換してあげましょう.以下のコードのようにしてください.

model.save('model.h5')
!pip install tensorflowjs
!tensorflowjs_converter --input_format keras \
                       model.h5 \
                       target_path

まずモデルを保存し,tensorflowjsのconverterを使用してモデルを変換させます.

このモデルをダウンロードしておきましょう.

HTMLコードはこちら!

では実際のHTMLコードを示します.なお,こちらのコードは以下の記事を参考にしています.

TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をする - Qiita
先日行われたTensorFlow Dev Summit 2018の「Machine Learning in JavaScript」で、Webブラウザ上で実行可能な機械学習ライブラリとしてTensor…
<!DOCTYPE html>
<html lang="ja">
<head>
    <script src="//cdnjs.cloudflare.com/ajax/libs/numeral.js/2.0.6/numeral.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/signature_pad@2.3.2/dist/signature_pad.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"> </script>
    <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css" integrity="sha384-MCw98/SFnGE8fJT3GXwEOngsV7Zt27NXFoaoApmYm81iuXoPkFOJwJ8ERdknLPMO" crossorigin="anonymous">
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>ブラウザ上でMNISTをTensorFlow.jsで認識する</title>
    <style>
        .row{
            margin-bottom: 15px;
        }
    </style>
</head>
<body>
  <div class="container" style="max-width:850px;">
    <h1 class="title">ブラウザ上でMNISTをTensorFlow.jsで認識する</h1>
    <h2 class="subtitle">Kerasで学習ずみのモデルを使用</h2>
    <div class="row">
        <div class="col">
            <canvas id="draw-area" width="280" height="280" style="border: 2px solid;"></canvas>
        </div>
    </div>
    <div class="row">
        <div class="col">
            <button id="predict-button" class="btn btn-primary" type="button" onclick="prediction()" disabled>Prediction</button>
            <button class="btn btn-secondary" type="button" onclick="reset()">Reset</button>
        </div>
    </div>
    <div class="row">
        <div class="col">
            <table class="table">
                <thead>
                    <tr>
                        <th>Number</th>
                        <th>Accuracy</th>
                    </tr>
                </thead>
                <tbody>
                    <tr>
                        <th>0</th>
                        <td class="accuracy" data-row-index="0">-</td>
                    </tr>
                    <tr>
                        <th>1</th>
                        <td class="accuracy" data-row-index="1">-</td>
                    </tr>
                    <tr>
                        <th>2</th>
                        <td class="accuracy" data-row-index="2">-</td>
                    </tr>
                    <tr>
                        <th>3</th>
                        <td class="accuracy" data-row-index="3">-</td>
                    </tr>
                    <tr>
                        <th>4</th>
                        <td class="accuracy" data-row-index="4">-</td>
                    </tr>
                    <tr>
                        <th>5</th>
                        <td class="accuracy" data-row-index="5">-</td>
                    </tr>
                    <tr>
                        <th>6</th>
                        <td class="accuracy" data-row-index="6">-</td>
                    </tr>
                    <tr>
                        <th>7</th>
                        <td class="accuracy" data-row-index="7">-</td>
                    </tr>
                    <tr>
                        <th>8</th>
                        <td class="accuracy" data-row-index="8">-</td>
                    </tr>
                    <tr>
                        <th>9</th>
                        <td class="accuracy" data-row-index="9">-</td>
                    </tr>
                </tbody>
            </table>  
        </div>
    </div>
  </div>
  <script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
  <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.3/umd/popper.min.js" integrity="sha384-ZMP7rVo3mIykV+2+9J3UJ46jBk0WLaUAdn689aCwoqbBJiSnjAK/l8WvCWPIPm49" crossorigin="anonymous"></script>
  <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/js/bootstrap.min.js" integrity="sha384-ChfqqxuZUCnJSK3+MXmPNIyE6ZbWh2IMqE241rYiqJxyMiZ6OW/JmZQ5stwEULTy" crossorigin="anonymous"></script>
  
  <script>
    // init SignaturePad
    const drawElement = document.getElementById('draw-area');
    const signaturePad = new SignaturePad(drawElement, {
       minWidth: 6,
       maxWidth: 6,
       penColor: 'white',
       backgroundColor: 'black',
    });

    // load pre-trained model
    let model;
    tf.loadModel('./model/model.json')
        .then(pretrainedModel => {
            document.getElementById('predict-button').removeAttribute('disabled', false);
            model = pretrainedModel;
    });

    function getImageData() {
      const inputWidth = 28;
      const inputHeight = 28;

      // resize
      const tmpCanvas = document.createElement('canvas').getContext('2d');
      tmpCanvas.drawImage(drawElement, 0, 0, inputWidth, inputHeight);

      // convert grayscale
      let imageData = tmpCanvas.getImageData(0, 0, inputWidth, inputHeight);
      for (let i = 0; i < imageData.data.length; i+=4) {
        const avg = (imageData.data[i] + imageData.data[i+1] + imageData.data[i+2]) / 3;
        imageData.data[i] = imageData.data[i+1] = imageData.data[i+2] = avg;
      }
      return imageData;
    }

    function getAccuracyScores(imageData) {
      const score = tf.tidy(() => {
        // convert to tensor (shape: [width, height, channels])  
        const channels = 1; // grayscale              
        let input = tf.fromPixels(imageData, channels);
        // normalized
        input = tf.cast(input, 'float32').div(tf.scalar(255));
        // reshape input format (shape: [batch_size, width, height, channels])
        input = input.expandDims();
        // predict
        return model.predict(input).dataSync();
      });
      return score;
    }

    function prediction() {
      const imageData = getImageData();
      const accuracyScores = getAccuracyScores(imageData);
      const maxAccuracy = accuracyScores.indexOf(Math.max.apply(null, accuracyScores));

      const elements = document.querySelectorAll(".accuracy");
      elements.forEach(el => {
        el.parentNode.classList.remove('table-success');
        const rowIndex = Number(el.dataset.rowIndex);
        if (maxAccuracy === rowIndex) {
          el.parentNode.classList.add('table-success');
        }
        var formatedNumber = numeral(accuracyScores[rowIndex]).format('0.00%');
        el.innerText = formatedNumber;
      })
    }

    function reset() {
      signaturePad.clear();
      let elements = document.querySelectorAll(".accuracy");
      elements.forEach(el => {
        el.parentNode.classList.remove('table-success');
        el.innerText = '-';
      })
    }

  </script>
</body>
</html>

実際にHTMLを表示してみると,以下のようなページが表示されます.

ここで紹介したコードでは,Canvasに書いた文字をPredictionするとその文字がどれに該当するのかを予測することができます.

まとめ

今回はKerasで学習したモデルをブラウザ上で利用するためのライブラリ:TensorFlow.jsを紹介しました.Python環境だけでなく,JavaScriptでも利用できるので機械学習を活用する場が広がりますね.参考になれば幸いです.

タイトルとURLをコピーしました