0

I am trying to make an element-wise addition on the weights of two different models.

I developed the following algorithm :

async function getWeights(url){
  return new Promise(async function(resolve, reject){
  const model  =  await tf.loadLayersModel(url);
  resolve(model.layers[0].getWeights);
});
}

async function aggregate(){
  return new Promise(function (resolve, reject){
    weights.push(getWeights('file://./mymodel/modelReceived.json'));
    weights.push(getWeights('file://./mymodel/model.json'));
    let averageLayer = tf.layers.average();
    console.log(weights.length);
    const average = averageLayer.apply([weights[0], weights[1]]);
    model.layers[0].setWeights[average];
    resolve(model);
  });

}

async function returnValue(){
  var model = await aggregate();
  console.log(model);
}

returnValue();

However, I am getting this error:

(node:20468) UnhandledPromiseRejectionWarning: Error: A merge layer should be called on an Array of at least 2 inputs. Got 1 input(s).

I created the models with the following code:

const modelOne = tf.sequential();
modelOne.add(tf.layers.dense({units: 100, activation: 'relu', inputShape: [50]}));
modelOne.compile({optimizer: 'sgd', loss: 'meanSquaredError', metrics: ['accuracy']});

Can anyone explain the error to me? Are there any alternative ways to make the addition?

1 Answers1

0

the function getWeights() returns a Promise so when you call weights.push(getWeights('...')) you are passing in a Promise instead of a Tensor. it can be updated like so:

weights.push(await getWeights('...'))

the Promise in getWeights() resolves to a Function (i.e. model.layers[0].getWeights) instead of resolving to the weights:

resolve(model.layers[0].getWeights())

you should not need to do both Promise and async/await. you can simplify the getWeights() function like so:

async function getWeights(url){
  const model  =  await tf.loadLayersModel(url);
  return model.layers[0].getWeights();
}

the aggregate() can also use similar updates.

you can find more detail on Promises & async/await here: https://stackoverflow.com/a/14220323

vabarbosa
  • 706
  • 1
  • 4
  • 9