wrong output of classifiergetting wrong prediction with custom model after loading save model in tensorflow.jsprint all layers outputHow to get an output/log from model.fit() while it's busy trainingHow to classify data for Tensorflow.js as 2d, 3d etc?Load Mobilenet model with Tensorflow.js and classify inside express middlewareWriting a Number Classifier Neural Network in Tensorflow.jsCan we add our DataSet(Images) to ML5 Data ModelHow to add Images in a tensorflow.js model and train the model for given images labelsScalar Output Tensorflow JSissue with tfjs converter for a saved_model - weights not saved and weights_manifest.json only has the output nodegetting wrong prediction with custom model after loading save model in tensorflow.js

A workplace installs custom certificates on personal devices, can this be used to decrypt HTTPS traffic?

I'm in charge of equipment buying but no one's ever happy with what I choose. How to fix this?

How do I rename a LINUX host without needing to reboot for the rename to take effect?

Taylor series of product of two functions

What is the opposite of 'gravitas'?

Indicating multiple different modes of speech (fantasy language or telepathy)

What will be the temperature on Earth when Sun finishes its main sequence?

I2C signal and power over long range (10meter cable)

Lifted its hind leg on or lifted its hind leg towards?

Can the electrostatic force be infinite in magnitude?

Books on the History of math research at European universities

Can the harmonic series explain the origin of the major scale?

How can a jailer prevent the Forge Cleric's Artisan's Blessing from being used?

Why are on-board computers allowed to change controls without notifying the pilots?

The most efficient algorithm to find all possible integer pairs which sum to a given integer

Hostile work environment after whistle-blowing on coworker and our boss. What do I do?

Simulating a probability of 1 of 2^N with less than N random bits

Invariance of results when scaling explanatory variables in logistic regression, is there a proof?

Who must act to prevent Brexit on March 29th?

What to do when my ideas aren't chosen, when I strongly disagree with the chosen solution?

Pronouncing Homer as in modern Greek

Can a Gentile theist be saved?

What was required to accept "troll"?

Can I Retrieve Email Addresses from BCC?



wrong output of classifier


getting wrong prediction with custom model after loading save model in tensorflow.jsprint all layers outputHow to get an output/log from model.fit() while it's busy trainingHow to classify data for Tensorflow.js as 2d, 3d etc?Load Mobilenet model with Tensorflow.js and classify inside express middlewareWriting a Number Classifier Neural Network in Tensorflow.jsCan we add our DataSet(Images) to ML5 Data ModelHow to add Images in a tensorflow.js model and train the model for given images labelsScalar Output Tensorflow JSissue with tfjs converter for a saved_model - weights not saved and weights_manifest.json only has the output nodegetting wrong prediction with custom model after loading save model in tensorflow.js













0















I'm new to machine learning and i used an mnist demo model to train a cat and dog classifier.But it doesn't seem to work very well.Here are some diagrams of the model:



onEpochEndonBatchEndperClassAccuracy



info



It seems that this model always predicts any input as a cat.
This is my code. Please help me.



index.js:




import IMAGE_H, IMAGE_W, MnistData from './data.js';


import * as ui from './ui.js';


let classNum = 0;
function createConvModel()

const model = tf.sequential();
model.add(tf.layers.conv2d(
inputShape: [IMAGE_H, IMAGE_W, 3],
kernelSize: 5,
filters: 32,
activation: 'relu'
));

model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

model.add(tf.layers.conv2d(kernelSize: 5, filters: 32, activation: 'relu'));

model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

model.add(tf.layers.conv2d(kernelSize: 5, filters: 64, activation: 'relu'));

model.add(tf.layers.flatten());

model.add(tf.layers.dense(units: 64, activation: 'relu'));

model.add(tf.layers.dense(units: classNum, activation: 'softmax'));

return model;



function createDenseModel()
const model = tf.sequential();
model.add(tf.layers.flatten(inputShape: [IMAGE_H, IMAGE_W, 3]));
model.add(tf.layers.dense(units: 42, activation: 'relu'));
model.add(tf.layers.dense(units: classNum, activation: 'softmax'));
return model;


async function train(model, fitCallbacks)
ui.logStatus('Training model...');

const optimizer = 'rmsprop';

model.compile(
optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
);

const batchSize = 64;

const trainEpochs = ui.getTrainEpochs();

let trainBatchCount = 0;

const trainData = data.getTrainData();
const valData = data.getValData();
const testData = data.getTestData();


await model.fit(trainData.xs, trainData.labels,
batchSize:batchSize,
validationData:[valData.xs,valData.labels],
shuffle:true,
epochs: trainEpochs,
callbacks: fitCallbacks
);
console.log("complete");
const classNames = ['cat','dog'];
const [preds, labels] = doPrediction(model,testData);
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = name: 'Accuracy', tab: 'Evaluation' ;
tfvis.show.perClassAccuracy(container, classAccuracy, classNames);



function doPrediction(model,testData)
const testxs = testData.xs;
const labels = testData.labels.argMax([-1]);
const preds = model.predict(testxs).argMax([-1]);

testxs.dispose();
return [preds, labels];


function createModel()
let model;
const modelType = ui.getModelTypeId();
if (modelType === 'ConvNet')
model = createConvModel();
else if (modelType === 'DenseNet')
model = createDenseModel();
else
throw new Error(`Invalid model type: $modelType`);

return model;


async function watchTraining(model)
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container =
name: 'charts', tab: 'Training', styles: height: '1000px'
;
const callbacks = tfvis.show.fitCallbacks(container, metrics);
return train(model, callbacks);


let data;
async function load()
tf.disableDeprecationWarnings();
classNum = await localforage.getItem('classNum');
tfvis.visor();
data = new MnistData();
await data.load();



ui.setTrainButtonCallback(async () =>
ui.logStatus('Loading data...');
await load();

ui.logStatus('Creating model...');
const model = createModel();
model.summary();

ui.logStatus('Starting model training...');

await watchTraining(model);
);





data.js:






export const IMAGE_H = 64;
export const IMAGE_W = 64;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;
let NUM_CLASSES = 0;
let trainImagesLabels;
let testLabels;
let trainImages ;
let testImages ;
let validateImages;
let validateLabels;
let validateSplit = 0.2;
let modelId;
let classNum;

/**
* A class that fetches the sprited MNIST dataset and provide data as
* tf.Tensors.
*/
export class MnistData
constructor()

//shuffle
static shuffleSwap(arr1,arr2)
if(arr1.length == 1) return arr1,arr2;
let i = arr1.length;
while(--i > 1)
let j = Math.floor(Math.random() * (i+1));
[arr1[i], arr1[j]] = [arr1[j], arr1[i]];
[arr2[i], arr2[j]] = [arr2[j], arr2[i]];

return arr1,arr2;


async load()
//get data from localforage
this.trainImages = await localforage.getItem('dataset');
this.trainImagesLabels = await localforage.getItem('datasetLabel');
this.modelId = await localforage.getItem('modelId');
this.classNum = await localforage.getItem('classNum');

this.trainImages.shift();
this.trainImagesLabels.shift();

//construct the validateData
let status = false;
let maxVal = Math.floor(this.trainImages.length * 0.2);

this.validateImages = new Array();
this.validateLabels = new Array();
for(let i=0;i<maxVal;i++)
if(status)
this.validateImages.push(this.trainImages.pop());
this.validateLabels.push(this.trainImagesLabels.pop());
status = false;
else
this.validateImages.push(this.trainImages.shift());
this.validateLabels.push(this.trainImagesLabels.shift());
status = true;


//construct the testData
this.testImages = new Array();
this.testLabels = new Array();
for(let i=0;i<maxVal;i++)
if(status)
this.testImages.push(this.trainImages.pop());
this.testLabels.push(this.trainImagesLabels.pop());
status = false;
else
this.testImages.push(this.trainImages.shift());
this.testLabels.push(this.trainImagesLabels.shift());
status = true;


//shuffle
let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
this.validateImages = val.arr1;
this.validateLabels = val.arr2;
let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
this.trainImages = train.arr1;
this.trainImagesLabels = train.arr2;




getTrainData()
const xs = tf.tensor4d(this.trainImages);
const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
return xs, labels;




getValData()
const xs = tf.tensor4d(this.validateImages);
const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
return xs, labels;


getTestData()
const xs = tf.tensor4d(this.testImages);
const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
return xs, labels;





I added some pictures at the beginning.




 
//getclassNum
function getClassNum(files)
let classArr = new Array();
let dirArr = new Array();
let imageNum = 0;
for (let i = 0; i < files.length; i++)
if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg')
dirArr = files[i].webkitRelativePath.split('/');
let currentClassIndex = dirArr.length - 2;
let isExist = false;
if (currentClassIndex <= 0)
isExist = true;
else
imageNum++;

if (classArr == null)
classArr.push(dirArr[currentClassIndex]);

for (let j = 0; j < classArr.length; j++)
if (classArr[j] == dirArr[currentClassIndex])
isExist = true;


if (!isExist)
classArr.push(dirArr[currentClassIndex]);



let classNum = classArr.length;
return classNum, imageNum, classArr;

//get nested array
function getDataset(files, classArr,imgNum)
let trainLabelArr = new Array();
let trainDataArr = new Array();
for (let i = 0; i < files.length; i++)
if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg')
let dirArr = files[i].webkitRelativePath.split('/');
let currentClassIndex = dirArr.length - 2;
if (currentClassIndex >= 0)
for(let j=0;j<classArr.length;j++)
if(dirArr[currentClassIndex]==classArr[j])
let reader = new FileReader();
reader.readAsDataURL(files[i]);
reader.onload = function ()
document.getElementById('image').setAttribute( 'src', reader.result);
let tensor= tf.browser.fromPixels(document.getElementById('image'));
let nest = tensor.arraySync();

trainDataArr.push(nest);
trainLabelArr.push(j);






returntrainDataArr,trainLabelArr,trainDataLength

//getfiles
async function fileChange(that)
let files = that.files;
let container = getClassNum(files);

let data = getDataset(files, container.classArr,container.imageNum);
let trainDataArr = data.trainDataArr;
let trainLabelArr = data.trainLabelArr;

setTimeout(function ()

localforage.setItem('dataset',trainDataArr,function (err,result)

);
localforage.setItem('datasetLabel',trainLabelArr,function (err,result)

);
localforage.setItem('modelId',modelId,function (err,result)

);
localforage.setItem('classNum',container.classNum,function (err,result)

);
,container.imageNum * 10);


}












share|improve this question




























    0















    I'm new to machine learning and i used an mnist demo model to train a cat and dog classifier.But it doesn't seem to work very well.Here are some diagrams of the model:



    onEpochEndonBatchEndperClassAccuracy



    info



    It seems that this model always predicts any input as a cat.
    This is my code. Please help me.



    index.js:




    import IMAGE_H, IMAGE_W, MnistData from './data.js';


    import * as ui from './ui.js';


    let classNum = 0;
    function createConvModel()

    const model = tf.sequential();
    model.add(tf.layers.conv2d(
    inputShape: [IMAGE_H, IMAGE_W, 3],
    kernelSize: 5,
    filters: 32,
    activation: 'relu'
    ));

    model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

    model.add(tf.layers.conv2d(kernelSize: 5, filters: 32, activation: 'relu'));

    model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

    model.add(tf.layers.conv2d(kernelSize: 5, filters: 64, activation: 'relu'));

    model.add(tf.layers.flatten());

    model.add(tf.layers.dense(units: 64, activation: 'relu'));

    model.add(tf.layers.dense(units: classNum, activation: 'softmax'));

    return model;



    function createDenseModel()
    const model = tf.sequential();
    model.add(tf.layers.flatten(inputShape: [IMAGE_H, IMAGE_W, 3]));
    model.add(tf.layers.dense(units: 42, activation: 'relu'));
    model.add(tf.layers.dense(units: classNum, activation: 'softmax'));
    return model;


    async function train(model, fitCallbacks)
    ui.logStatus('Training model...');

    const optimizer = 'rmsprop';

    model.compile(
    optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
    );

    const batchSize = 64;

    const trainEpochs = ui.getTrainEpochs();

    let trainBatchCount = 0;

    const trainData = data.getTrainData();
    const valData = data.getValData();
    const testData = data.getTestData();


    await model.fit(trainData.xs, trainData.labels,
    batchSize:batchSize,
    validationData:[valData.xs,valData.labels],
    shuffle:true,
    epochs: trainEpochs,
    callbacks: fitCallbacks
    );
    console.log("complete");
    const classNames = ['cat','dog'];
    const [preds, labels] = doPrediction(model,testData);
    const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
    const container = name: 'Accuracy', tab: 'Evaluation' ;
    tfvis.show.perClassAccuracy(container, classAccuracy, classNames);



    function doPrediction(model,testData)
    const testxs = testData.xs;
    const labels = testData.labels.argMax([-1]);
    const preds = model.predict(testxs).argMax([-1]);

    testxs.dispose();
    return [preds, labels];


    function createModel()
    let model;
    const modelType = ui.getModelTypeId();
    if (modelType === 'ConvNet')
    model = createConvModel();
    else if (modelType === 'DenseNet')
    model = createDenseModel();
    else
    throw new Error(`Invalid model type: $modelType`);

    return model;


    async function watchTraining(model)
    const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
    const container =
    name: 'charts', tab: 'Training', styles: height: '1000px'
    ;
    const callbacks = tfvis.show.fitCallbacks(container, metrics);
    return train(model, callbacks);


    let data;
    async function load()
    tf.disableDeprecationWarnings();
    classNum = await localforage.getItem('classNum');
    tfvis.visor();
    data = new MnistData();
    await data.load();



    ui.setTrainButtonCallback(async () =>
    ui.logStatus('Loading data...');
    await load();

    ui.logStatus('Creating model...');
    const model = createModel();
    model.summary();

    ui.logStatus('Starting model training...');

    await watchTraining(model);
    );





    data.js:






    export const IMAGE_H = 64;
    export const IMAGE_W = 64;
    const IMAGE_SIZE = IMAGE_H * IMAGE_W;
    let NUM_CLASSES = 0;
    let trainImagesLabels;
    let testLabels;
    let trainImages ;
    let testImages ;
    let validateImages;
    let validateLabels;
    let validateSplit = 0.2;
    let modelId;
    let classNum;

    /**
    * A class that fetches the sprited MNIST dataset and provide data as
    * tf.Tensors.
    */
    export class MnistData
    constructor()

    //shuffle
    static shuffleSwap(arr1,arr2)
    if(arr1.length == 1) return arr1,arr2;
    let i = arr1.length;
    while(--i > 1)
    let j = Math.floor(Math.random() * (i+1));
    [arr1[i], arr1[j]] = [arr1[j], arr1[i]];
    [arr2[i], arr2[j]] = [arr2[j], arr2[i]];

    return arr1,arr2;


    async load()
    //get data from localforage
    this.trainImages = await localforage.getItem('dataset');
    this.trainImagesLabels = await localforage.getItem('datasetLabel');
    this.modelId = await localforage.getItem('modelId');
    this.classNum = await localforage.getItem('classNum');

    this.trainImages.shift();
    this.trainImagesLabels.shift();

    //construct the validateData
    let status = false;
    let maxVal = Math.floor(this.trainImages.length * 0.2);

    this.validateImages = new Array();
    this.validateLabels = new Array();
    for(let i=0;i<maxVal;i++)
    if(status)
    this.validateImages.push(this.trainImages.pop());
    this.validateLabels.push(this.trainImagesLabels.pop());
    status = false;
    else
    this.validateImages.push(this.trainImages.shift());
    this.validateLabels.push(this.trainImagesLabels.shift());
    status = true;


    //construct the testData
    this.testImages = new Array();
    this.testLabels = new Array();
    for(let i=0;i<maxVal;i++)
    if(status)
    this.testImages.push(this.trainImages.pop());
    this.testLabels.push(this.trainImagesLabels.pop());
    status = false;
    else
    this.testImages.push(this.trainImages.shift());
    this.testLabels.push(this.trainImagesLabels.shift());
    status = true;


    //shuffle
    let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
    this.validateImages = val.arr1;
    this.validateLabels = val.arr2;
    let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
    this.trainImages = train.arr1;
    this.trainImagesLabels = train.arr2;




    getTrainData()
    const xs = tf.tensor4d(this.trainImages);
    const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
    return xs, labels;




    getValData()
    const xs = tf.tensor4d(this.validateImages);
    const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
    return xs, labels;


    getTestData()
    const xs = tf.tensor4d(this.testImages);
    const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
    return xs, labels;





    I added some pictures at the beginning.




     
    //getclassNum
    function getClassNum(files)
    let classArr = new Array();
    let dirArr = new Array();
    let imageNum = 0;
    for (let i = 0; i < files.length; i++)
    if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg')
    dirArr = files[i].webkitRelativePath.split('/');
    let currentClassIndex = dirArr.length - 2;
    let isExist = false;
    if (currentClassIndex <= 0)
    isExist = true;
    else
    imageNum++;

    if (classArr == null)
    classArr.push(dirArr[currentClassIndex]);

    for (let j = 0; j < classArr.length; j++)
    if (classArr[j] == dirArr[currentClassIndex])
    isExist = true;


    if (!isExist)
    classArr.push(dirArr[currentClassIndex]);



    let classNum = classArr.length;
    return classNum, imageNum, classArr;

    //get nested array
    function getDataset(files, classArr,imgNum)
    let trainLabelArr = new Array();
    let trainDataArr = new Array();
    for (let i = 0; i < files.length; i++)
    if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg')
    let dirArr = files[i].webkitRelativePath.split('/');
    let currentClassIndex = dirArr.length - 2;
    if (currentClassIndex >= 0)
    for(let j=0;j<classArr.length;j++)
    if(dirArr[currentClassIndex]==classArr[j])
    let reader = new FileReader();
    reader.readAsDataURL(files[i]);
    reader.onload = function ()
    document.getElementById('image').setAttribute( 'src', reader.result);
    let tensor= tf.browser.fromPixels(document.getElementById('image'));
    let nest = tensor.arraySync();

    trainDataArr.push(nest);
    trainLabelArr.push(j);






    returntrainDataArr,trainLabelArr,trainDataLength

    //getfiles
    async function fileChange(that)
    let files = that.files;
    let container = getClassNum(files);

    let data = getDataset(files, container.classArr,container.imageNum);
    let trainDataArr = data.trainDataArr;
    let trainLabelArr = data.trainLabelArr;

    setTimeout(function ()

    localforage.setItem('dataset',trainDataArr,function (err,result)

    );
    localforage.setItem('datasetLabel',trainLabelArr,function (err,result)

    );
    localforage.setItem('modelId',modelId,function (err,result)

    );
    localforage.setItem('classNum',container.classNum,function (err,result)

    );
    ,container.imageNum * 10);


    }












    share|improve this question


























      0












      0








      0








      I'm new to machine learning and i used an mnist demo model to train a cat and dog classifier.But it doesn't seem to work very well.Here are some diagrams of the model:



      onEpochEndonBatchEndperClassAccuracy



      info



      It seems that this model always predicts any input as a cat.
      This is my code. Please help me.



      index.js:




      import IMAGE_H, IMAGE_W, MnistData from './data.js';


      import * as ui from './ui.js';


      let classNum = 0;
      function createConvModel()

      const model = tf.sequential();
      model.add(tf.layers.conv2d(
      inputShape: [IMAGE_H, IMAGE_W, 3],
      kernelSize: 5,
      filters: 32,
      activation: 'relu'
      ));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 32, activation: 'relu'));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 64, activation: 'relu'));

      model.add(tf.layers.flatten());

      model.add(tf.layers.dense(units: 64, activation: 'relu'));

      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));

      return model;



      function createDenseModel()
      const model = tf.sequential();
      model.add(tf.layers.flatten(inputShape: [IMAGE_H, IMAGE_W, 3]));
      model.add(tf.layers.dense(units: 42, activation: 'relu'));
      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));
      return model;


      async function train(model, fitCallbacks)
      ui.logStatus('Training model...');

      const optimizer = 'rmsprop';

      model.compile(
      optimizer,
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy'],
      );

      const batchSize = 64;

      const trainEpochs = ui.getTrainEpochs();

      let trainBatchCount = 0;

      const trainData = data.getTrainData();
      const valData = data.getValData();
      const testData = data.getTestData();


      await model.fit(trainData.xs, trainData.labels,
      batchSize:batchSize,
      validationData:[valData.xs,valData.labels],
      shuffle:true,
      epochs: trainEpochs,
      callbacks: fitCallbacks
      );
      console.log("complete");
      const classNames = ['cat','dog'];
      const [preds, labels] = doPrediction(model,testData);
      const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
      const container = name: 'Accuracy', tab: 'Evaluation' ;
      tfvis.show.perClassAccuracy(container, classAccuracy, classNames);



      function doPrediction(model,testData)
      const testxs = testData.xs;
      const labels = testData.labels.argMax([-1]);
      const preds = model.predict(testxs).argMax([-1]);

      testxs.dispose();
      return [preds, labels];


      function createModel()
      let model;
      const modelType = ui.getModelTypeId();
      if (modelType === 'ConvNet')
      model = createConvModel();
      else if (modelType === 'DenseNet')
      model = createDenseModel();
      else
      throw new Error(`Invalid model type: $modelType`);

      return model;


      async function watchTraining(model)
      const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
      const container =
      name: 'charts', tab: 'Training', styles: height: '1000px'
      ;
      const callbacks = tfvis.show.fitCallbacks(container, metrics);
      return train(model, callbacks);


      let data;
      async function load()
      tf.disableDeprecationWarnings();
      classNum = await localforage.getItem('classNum');
      tfvis.visor();
      data = new MnistData();
      await data.load();



      ui.setTrainButtonCallback(async () =>
      ui.logStatus('Loading data...');
      await load();

      ui.logStatus('Creating model...');
      const model = createModel();
      model.summary();

      ui.logStatus('Starting model training...');

      await watchTraining(model);
      );





      data.js:






      export const IMAGE_H = 64;
      export const IMAGE_W = 64;
      const IMAGE_SIZE = IMAGE_H * IMAGE_W;
      let NUM_CLASSES = 0;
      let trainImagesLabels;
      let testLabels;
      let trainImages ;
      let testImages ;
      let validateImages;
      let validateLabels;
      let validateSplit = 0.2;
      let modelId;
      let classNum;

      /**
      * A class that fetches the sprited MNIST dataset and provide data as
      * tf.Tensors.
      */
      export class MnistData
      constructor()

      //shuffle
      static shuffleSwap(arr1,arr2)
      if(arr1.length == 1) return arr1,arr2;
      let i = arr1.length;
      while(--i > 1)
      let j = Math.floor(Math.random() * (i+1));
      [arr1[i], arr1[j]] = [arr1[j], arr1[i]];
      [arr2[i], arr2[j]] = [arr2[j], arr2[i]];

      return arr1,arr2;


      async load()
      //get data from localforage
      this.trainImages = await localforage.getItem('dataset');
      this.trainImagesLabels = await localforage.getItem('datasetLabel');
      this.modelId = await localforage.getItem('modelId');
      this.classNum = await localforage.getItem('classNum');

      this.trainImages.shift();
      this.trainImagesLabels.shift();

      //construct the validateData
      let status = false;
      let maxVal = Math.floor(this.trainImages.length * 0.2);

      this.validateImages = new Array();
      this.validateLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.validateImages.push(this.trainImages.pop());
      this.validateLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.validateImages.push(this.trainImages.shift());
      this.validateLabels.push(this.trainImagesLabels.shift());
      status = true;


      //construct the testData
      this.testImages = new Array();
      this.testLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.testImages.push(this.trainImages.pop());
      this.testLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.testImages.push(this.trainImages.shift());
      this.testLabels.push(this.trainImagesLabels.shift());
      status = true;


      //shuffle
      let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
      this.validateImages = val.arr1;
      this.validateLabels = val.arr2;
      let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
      this.trainImages = train.arr1;
      this.trainImagesLabels = train.arr2;




      getTrainData()
      const xs = tf.tensor4d(this.trainImages);
      const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
      return xs, labels;




      getValData()
      const xs = tf.tensor4d(this.validateImages);
      const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
      return xs, labels;


      getTestData()
      const xs = tf.tensor4d(this.testImages);
      const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
      return xs, labels;





      I added some pictures at the beginning.




       
      //getclassNum
      function getClassNum(files)
      let classArr = new Array();
      let dirArr = new Array();
      let imageNum = 0;
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg')
      dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      let isExist = false;
      if (currentClassIndex <= 0)
      isExist = true;
      else
      imageNum++;

      if (classArr == null)
      classArr.push(dirArr[currentClassIndex]);

      for (let j = 0; j < classArr.length; j++)
      if (classArr[j] == dirArr[currentClassIndex])
      isExist = true;


      if (!isExist)
      classArr.push(dirArr[currentClassIndex]);



      let classNum = classArr.length;
      return classNum, imageNum, classArr;

      //get nested array
      function getDataset(files, classArr,imgNum)
      let trainLabelArr = new Array();
      let trainDataArr = new Array();
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg')
      let dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      if (currentClassIndex >= 0)
      for(let j=0;j<classArr.length;j++)
      if(dirArr[currentClassIndex]==classArr[j])
      let reader = new FileReader();
      reader.readAsDataURL(files[i]);
      reader.onload = function ()
      document.getElementById('image').setAttribute( 'src', reader.result);
      let tensor= tf.browser.fromPixels(document.getElementById('image'));
      let nest = tensor.arraySync();

      trainDataArr.push(nest);
      trainLabelArr.push(j);






      returntrainDataArr,trainLabelArr,trainDataLength

      //getfiles
      async function fileChange(that)
      let files = that.files;
      let container = getClassNum(files);

      let data = getDataset(files, container.classArr,container.imageNum);
      let trainDataArr = data.trainDataArr;
      let trainLabelArr = data.trainLabelArr;

      setTimeout(function ()

      localforage.setItem('dataset',trainDataArr,function (err,result)

      );
      localforage.setItem('datasetLabel',trainLabelArr,function (err,result)

      );
      localforage.setItem('modelId',modelId,function (err,result)

      );
      localforage.setItem('classNum',container.classNum,function (err,result)

      );
      ,container.imageNum * 10);


      }












      share|improve this question
















      I'm new to machine learning and i used an mnist demo model to train a cat and dog classifier.But it doesn't seem to work very well.Here are some diagrams of the model:



      onEpochEndonBatchEndperClassAccuracy



      info



      It seems that this model always predicts any input as a cat.
      This is my code. Please help me.



      index.js:




      import IMAGE_H, IMAGE_W, MnistData from './data.js';


      import * as ui from './ui.js';


      let classNum = 0;
      function createConvModel()

      const model = tf.sequential();
      model.add(tf.layers.conv2d(
      inputShape: [IMAGE_H, IMAGE_W, 3],
      kernelSize: 5,
      filters: 32,
      activation: 'relu'
      ));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 32, activation: 'relu'));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 64, activation: 'relu'));

      model.add(tf.layers.flatten());

      model.add(tf.layers.dense(units: 64, activation: 'relu'));

      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));

      return model;



      function createDenseModel()
      const model = tf.sequential();
      model.add(tf.layers.flatten(inputShape: [IMAGE_H, IMAGE_W, 3]));
      model.add(tf.layers.dense(units: 42, activation: 'relu'));
      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));
      return model;


      async function train(model, fitCallbacks)
      ui.logStatus('Training model...');

      const optimizer = 'rmsprop';

      model.compile(
      optimizer,
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy'],
      );

      const batchSize = 64;

      const trainEpochs = ui.getTrainEpochs();

      let trainBatchCount = 0;

      const trainData = data.getTrainData();
      const valData = data.getValData();
      const testData = data.getTestData();


      await model.fit(trainData.xs, trainData.labels,
      batchSize:batchSize,
      validationData:[valData.xs,valData.labels],
      shuffle:true,
      epochs: trainEpochs,
      callbacks: fitCallbacks
      );
      console.log("complete");
      const classNames = ['cat','dog'];
      const [preds, labels] = doPrediction(model,testData);
      const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
      const container = name: 'Accuracy', tab: 'Evaluation' ;
      tfvis.show.perClassAccuracy(container, classAccuracy, classNames);



      function doPrediction(model,testData)
      const testxs = testData.xs;
      const labels = testData.labels.argMax([-1]);
      const preds = model.predict(testxs).argMax([-1]);

      testxs.dispose();
      return [preds, labels];


      function createModel()
      let model;
      const modelType = ui.getModelTypeId();
      if (modelType === 'ConvNet')
      model = createConvModel();
      else if (modelType === 'DenseNet')
      model = createDenseModel();
      else
      throw new Error(`Invalid model type: $modelType`);

      return model;


      async function watchTraining(model)
      const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
      const container =
      name: 'charts', tab: 'Training', styles: height: '1000px'
      ;
      const callbacks = tfvis.show.fitCallbacks(container, metrics);
      return train(model, callbacks);


      let data;
      async function load()
      tf.disableDeprecationWarnings();
      classNum = await localforage.getItem('classNum');
      tfvis.visor();
      data = new MnistData();
      await data.load();



      ui.setTrainButtonCallback(async () =>
      ui.logStatus('Loading data...');
      await load();

      ui.logStatus('Creating model...');
      const model = createModel();
      model.summary();

      ui.logStatus('Starting model training...');

      await watchTraining(model);
      );





      data.js:






      export const IMAGE_H = 64;
      export const IMAGE_W = 64;
      const IMAGE_SIZE = IMAGE_H * IMAGE_W;
      let NUM_CLASSES = 0;
      let trainImagesLabels;
      let testLabels;
      let trainImages ;
      let testImages ;
      let validateImages;
      let validateLabels;
      let validateSplit = 0.2;
      let modelId;
      let classNum;

      /**
      * A class that fetches the sprited MNIST dataset and provide data as
      * tf.Tensors.
      */
      export class MnistData
      constructor()

      //shuffle
      static shuffleSwap(arr1,arr2)
      if(arr1.length == 1) return arr1,arr2;
      let i = arr1.length;
      while(--i > 1)
      let j = Math.floor(Math.random() * (i+1));
      [arr1[i], arr1[j]] = [arr1[j], arr1[i]];
      [arr2[i], arr2[j]] = [arr2[j], arr2[i]];

      return arr1,arr2;


      async load()
      //get data from localforage
      this.trainImages = await localforage.getItem('dataset');
      this.trainImagesLabels = await localforage.getItem('datasetLabel');
      this.modelId = await localforage.getItem('modelId');
      this.classNum = await localforage.getItem('classNum');

      this.trainImages.shift();
      this.trainImagesLabels.shift();

      //construct the validateData
      let status = false;
      let maxVal = Math.floor(this.trainImages.length * 0.2);

      this.validateImages = new Array();
      this.validateLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.validateImages.push(this.trainImages.pop());
      this.validateLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.validateImages.push(this.trainImages.shift());
      this.validateLabels.push(this.trainImagesLabels.shift());
      status = true;


      //construct the testData
      this.testImages = new Array();
      this.testLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.testImages.push(this.trainImages.pop());
      this.testLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.testImages.push(this.trainImages.shift());
      this.testLabels.push(this.trainImagesLabels.shift());
      status = true;


      //shuffle
      let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
      this.validateImages = val.arr1;
      this.validateLabels = val.arr2;
      let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
      this.trainImages = train.arr1;
      this.trainImagesLabels = train.arr2;




      getTrainData()
      const xs = tf.tensor4d(this.trainImages);
      const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
      return xs, labels;




      getValData()
      const xs = tf.tensor4d(this.validateImages);
      const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
      return xs, labels;


      getTestData()
      const xs = tf.tensor4d(this.testImages);
      const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
      return xs, labels;





      I added some pictures at the beginning.




       
      //getclassNum
      function getClassNum(files)
      let classArr = new Array();
      let dirArr = new Array();
      let imageNum = 0;
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg')
      dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      let isExist = false;
      if (currentClassIndex <= 0)
      isExist = true;
      else
      imageNum++;

      if (classArr == null)
      classArr.push(dirArr[currentClassIndex]);

      for (let j = 0; j < classArr.length; j++)
      if (classArr[j] == dirArr[currentClassIndex])
      isExist = true;


      if (!isExist)
      classArr.push(dirArr[currentClassIndex]);



      let classNum = classArr.length;
      return classNum, imageNum, classArr;

      //get nested array
      function getDataset(files, classArr,imgNum)
      let trainLabelArr = new Array();
      let trainDataArr = new Array();
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg')
      let dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      if (currentClassIndex >= 0)
      for(let j=0;j<classArr.length;j++)
      if(dirArr[currentClassIndex]==classArr[j])
      let reader = new FileReader();
      reader.readAsDataURL(files[i]);
      reader.onload = function ()
      document.getElementById('image').setAttribute( 'src', reader.result);
      let tensor= tf.browser.fromPixels(document.getElementById('image'));
      let nest = tensor.arraySync();

      trainDataArr.push(nest);
      trainLabelArr.push(j);






      returntrainDataArr,trainLabelArr,trainDataLength

      //getfiles
      async function fileChange(that)
      let files = that.files;
      let container = getClassNum(files);

      let data = getDataset(files, container.classArr,container.imageNum);
      let trainDataArr = data.trainDataArr;
      let trainLabelArr = data.trainLabelArr;

      setTimeout(function ()

      localforage.setItem('dataset',trainDataArr,function (err,result)

      );
      localforage.setItem('datasetLabel',trainLabelArr,function (err,result)

      );
      localforage.setItem('modelId',modelId,function (err,result)

      );
      localforage.setItem('classNum',container.classNum,function (err,result)

      );
      ,container.imageNum * 10);


      }








      import IMAGE_H, IMAGE_W, MnistData from './data.js';


      import * as ui from './ui.js';


      let classNum = 0;
      function createConvModel()

      const model = tf.sequential();
      model.add(tf.layers.conv2d(
      inputShape: [IMAGE_H, IMAGE_W, 3],
      kernelSize: 5,
      filters: 32,
      activation: 'relu'
      ));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 32, activation: 'relu'));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 64, activation: 'relu'));

      model.add(tf.layers.flatten());

      model.add(tf.layers.dense(units: 64, activation: 'relu'));

      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));

      return model;



      function createDenseModel()
      const model = tf.sequential();
      model.add(tf.layers.flatten(inputShape: [IMAGE_H, IMAGE_W, 3]));
      model.add(tf.layers.dense(units: 42, activation: 'relu'));
      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));
      return model;


      async function train(model, fitCallbacks)
      ui.logStatus('Training model...');

      const optimizer = 'rmsprop';

      model.compile(
      optimizer,
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy'],
      );

      const batchSize = 64;

      const trainEpochs = ui.getTrainEpochs();

      let trainBatchCount = 0;

      const trainData = data.getTrainData();
      const valData = data.getValData();
      const testData = data.getTestData();


      await model.fit(trainData.xs, trainData.labels,
      batchSize:batchSize,
      validationData:[valData.xs,valData.labels],
      shuffle:true,
      epochs: trainEpochs,
      callbacks: fitCallbacks
      );
      console.log("complete");
      const classNames = ['cat','dog'];
      const [preds, labels] = doPrediction(model,testData);
      const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
      const container = name: 'Accuracy', tab: 'Evaluation' ;
      tfvis.show.perClassAccuracy(container, classAccuracy, classNames);



      function doPrediction(model,testData)
      const testxs = testData.xs;
      const labels = testData.labels.argMax([-1]);
      const preds = model.predict(testxs).argMax([-1]);

      testxs.dispose();
      return [preds, labels];


      function createModel()
      let model;
      const modelType = ui.getModelTypeId();
      if (modelType === 'ConvNet')
      model = createConvModel();
      else if (modelType === 'DenseNet')
      model = createDenseModel();
      else
      throw new Error(`Invalid model type: $modelType`);

      return model;


      async function watchTraining(model)
      const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
      const container =
      name: 'charts', tab: 'Training', styles: height: '1000px'
      ;
      const callbacks = tfvis.show.fitCallbacks(container, metrics);
      return train(model, callbacks);


      let data;
      async function load()
      tf.disableDeprecationWarnings();
      classNum = await localforage.getItem('classNum');
      tfvis.visor();
      data = new MnistData();
      await data.load();



      ui.setTrainButtonCallback(async () =>
      ui.logStatus('Loading data...');
      await load();

      ui.logStatus('Creating model...');
      const model = createModel();
      model.summary();

      ui.logStatus('Starting model training...');

      await watchTraining(model);
      );





      import IMAGE_H, IMAGE_W, MnistData from './data.js';


      import * as ui from './ui.js';


      let classNum = 0;
      function createConvModel()

      const model = tf.sequential();
      model.add(tf.layers.conv2d(
      inputShape: [IMAGE_H, IMAGE_W, 3],
      kernelSize: 5,
      filters: 32,
      activation: 'relu'
      ));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 32, activation: 'relu'));

      model.add(tf.layers.maxPooling2d(poolSize: 2, strides: 2));

      model.add(tf.layers.conv2d(kernelSize: 5, filters: 64, activation: 'relu'));

      model.add(tf.layers.flatten());

      model.add(tf.layers.dense(units: 64, activation: 'relu'));

      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));

      return model;



      function createDenseModel()
      const model = tf.sequential();
      model.add(tf.layers.flatten(inputShape: [IMAGE_H, IMAGE_W, 3]));
      model.add(tf.layers.dense(units: 42, activation: 'relu'));
      model.add(tf.layers.dense(units: classNum, activation: 'softmax'));
      return model;


      async function train(model, fitCallbacks)
      ui.logStatus('Training model...');

      const optimizer = 'rmsprop';

      model.compile(
      optimizer,
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy'],
      );

      const batchSize = 64;

      const trainEpochs = ui.getTrainEpochs();

      let trainBatchCount = 0;

      const trainData = data.getTrainData();
      const valData = data.getValData();
      const testData = data.getTestData();


      await model.fit(trainData.xs, trainData.labels,
      batchSize:batchSize,
      validationData:[valData.xs,valData.labels],
      shuffle:true,
      epochs: trainEpochs,
      callbacks: fitCallbacks
      );
      console.log("complete");
      const classNames = ['cat','dog'];
      const [preds, labels] = doPrediction(model,testData);
      const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
      const container = name: 'Accuracy', tab: 'Evaluation' ;
      tfvis.show.perClassAccuracy(container, classAccuracy, classNames);



      function doPrediction(model,testData)
      const testxs = testData.xs;
      const labels = testData.labels.argMax([-1]);
      const preds = model.predict(testxs).argMax([-1]);

      testxs.dispose();
      return [preds, labels];


      function createModel()
      let model;
      const modelType = ui.getModelTypeId();
      if (modelType === 'ConvNet')
      model = createConvModel();
      else if (modelType === 'DenseNet')
      model = createDenseModel();
      else
      throw new Error(`Invalid model type: $modelType`);

      return model;


      async function watchTraining(model)
      const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
      const container =
      name: 'charts', tab: 'Training', styles: height: '1000px'
      ;
      const callbacks = tfvis.show.fitCallbacks(container, metrics);
      return train(model, callbacks);


      let data;
      async function load()
      tf.disableDeprecationWarnings();
      classNum = await localforage.getItem('classNum');
      tfvis.visor();
      data = new MnistData();
      await data.load();



      ui.setTrainButtonCallback(async () =>
      ui.logStatus('Loading data...');
      await load();

      ui.logStatus('Creating model...');
      const model = createModel();
      model.summary();

      ui.logStatus('Starting model training...');

      await watchTraining(model);
      );





      export const IMAGE_H = 64;
      export const IMAGE_W = 64;
      const IMAGE_SIZE = IMAGE_H * IMAGE_W;
      let NUM_CLASSES = 0;
      let trainImagesLabels;
      let testLabels;
      let trainImages ;
      let testImages ;
      let validateImages;
      let validateLabels;
      let validateSplit = 0.2;
      let modelId;
      let classNum;

      /**
      * A class that fetches the sprited MNIST dataset and provide data as
      * tf.Tensors.
      */
      export class MnistData
      constructor()

      //shuffle
      static shuffleSwap(arr1,arr2)
      if(arr1.length == 1) return arr1,arr2;
      let i = arr1.length;
      while(--i > 1)
      let j = Math.floor(Math.random() * (i+1));
      [arr1[i], arr1[j]] = [arr1[j], arr1[i]];
      [arr2[i], arr2[j]] = [arr2[j], arr2[i]];

      return arr1,arr2;


      async load()
      //get data from localforage
      this.trainImages = await localforage.getItem('dataset');
      this.trainImagesLabels = await localforage.getItem('datasetLabel');
      this.modelId = await localforage.getItem('modelId');
      this.classNum = await localforage.getItem('classNum');

      this.trainImages.shift();
      this.trainImagesLabels.shift();

      //construct the validateData
      let status = false;
      let maxVal = Math.floor(this.trainImages.length * 0.2);

      this.validateImages = new Array();
      this.validateLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.validateImages.push(this.trainImages.pop());
      this.validateLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.validateImages.push(this.trainImages.shift());
      this.validateLabels.push(this.trainImagesLabels.shift());
      status = true;


      //construct the testData
      this.testImages = new Array();
      this.testLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.testImages.push(this.trainImages.pop());
      this.testLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.testImages.push(this.trainImages.shift());
      this.testLabels.push(this.trainImagesLabels.shift());
      status = true;


      //shuffle
      let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
      this.validateImages = val.arr1;
      this.validateLabels = val.arr2;
      let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
      this.trainImages = train.arr1;
      this.trainImagesLabels = train.arr2;




      getTrainData()
      const xs = tf.tensor4d(this.trainImages);
      const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
      return xs, labels;




      getValData()
      const xs = tf.tensor4d(this.validateImages);
      const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
      return xs, labels;


      getTestData()
      const xs = tf.tensor4d(this.testImages);
      const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
      return xs, labels;






      export const IMAGE_H = 64;
      export const IMAGE_W = 64;
      const IMAGE_SIZE = IMAGE_H * IMAGE_W;
      let NUM_CLASSES = 0;
      let trainImagesLabels;
      let testLabels;
      let trainImages ;
      let testImages ;
      let validateImages;
      let validateLabels;
      let validateSplit = 0.2;
      let modelId;
      let classNum;

      /**
      * A class that fetches the sprited MNIST dataset and provide data as
      * tf.Tensors.
      */
      export class MnistData
      constructor()

      //shuffle
      static shuffleSwap(arr1,arr2)
      if(arr1.length == 1) return arr1,arr2;
      let i = arr1.length;
      while(--i > 1)
      let j = Math.floor(Math.random() * (i+1));
      [arr1[i], arr1[j]] = [arr1[j], arr1[i]];
      [arr2[i], arr2[j]] = [arr2[j], arr2[i]];

      return arr1,arr2;


      async load()
      //get data from localforage
      this.trainImages = await localforage.getItem('dataset');
      this.trainImagesLabels = await localforage.getItem('datasetLabel');
      this.modelId = await localforage.getItem('modelId');
      this.classNum = await localforage.getItem('classNum');

      this.trainImages.shift();
      this.trainImagesLabels.shift();

      //construct the validateData
      let status = false;
      let maxVal = Math.floor(this.trainImages.length * 0.2);

      this.validateImages = new Array();
      this.validateLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.validateImages.push(this.trainImages.pop());
      this.validateLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.validateImages.push(this.trainImages.shift());
      this.validateLabels.push(this.trainImagesLabels.shift());
      status = true;


      //construct the testData
      this.testImages = new Array();
      this.testLabels = new Array();
      for(let i=0;i<maxVal;i++)
      if(status)
      this.testImages.push(this.trainImages.pop());
      this.testLabels.push(this.trainImagesLabels.pop());
      status = false;
      else
      this.testImages.push(this.trainImages.shift());
      this.testLabels.push(this.trainImagesLabels.shift());
      status = true;


      //shuffle
      let val = MnistData.shuffleSwap(this.validateImages,this.validateLabels);
      this.validateImages = val.arr1;
      this.validateLabels = val.arr2;
      let train = MnistData.shuffleSwap(this.trainImages,this.trainImagesLabels);
      this.trainImages = train.arr1;
      this.trainImagesLabels = train.arr2;




      getTrainData()
      const xs = tf.tensor4d(this.trainImages);
      const labels = tf.oneHot(tf.tensor1d(this.trainImagesLabels,'int32'),this.classNum);
      return xs, labels;




      getValData()
      const xs = tf.tensor4d(this.validateImages);
      const labels = tf.oneHot(tf.tensor1d(this.validateLabels,'int32'),this.classNum);
      return xs, labels;


      getTestData()
      const xs = tf.tensor4d(this.testImages);
      const labels = tf.oneHot(tf.tensor1d(this.testLabels,'int32'),this.classNum);
      return xs, labels;






       
      //getclassNum
      function getClassNum(files)
      let classArr = new Array();
      let dirArr = new Array();
      let imageNum = 0;
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg')
      dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      let isExist = false;
      if (currentClassIndex <= 0)
      isExist = true;
      else
      imageNum++;

      if (classArr == null)
      classArr.push(dirArr[currentClassIndex]);

      for (let j = 0; j < classArr.length; j++)
      if (classArr[j] == dirArr[currentClassIndex])
      isExist = true;


      if (!isExist)
      classArr.push(dirArr[currentClassIndex]);



      let classNum = classArr.length;
      return classNum, imageNum, classArr;

      //get nested array
      function getDataset(files, classArr,imgNum)
      let trainLabelArr = new Array();
      let trainDataArr = new Array();
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg')
      let dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      if (currentClassIndex >= 0)
      for(let j=0;j<classArr.length;j++)
      if(dirArr[currentClassIndex]==classArr[j])
      let reader = new FileReader();
      reader.readAsDataURL(files[i]);
      reader.onload = function ()
      document.getElementById('image').setAttribute( 'src', reader.result);
      let tensor= tf.browser.fromPixels(document.getElementById('image'));
      let nest = tensor.arraySync();

      trainDataArr.push(nest);
      trainLabelArr.push(j);






      returntrainDataArr,trainLabelArr,trainDataLength

      //getfiles
      async function fileChange(that)
      let files = that.files;
      let container = getClassNum(files);

      let data = getDataset(files, container.classArr,container.imageNum);
      let trainDataArr = data.trainDataArr;
      let trainLabelArr = data.trainLabelArr;

      setTimeout(function ()

      localforage.setItem('dataset',trainDataArr,function (err,result)

      );
      localforage.setItem('datasetLabel',trainLabelArr,function (err,result)

      );
      localforage.setItem('modelId',modelId,function (err,result)

      );
      localforage.setItem('classNum',container.classNum,function (err,result)

      );
      ,container.imageNum * 10);


      }





       
      //getclassNum
      function getClassNum(files)
      let classArr = new Array();
      let dirArr = new Array();
      let imageNum = 0;
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image' && files[i].type.split('/')[1] == 'jpeg')
      dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      let isExist = false;
      if (currentClassIndex <= 0)
      isExist = true;
      else
      imageNum++;

      if (classArr == null)
      classArr.push(dirArr[currentClassIndex]);

      for (let j = 0; j < classArr.length; j++)
      if (classArr[j] == dirArr[currentClassIndex])
      isExist = true;


      if (!isExist)
      classArr.push(dirArr[currentClassIndex]);



      let classNum = classArr.length;
      return classNum, imageNum, classArr;

      //get nested array
      function getDataset(files, classArr,imgNum)
      let trainLabelArr = new Array();
      let trainDataArr = new Array();
      for (let i = 0; i < files.length; i++)
      if (files[i].type.split('/')[0] == 'image'&& files[i].type.split('/')[1] == 'jpeg')
      let dirArr = files[i].webkitRelativePath.split('/');
      let currentClassIndex = dirArr.length - 2;
      if (currentClassIndex >= 0)
      for(let j=0;j<classArr.length;j++)
      if(dirArr[currentClassIndex]==classArr[j])
      let reader = new FileReader();
      reader.readAsDataURL(files[i]);
      reader.onload = function ()
      document.getElementById('image').setAttribute( 'src', reader.result);
      let tensor= tf.browser.fromPixels(document.getElementById('image'));
      let nest = tensor.arraySync();

      trainDataArr.push(nest);
      trainLabelArr.push(j);






      returntrainDataArr,trainLabelArr,trainDataLength

      //getfiles
      async function fileChange(that)
      let files = that.files;
      let container = getClassNum(files);

      let data = getDataset(files, container.classArr,container.imageNum);
      let trainDataArr = data.trainDataArr;
      let trainLabelArr = data.trainLabelArr;

      setTimeout(function ()

      localforage.setItem('dataset',trainDataArr,function (err,result)

      );
      localforage.setItem('datasetLabel',trainLabelArr,function (err,result)

      );
      localforage.setItem('modelId',modelId,function (err,result)

      );
      localforage.setItem('classNum',container.classNum,function (err,result)

      );
      ,container.imageNum * 10);


      }






      tensorflow.js






      share|improve this question















      share|improve this question













      share|improve this question




      share|improve this question








      edited Mar 7 at 16:13







      chong zhao

















      asked Mar 7 at 10:28









      chong zhaochong zhao

      11




      11






















          1 Answer
          1






          active

          oldest

          votes


















          0














          Let me answer my question. After a day of testing, I found that this model needs a lot of data. Each category requires at least 1,000 images. If there is not enough training data, the model can only output one result. Moreover, this model performs very well in recognizing objects with fewer characters such as letters and signs, and not very well in recognizing animals or natural environments.






          share|improve this answer






















            Your Answer






            StackExchange.ifUsing("editor", function ()
            StackExchange.using("externalEditor", function ()
            StackExchange.using("snippets", function ()
            StackExchange.snippets.init();
            );
            );
            , "code-snippets");

            StackExchange.ready(function()
            var channelOptions =
            tags: "".split(" "),
            id: "1"
            ;
            initTagRenderer("".split(" "), "".split(" "), channelOptions);

            StackExchange.using("externalEditor", function()
            // Have to fire editor after snippets, if snippets enabled
            if (StackExchange.settings.snippets.snippetsEnabled)
            StackExchange.using("snippets", function()
            createEditor();
            );

            else
            createEditor();

            );

            function createEditor()
            StackExchange.prepareEditor(
            heartbeatType: 'answer',
            autoActivateHeartbeat: false,
            convertImagesToLinks: true,
            noModals: true,
            showLowRepImageUploadWarning: true,
            reputationToPostImages: 10,
            bindNavPrevention: true,
            postfix: "",
            imageUploader:
            brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
            contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
            allowUrls: true
            ,
            onDemand: true,
            discardSelector: ".discard-answer"
            ,immediatelyShowMarkdownHelp:true
            );



            );













            draft saved

            draft discarded


















            StackExchange.ready(
            function ()
            StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fstackoverflow.com%2fquestions%2f55041537%2fwrong-output-of-classifier%23new-answer', 'question_page');

            );

            Post as a guest















            Required, but never shown

























            1 Answer
            1






            active

            oldest

            votes








            1 Answer
            1






            active

            oldest

            votes









            active

            oldest

            votes






            active

            oldest

            votes









            0














            Let me answer my question. After a day of testing, I found that this model needs a lot of data. Each category requires at least 1,000 images. If there is not enough training data, the model can only output one result. Moreover, this model performs very well in recognizing objects with fewer characters such as letters and signs, and not very well in recognizing animals or natural environments.






            share|improve this answer



























              0














              Let me answer my question. After a day of testing, I found that this model needs a lot of data. Each category requires at least 1,000 images. If there is not enough training data, the model can only output one result. Moreover, this model performs very well in recognizing objects with fewer characters such as letters and signs, and not very well in recognizing animals or natural environments.






              share|improve this answer

























                0












                0








                0







                Let me answer my question. After a day of testing, I found that this model needs a lot of data. Each category requires at least 1,000 images. If there is not enough training data, the model can only output one result. Moreover, this model performs very well in recognizing objects with fewer characters such as letters and signs, and not very well in recognizing animals or natural environments.






                share|improve this answer













                Let me answer my question. After a day of testing, I found that this model needs a lot of data. Each category requires at least 1,000 images. If there is not enough training data, the model can only output one result. Moreover, this model performs very well in recognizing objects with fewer characters such as letters and signs, and not very well in recognizing animals or natural environments.







                share|improve this answer












                share|improve this answer



                share|improve this answer










                answered Mar 8 at 7:45









                chong zhaochong zhao

                11




                11





























                    draft saved

                    draft discarded
















































                    Thanks for contributing an answer to Stack Overflow!


                    • Please be sure to answer the question. Provide details and share your research!

                    But avoid


                    • Asking for help, clarification, or responding to other answers.

                    • Making statements based on opinion; back them up with references or personal experience.

                    To learn more, see our tips on writing great answers.




                    draft saved


                    draft discarded














                    StackExchange.ready(
                    function ()
                    StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fstackoverflow.com%2fquestions%2f55041537%2fwrong-output-of-classifier%23new-answer', 'question_page');

                    );

                    Post as a guest















                    Required, but never shown





















































                    Required, but never shown














                    Required, but never shown












                    Required, but never shown







                    Required, but never shown

































                    Required, but never shown














                    Required, but never shown












                    Required, but never shown







                    Required, but never shown







                    Popular posts from this blog

                    Save data to MySQL database using ExtJS and PHP [closed]2019 Community Moderator ElectionHow can I prevent SQL injection in PHP?Which MySQL data type to use for storing boolean valuesPHP: Delete an element from an arrayHow do I connect to a MySQL Database in Python?Should I use the datetime or timestamp data type in MySQL?How to get a list of MySQL user accountsHow Do You Parse and Process HTML/XML in PHP?Reference — What does this symbol mean in PHP?How does PHP 'foreach' actually work?Why shouldn't I use mysql_* functions in PHP?

                    Compiling GNU Global with universal-ctags support Announcing the arrival of Valued Associate #679: Cesar Manara Planned maintenance scheduled April 23, 2019 at 23:30 UTC (7:30pm US/Eastern) Data science time! April 2019 and salary with experience The Ask Question Wizard is Live!Tags for Emacs: Relationship between etags, ebrowse, cscope, GNU Global and exuberant ctagsVim and Ctags tips and trickscscope or ctags why choose one over the other?scons and ctagsctags cannot open option file “.ctags”Adding tag scopes in universal-ctagsShould I use Universal-ctags?Universal ctags on WindowsHow do I install GNU Global with universal ctags support using Homebrew?Universal ctags with emacsHow to highlight ctags generated by Universal Ctags in Vim?

                    Add ONERROR event to image from jsp tldHow to add an image to a JPanel?Saving image from PHP URLHTML img scalingCheck if an image is loaded (no errors) with jQueryHow to force an <img> to take up width, even if the image is not loadedHow do I populate hidden form field with a value set in Spring ControllerStyling Raw elements Generated from JSP tagds with Jquery MobileLimit resizing of images with explicitly set width and height attributeserror TLD use in a jsp fileJsp tld files cannot be resolved