|
| 1 | +<!-- Copyright 2023 Google LLC. All Rights Reserved. |
| 2 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +you may not use this file except in compliance with the License. |
| 4 | +You may obtain a copy of the License at |
| 5 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +Unless required by applicable law or agreed to in writing, software |
| 7 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +See the License for the specific language governing permissions and |
| 10 | +limitations under the License. |
| 11 | +============================================================================== --> |
| 12 | +<!DOCTYPE html> |
| 13 | +<html> |
| 14 | + |
| 15 | +<head> |
| 16 | + <meta charset="UTF-8" /> |
| 17 | + <title>Yggdrasil Decision Forests in TensorFlow.JS</title> |
| 18 | + |
| 19 | + <!-- Import @tensorflow/tfjs --> |
| 20 | + <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script> |
| 21 | + |
| 22 | + <!-- Import @tensorflow/tfjs-tfdf |
| 23 | + Note that we need to explicitly load dist/tf-tfdf.min.js so that it can |
| 24 | + locate WASM module files from their default location (dist/). --> |
| 25 | + <!-- TODO: Make TFDF search for WASM path relative to ./dist/ instead of ./ --> |
| 26 | + <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tfdf/dist/tf-tfdf.min.js"></script> |
| 27 | + |
| 28 | + <!-- Import papaparse to parse the penguins csv file --> |
| 29 | + <script src=" https://cdn.jsdelivr.net/npm/[email protected]/papaparse.min.js" ></script> |
| 30 | + |
| 31 | + <style> |
| 32 | + .button_box { |
| 33 | + display: flex; |
| 34 | + } |
| 35 | + |
| 36 | + .button_box button { |
| 37 | + width: 200px; |
| 38 | + margin: 5px; |
| 39 | + } |
| 40 | + table, th, td { |
| 41 | + border: 1px solid black; |
| 42 | + border-collapse: collapse; |
| 43 | + |
| 44 | + } |
| 45 | + </style> |
| 46 | +</head> |
| 47 | + |
| 48 | +<body> |
| 49 | + <h1>Yggdrasil Decision Forests in TensorFlow.JS</h1> |
| 50 | + |
| 51 | + <p> |
| 52 | + This example demonstrates how to use run TensorFlow Decision Forest models convereted to TensorFlow.JS model. The model was trained by following the <a href="https://simplemlforsheets.com/tutorial.html">Simple ML for Sheets Tutorial</a>. |
| 53 | + </p> |
| 54 | + <table id="penguin_table"/> |
| 55 | + |
| 56 | + <div class="button_box"> |
| 57 | + <button id="btn_apply_model" type="button" disabled>Loading Model...</button> |
| 58 | + </div> |
| 59 | + |
| 60 | + <script> |
| 61 | + // Penguin dataset from https://simplemlforsheets.com/tutorial.html |
| 62 | + const dataset = `species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex,year,Pred:species,Pred:Conf.species |
| 63 | +,Biscoe,47.8,15,215,5650,male,2007,, |
| 64 | +,Torgersen,,,,,,2007,, |
| 65 | +,Dream,40.2,17.1,193,3400,female,2009,, |
| 66 | +,Dream,36,17.8,195,3450,female,2009,, |
| 67 | +,Biscoe,49.8,15.9,229,5950,male,2009,, |
| 68 | +,Biscoe,38.6,17.2,199,3750,female,2009,, |
| 69 | +,Dream,49,19.5,210,3950,male,2008,, |
| 70 | +,Dream,40.7,17,190,3725,male,2009,, |
| 71 | +,Biscoe,52.5,15.6,221,5450,male,2009,, |
| 72 | +,Biscoe,46.2,14.4,214,4650,,2008,, |
| 73 | +,Torgersen,40.2,17,176,3450,female,2009,, |
| 74 | +,Biscoe,46.5,14.5,213,4400,female,2007,, |
| 75 | +,Biscoe,49.5,16.2,229,5800,male,2008,, |
| 76 | +,Torgersen,36.2,16.1,187,3550,female,, |
| 77 | +,Biscoe,41.3,21.1,195,4400,male,2008,, |
| 78 | +,Biscoe,45.1,14.5,207,5050,female,2007,, |
| 79 | +,Biscoe,47.5,15,218,4950,female,2009,, |
| 80 | +,Biscoe,49.1,15,228,5500,male,2009,, |
| 81 | +,Biscoe,45.5,15,220,5000,male,2008,, |
| 82 | +,Biscoe,35,17.9,192,3725,female,2009,, |
| 83 | +,Torgersen,35.5,17.5,190,3700,female,, |
| 84 | +,Biscoe,46.3,15.8,215,5050,male,2007,, |
| 85 | +,Dream,42.5,16.7,187,3350,female,2008,, |
| 86 | +,Torgersen,34.1,18.1,193,3475,,2007,, |
| 87 | +,Dream,37.5,18.5,199,4475,male,2009,, |
| 88 | +,Dream,36.4,17,195,3325,female,2007,, |
| 89 | +,Dream,45.7,17.3,193,3600,female,2009,, |
| 90 | +,Dream,51.9,19.5,206,3950,male,2009,, |
| 91 | +,Biscoe,46.2,14.5,209,4800,female,2007,, |
| 92 | +,Dream,42.5,17.3,187,3350,female,2009,,`; |
| 93 | + // The model (once loaded). |
| 94 | + let model = null; |
| 95 | + |
| 96 | + function renderTable(data) { |
| 97 | + const table = document.getElementById("penguin_table"); |
| 98 | + while (table.hasChildNodes()) { |
| 99 | + table.removeChild(table.lastChild); |
| 100 | + } |
| 101 | + |
| 102 | + const headers = Object.keys(data[0]); |
| 103 | + const headerRow = table.insertRow(); |
| 104 | + for (let header of headers) { |
| 105 | + const cell = headerRow.insertCell(); |
| 106 | + cell.innerHTML = header; |
| 107 | + } |
| 108 | + |
| 109 | + for (let record of data) { |
| 110 | + const row = table.insertRow(); |
| 111 | + for (let header of headers) { |
| 112 | + const cell = row.insertCell(); |
| 113 | + cell.innerHTML = record[header] ?? ''; |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + // Parse the CSV data |
| 119 | + const {data} = Papa.parse(dataset, {header: true}); |
| 120 | + renderTable(data); |
| 121 | + |
| 122 | + async function loadModel() { |
| 123 | + model = await tfdf.loadTFDFModel('https://storage.googleapis.com/tfjs-examples/tfdf-penguins/tfjs_model/model.json'); |
| 124 | + const button = document.getElementById("btn_apply_model"); |
| 125 | + button.disabled = false; |
| 126 | + button.innerText = "Classify Penguins!"; |
| 127 | + } |
| 128 | + loadModel(); |
| 129 | + |
| 130 | + async function applyModel() { |
| 131 | + const inputs = {}; |
| 132 | + const toDispose = []; |
| 133 | + |
| 134 | + for (const {name, dtype} of model.inputs) { |
| 135 | + if (dtype === null) { |
| 136 | + continue; |
| 137 | + } |
| 138 | + const defaultVal = dtype === 'string' ? '' : 0; |
| 139 | + inputs[name] = tf.tensor1d(data.map(d => d[name] ?? defaultVal), dtype); |
| 140 | + toDispose.push(inputs[name]); |
| 141 | + } |
| 142 | + |
| 143 | + const predictions = await model.executeAsync(inputs); |
| 144 | + toDispose.push(predictions); |
| 145 | + |
| 146 | + const [classificationsTensor, confidencesTensor] = tf.tidy(() => { |
| 147 | + const classificationIndices = tf.argMax(predictions, 1); |
| 148 | + const classifications = classificationIndices.sub(3); |
| 149 | + const confidences = tf.gather(predictions, classificationIndices, 1, 1); |
| 150 | + |
| 151 | + return [classifications, confidences]; |
| 152 | + }); |
| 153 | + |
| 154 | + toDispose.push(classificationsTensor, confidencesTensor); |
| 155 | + |
| 156 | + const classNames = ['Adelie', 'Gentoo', 'Chinstrap']; |
| 157 | + const classifications = (await classificationsTensor.array()) |
| 158 | + .map(index => classNames[index]); |
| 159 | + const confidences = await confidencesTensor.array(); |
| 160 | + |
| 161 | + const dataCopy = structuredClone(data); // Make a copy of the data |
| 162 | + for (let i = 0; i < confidences.length; i++) { |
| 163 | + const record = dataCopy[i]; |
| 164 | + record['Pred:species'] = classifications[i]; |
| 165 | + record['Pred:Conf.species'] = confidences[i]; |
| 166 | + } |
| 167 | + |
| 168 | + // Clean up tensors |
| 169 | + tf.dispose(toDispose); |
| 170 | + |
| 171 | + renderTable(dataCopy); |
| 172 | + } |
| 173 | + |
| 174 | + document.getElementById("btn_apply_model").onclick = applyModel; |
| 175 | + </script> |
| 176 | +</body> |
| 177 | + |
| 178 | +</html> |
0 commit comments