Logo Search packages:      
Sourcecode: libneuralnet version File versions  Download package

classifier.cc

//
// trainer.cc
//  
// Made by Guillaume Stordeur
// Login   <kami@GrayArea.Masaq>
// 
// Started on  Fri Aug  2 05:53:59 2002 Guillaume Stordeur
// Last update Thu May  1 19:39:32 2003 Guillaume Stordeur
//

#include <stdlib.h>
#include <iostream>
#include <vector>
#include "neuralnet.hh"
#include <exception/exception.hh>
#include <data/csvloader.hh>
//
// Main
//
int   main(int argc, char *argv[])
{  
  srand(time(NULL));

  if (argc < 4)
    {
      std::cout << "Usage: classifier network.nn data.csv separator\n";
      return 1;
    }


  // Init cvsloader and get cols
  std::string csvfile(argv[2]);
  Data::CsvLoader loader;
  std::vector<Data::Column*>  cols;
  cols = loader.read_file(csvfile, argv[3][0]);

  // Init neural net
  NeuralNet::NeuralNet  nn(argv[1]);
 

  // Check that inputs match
  if (nn.getNbInputNeurons() != cols.size())
    {
      std::cerr << "ERROR: Data inputs do not match number of input neurons:" << std::endl;
      std::cerr << "       nbInputNeurons = " << nn.getNbInputNeurons() << std::endl;
      std::cerr << "       nbDataInputs   = " << cols.size() << std::endl;
      return 1;
    }
 
  // Output the names of the columns
  for (unsigned int i = 0; i < cols.size(); i++)
    std::cout << cols[i]->get_name() << ";";
  for (unsigned int i = 0; i < nn.getNbOutputNeurons(); i++)
    {
      std::cout << "Output" << i + 1;
      if (i < nn.getNbOutputNeurons() - 1)
      std::cout << ";";
    }
  std::cout << std::endl;

  // Output nn results
  std::vector<float> res, input;
  for (unsigned int j = 0; j < cols[0]->get_size(); j++)
    {
      input.clear();
      res.clear();
      for (unsigned int i = 0; i < cols.size(); i++)
      {
        if (cols[i]->get_type() == Data::STRING)
          {
            std::cout << "ERROR: This classifier does not support strings,";
            std::cout << "please discretize the data first." << std::endl;
            return 1;
          }
        float value;
        if (cols[i]->get_type() == Data::INT)
          value = (float) (dynamic_cast<Data::ColumnTyped<int>*> (cols[i]))->get_value(j);
        else if (cols[i]->get_type() == Data::FLOAT)
          value = (dynamic_cast<Data::ColumnTyped<float>*> (cols[i]))->get_value(j);
        input.push_back(value);
        std::cout << value << ";";
      }
      res = nn.output(input);
      for (unsigned int k = 0; k < res.size(); k++)
      {
        std::cout << res[k];
        if (k < res.size() - 1)
          std::cout << ";";
      }
      std::cout << std::endl;
    }
  return 0;
}

Generated by  Doxygen 1.6.0   Back to index