0

I'm trying to execute the model.predict function inside a web worker and I can't find anywhere how can I import Tensorflowjs inside the web worker.

I can use the importScripts('cdn') but how can I reference to tensorflow to use it's functions?

This is the code up until now:

worker.js

/// <reference lib="webworker" />
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js');

addEventListener('message', async ({ data }) => {
  const model = data.model;
  const pred = await model.predict(data.tensor);
  postMessage(pred);
});

service.ts

predict() {
   if (typeof Worker !== 'undefined') {
  // Create a new
  const worker = new Worker('../workers/model.worker', { type: 'module' });
  worker.onmessage = ({ data }) => {
    console.log(Array.from(data.dataSync()));
  };
  worker.postMessage({tensor, model: this._model});
  } else {
    // Web Workers are not supported in this environment.
    // You should add a fallback so that your program still executes correctly.
  }
}
Marco Ripamonti
  • 267
  • 1
  • 3
  • 12

2 Answers2

1

The data exchanged between the main worker and child workers whould be serializable. Therefore, you cannot pass the model itself nor a tf.tensor. You can on the other hand pass the data and construct the tensor in your workers.

For the compiler to know that you imported a global variable, you need to declare tf

declare let tf: any // or find better type
edkeveked
  • 17,989
  • 10
  • 55
  • 93
  • Declaring the variable doesn't seem to work. I declared and called for example tf.ones([2,2]) and this is the error I got: `TypeError: Cannot read property 'ones' of undefined` – Marco Ripamonti Dec 16 '20 at 09:19
  • Sorry, it is a declare statement. Could you try it again ? – edkeveked Dec 16 '20 at 09:20
0

You can't postMessage custom objects' methods, so you'd need to initialize this model tensor from the Worker itself so that it has all its methods.

To do that, if you do generate your tensor from a DOM element, you would first need to generate the tensor through tf.browser.fromPixels() from the main thread, then extract the tensor's data as a TypedArray that you will send to your Worker. Then in the Worker you will be able to create a new tensor instance from that TypedArray.

Here is a rewrite of the mobilenet example using a Worker, (the prediction results may take some time to appear).

onload = async (evt) => {
  const worker = new Worker( getWorkerURL() );
  const imgElement = document.querySelector('img');
  // get tensor as usual
  const img = tf.browser.fromPixels(imgElement);
  // extract as TypedArray so we can transfer to Worker
  const data = await img.data();
  img.dispose();
  // wait for the Worker is ready
  // (strange bug in Chrome where message events are lost otherwise...)
  worker.onmessage = (evt) => {
    // do something with the results
    worker.onmessage = ({ data }) =>
      showResults(imgElement, data);
    // transfer the data we extracted to the Worker
    worker.postMessage(data, [data.buffer]);
  };
}
function showResults(imgElement, classes) {
  const predictionContainer = document.createElement('div');
  predictionContainer.className = 'pred-container';

  const imgContainer = document.createElement('div');
  imgContainer.appendChild(imgElement);
  predictionContainer.appendChild(imgContainer);

  const probsContainer = document.createElement('div');
  for (let i = 0; i < classes.length; i++) {
    const row = document.createElement('div');
    row.className = 'row';

    const classElement = document.createElement('div');
    classElement.className = 'cell';
    classElement.innerText = classes[i].className;
    row.appendChild(classElement);

    const probsElement = document.createElement('div');
    probsElement.className = 'cell';
    probsElement.innerText = classes[i].probability.toFixed(3);
    row.appendChild(probsElement);

    probsContainer.appendChild(row);
  }
  predictionContainer.appendChild(probsContainer);

  document.body.prepend(predictionContainer);
}
function getWorkerURL() {
  const elem = document.querySelector("[type='worker-script']");
  const data = elem.textContent;
  const blob = new Blob( [ data ], { type: "text/javascript" } );
  return URL.createObjectURL( blob );
}
.pred-container {
  margin-bottom: 20px;
}

.pred-container > div {
  display: inline-block;
  margin-right: 20px;
  vertical-align: top;
}
.row {
  display: table-row;
}
.cell {
  display: table-cell;
  padding-right: 20px;
}
<!-- ### worker.js ### -->
<script type="worker-script">
// we need to load tensorflow here
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js");
(async ()=> {
const MOBILENET_MODEL_PATH =
    'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';

const IMAGE_SIZE = 224;
const TOPK_PREDICTIONS = 10;

// load the model
// note that 'tf' is available globally thanks to 'importScripts'
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);

// let the main thread know we are ready
// (strange bug in Chrome where message events are lost otherwise...)
postMessage("ready");
self.onmessage = async ( { data } ) => {
  const img = tf.tensor(data);
  const logits = tf.tidy(() => {
    const offset = tf.scalar(127.5);
    // Normalize the image from [0, 255] to [-1, 1].
    const normalized = img.sub(offset).div(offset);

    // Reshape to a single-element batch so we can pass it to predict.
    const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);
    return mobilenet.predict(batched);
  });

  // Convert logits to probabilities and class names.
  const classes = await getTopKClasses(logits, TOPK_PREDICTIONS);
  postMessage(classes);
}
async function getTopKClasses(logits, topK) {
  const values = await logits.data();

  const valuesAndIndices = [];
  for (let i = 0; i < values.length; i++) {
    valuesAndIndices.push({value: values[i], index: i});
  }
  valuesAndIndices.sort((a, b) => {
    return b.value - a.value;
  });
  const topkValues = new Float32Array(topK);
  const topkIndices = new Int32Array(topK);
  for (let i = 0; i < topK; i++) {
    topkValues[i] = valuesAndIndices[i].value;
    topkIndices[i] = valuesAndIndices[i].index;
  }

  const topClassesAndProbs = [];
  for (let i = 0; i < topkIndices.length; i++) {
    topClassesAndProbs.push({
      // would be too big to import https://github.com/tensorflow/tfjs-examples/blob/master/mobilenet/imagenet_classes.js
      // so we just show the index here
      className: topkIndices[i],
      probability: topkValues[i]
    })
  }
  return topClassesAndProbs;
}

})();
</script>

<!-- #### index.html ### -->
<!-- we need  to load tensorflow here too -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<img crossorigin src="https://upload.wikimedia.org/wikipedia/commons/thumb/a/ae/Katri.jpg/577px-Katri.jpg" width="224" height="224">
Kaiido
  • 123,334
  • 13
  • 219
  • 285
  • The problem with this is that I cannot use the reference to tf imported with importScripts. The compiler doesn't recognize it and throws an error. – Marco Ripamonti Dec 16 '20 at 08:55
  • That would be an other issue... See https://stackoverflow.com/questions/57774039/how-to-import-a-node-module-inside-an-angular-web-worker maybe, but I'm really not into angular – Kaiido Dec 16 '20 at 09:06