Skip to content

Commit 0001d9e

Browse files
Create a TFDF demo using the penguins dataset (#1046)
Create a TFDF demo using the penguins dataset Replace node:12 docker with tfjs ci docker
1 parent ac95c7f commit 0001d9e

File tree

3 files changed

+205
-2
lines changed

3 files changed

+205
-2
lines changed

cloudbuild.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
steps:
2-
- name: 'node:12'
2+
- name: 'gcr.io/learnjs-174218/release'
33
entrypoint: 'yarn'
44
id: 'yarn'
55
args: ['install']
6-
- name: 'node:12'
6+
- name: 'gcr.io/learnjs-174218/release'
77
entrypoint: 'yarn'
88
id: 'test'
99
args: ['presubmit']

tfdf-penguins/README.md

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Tensorflow Decision Forests Penguins Demo
2+
3+
[See this example live!](https://storage.googleapis.com/tfjs-examples/tfdf-penguins/index.html)
4+
5+
## Contents
6+
7+
The demo shows how to use the Tensorflow.js decision forests package
8+
to run a converted model.
9+
10+
## Converting Model
11+
12+
1. Create a [Python TensorFlow Decision Forests model](https://www.tensorflow.org/decision_forests).
13+
14+
2. Save the model (will be exported as a SavedModel).
15+
16+
3. Run the model through the [tensorflowjs_converter](https://www.tensorflow.org/js/guide/conversion).
17+
```sh
18+
$ tensorflowjs_converter /path/to/saved_model /path/to/tfjs_model
19+
```
20+
21+
4. Use the [tfjs-tfdf library](https://github.com/tensorflow/tfjs/tree/master/tfjs-tfdf) to run the converted model in the web.
22+
23+
## Demo
24+
25+
The demo in the index.html file is based on the [SimpleML for Sheets tutorial](https://simplemlforsheets.com/tutorial.html) and shows how to predict the species of a penguin based on other information about it.

tfdf-penguins/index.html

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
<!-- Copyright 2023 Google LLC. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.
11+
============================================================================== -->
12+
<!DOCTYPE html>
13+
<html>
14+
15+
<head>
16+
<meta charset="UTF-8" />
17+
<title>Yggdrasil Decision Forests in TensorFlow.JS</title>
18+
19+
<!-- Import @tensorflow/tfjs -->
20+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
21+
22+
<!-- Import @tensorflow/tfjs-tfdf
23+
Note that we need to explicitly load dist/tf-tfdf.min.js so that it can
24+
locate WASM module files from their default location (dist/). -->
25+
<!-- TODO: Make TFDF search for WASM path relative to ./dist/ instead of ./ -->
26+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tfdf/dist/tf-tfdf.min.js"></script>
27+
28+
<!-- Import papaparse to parse the penguins csv file -->
29+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/papaparse.min.js"></script>
30+
31+
<style>
32+
.button_box {
33+
display: flex;
34+
}
35+
36+
.button_box button {
37+
width: 200px;
38+
margin: 5px;
39+
}
40+
table, th, td {
41+
border: 1px solid black;
42+
border-collapse: collapse;
43+
44+
}
45+
</style>
46+
</head>
47+
48+
<body>
49+
<h1>Yggdrasil Decision Forests in TensorFlow.JS</h1>
50+
51+
<p>
52+
This example demonstrates how to use run TensorFlow Decision Forest models convereted to TensorFlow.JS model. The model was trained by following the <a href="https://simplemlforsheets.com/tutorial.html">Simple ML for Sheets Tutorial</a>.
53+
</p>
54+
<table id="penguin_table"/>
55+
56+
<div class="button_box">
57+
<button id="btn_apply_model" type="button" disabled>Loading Model...</button>
58+
</div>
59+
60+
<script>
61+
// Penguin dataset from https://simplemlforsheets.com/tutorial.html
62+
const dataset = `species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex,year,Pred:species,Pred:Conf.species
63+
,Biscoe,47.8,15,215,5650,male,2007,,
64+
,Torgersen,,,,,,2007,,
65+
,Dream,40.2,17.1,193,3400,female,2009,,
66+
,Dream,36,17.8,195,3450,female,2009,,
67+
,Biscoe,49.8,15.9,229,5950,male,2009,,
68+
,Biscoe,38.6,17.2,199,3750,female,2009,,
69+
,Dream,49,19.5,210,3950,male,2008,,
70+
,Dream,40.7,17,190,3725,male,2009,,
71+
,Biscoe,52.5,15.6,221,5450,male,2009,,
72+
,Biscoe,46.2,14.4,214,4650,,2008,,
73+
,Torgersen,40.2,17,176,3450,female,2009,,
74+
,Biscoe,46.5,14.5,213,4400,female,2007,,
75+
,Biscoe,49.5,16.2,229,5800,male,2008,,
76+
,Torgersen,36.2,16.1,187,3550,female,,
77+
,Biscoe,41.3,21.1,195,4400,male,2008,,
78+
,Biscoe,45.1,14.5,207,5050,female,2007,,
79+
,Biscoe,47.5,15,218,4950,female,2009,,
80+
,Biscoe,49.1,15,228,5500,male,2009,,
81+
,Biscoe,45.5,15,220,5000,male,2008,,
82+
,Biscoe,35,17.9,192,3725,female,2009,,
83+
,Torgersen,35.5,17.5,190,3700,female,,
84+
,Biscoe,46.3,15.8,215,5050,male,2007,,
85+
,Dream,42.5,16.7,187,3350,female,2008,,
86+
,Torgersen,34.1,18.1,193,3475,,2007,,
87+
,Dream,37.5,18.5,199,4475,male,2009,,
88+
,Dream,36.4,17,195,3325,female,2007,,
89+
,Dream,45.7,17.3,193,3600,female,2009,,
90+
,Dream,51.9,19.5,206,3950,male,2009,,
91+
,Biscoe,46.2,14.5,209,4800,female,2007,,
92+
,Dream,42.5,17.3,187,3350,female,2009,,`;
93+
// The model (once loaded).
94+
let model = null;
95+
96+
function renderTable(data) {
97+
const table = document.getElementById("penguin_table");
98+
while (table.hasChildNodes()) {
99+
table.removeChild(table.lastChild);
100+
}
101+
102+
const headers = Object.keys(data[0]);
103+
const headerRow = table.insertRow();
104+
for (let header of headers) {
105+
const cell = headerRow.insertCell();
106+
cell.innerHTML = header;
107+
}
108+
109+
for (let record of data) {
110+
const row = table.insertRow();
111+
for (let header of headers) {
112+
const cell = row.insertCell();
113+
cell.innerHTML = record[header] ?? '';
114+
}
115+
}
116+
}
117+
118+
// Parse the CSV data
119+
const {data} = Papa.parse(dataset, {header: true});
120+
renderTable(data);
121+
122+
async function loadModel() {
123+
model = await tfdf.loadTFDFModel('https://storage.googleapis.com/tfjs-examples/tfdf-penguins/tfjs_model/model.json');
124+
const button = document.getElementById("btn_apply_model");
125+
button.disabled = false;
126+
button.innerText = "Classify Penguins!";
127+
}
128+
loadModel();
129+
130+
async function applyModel() {
131+
const inputs = {};
132+
const toDispose = [];
133+
134+
for (const {name, dtype} of model.inputs) {
135+
if (dtype === null) {
136+
continue;
137+
}
138+
const defaultVal = dtype === 'string' ? '' : 0;
139+
inputs[name] = tf.tensor1d(data.map(d => d[name] ?? defaultVal), dtype);
140+
toDispose.push(inputs[name]);
141+
}
142+
143+
const predictions = await model.executeAsync(inputs);
144+
toDispose.push(predictions);
145+
146+
const [classificationsTensor, confidencesTensor] = tf.tidy(() => {
147+
const classificationIndices = tf.argMax(predictions, 1);
148+
const classifications = classificationIndices.sub(3);
149+
const confidences = tf.gather(predictions, classificationIndices, 1, 1);
150+
151+
return [classifications, confidences];
152+
});
153+
154+
toDispose.push(classificationsTensor, confidencesTensor);
155+
156+
const classNames = ['Adelie', 'Gentoo', 'Chinstrap'];
157+
const classifications = (await classificationsTensor.array())
158+
.map(index => classNames[index]);
159+
const confidences = await confidencesTensor.array();
160+
161+
const dataCopy = structuredClone(data); // Make a copy of the data
162+
for (let i = 0; i < confidences.length; i++) {
163+
const record = dataCopy[i];
164+
record['Pred:species'] = classifications[i];
165+
record['Pred:Conf.species'] = confidences[i];
166+
}
167+
168+
// Clean up tensors
169+
tf.dispose(toDispose);
170+
171+
renderTable(dataCopy);
172+
}
173+
174+
document.getElementById("btn_apply_model").onclick = applyModel;
175+
</script>
176+
</body>
177+
178+
</html>

0 commit comments

Comments
 (0)