Logo Search packages:      
Sourcecode: libneuralnet version File versions  Download package

neuron.hh

//
// neuron.hh
//
// Made by Guillaume Stordeur
// Login   <kami@GrayArea.Masaq>
//
// Started on  Thu Aug  1 04:59:35 2002 Guillaume Stordeur
// Last update Mon May  5 21:45:49 2003 Guillaume Stordeur
//

#ifndef     NEURON_HH_
# define    NEURON_HH_

# include <cassert>
# include <vector>
# include "utils.hh"
# include "Matrix.hh"

#define ACT_STRING(T, S)            \
 switch(T)                    \
  {                           \
   case ACT_SGN:              \
    S = "sgn"; break;               \
   case ACT_LINEAR:                 \
    S = "linear"; break;            \
   case ACT_SIGMOID:                \
    S = "sigmoid"; break;           \
   case ACT_SIGMOID_APPROX:         \
    S = "sigmoid_approx"; break;    \
   case ACT_TANH:             \
    S = "tanh"; break;              \
   case ACT_TANH_APPROX:            \
    S = "tanh_approx"; break;       \
   case ACT_GAUSS:                  \
    S = "gauss"; break;             \
   default:                   \
    S = "unknown"; break;           \
}

#define DWSUM           1 << 1
#define DELTAE          1 << 2
#define MOMENTUM  1 << 3
#define OLD_DELTAE      1 << 4
#define DECAY_DELTAE    1 << 5
#define TRI       1 << 6
#define OUTPUT          1 << 7

namespace NeuralNet
{

typedef enum e_ActivationFunction
  {
    ACT_SGN,
    ACT_LINEAR,
    ACT_SIGMOID,
    ACT_SIGMOID_APPROX,
    ACT_TANH,
    ACT_TANH_APPROX,
    ACT_GAUSS,
  } ActivationFunctionType;




//--------------------------------------
// Main Neuron class, all other classes
// are derived from this one.
//
class Neuron
{
public:
  // Constructor
  Neuron() { _dwsum = _s = _output = 0; _fixed = _recurrent = _timeLagged = false; }
  virtual ~Neuron() {}

  //  Add a connection to neuron's input
  void      addConnection(Neuron *src, float weight);
  void      addConnection(Neuron *src); // random weight

  //  Update weights directly (stochastic mode)
  void      updateBackpropStochastic(float lRate,
                         float moment,
                         float delta);

  // update deltae for batch mode
  virtual void    updateBatch(float delta);

  // final weight change functions for batch mode
  // rprop
  virtual void    updateWeights(float nPlus, float nMinus,
                        float deltaMin, float deltaMax,
                        bool errUp);
  // backprop
  virtual void    updateWeights(float lRate, float moment);
  // quickprop
  virtual void    updateWeights(float lRate, float moment, float mu);

  //  Returns the index of n in the _inputNeurons vector
  //  -1 is returned if not present.
  int isInputNeuron(Neuron *n);

  //-=-=-=-=-=-=-=-=-=-=-=-=-=-=
  // INTERFACE

  void      setWeight(unsigned int i, float w)
  { assert(i < _weights.size()); _weights[i] = w; }

  void      setWeights(Matrix<double>& m);

  void      setTri(float t)
  { for (unsigned int i = 0; i < _tri.size(); i++) _tri[i] = t; }

  void      setOldDeltae(float t)
  { for (unsigned int i = 0; i < _oldDeltae.size(); i++) _oldDeltae[i] = t; }

  void      clearOldDeltae()
  { for (unsigned int i = 0; i < _oldDeltae.size(); i++) _oldDeltae[i] = 0; }

  void      setFixed(bool b) { _fixed = b; }
  void      setTimeLagged(bool b) { _timeLagged = b; }

  void      incDwsum(float a) { _dwsum += a; }

  void      clearDwsum() { _dwsum = 0; }

  void      clearDeltae()
  { for (unsigned int i = 0; i < _deltae.size(); i++) _deltae[i] = 0; }

  void      clearMomentum()
  { for (unsigned int i = 0; i < _momentum.size(); i++) _momentum[i] = 0; }

  void      decayDeltae(float decay)
  {
    for (unsigned int i = 0; i < _inputNeurons.size(); i++)
      _deltae[i] = decay * _weights[i];
  }
  
  bool      getRecurrent() const { return _recurrent; }
  
  void      setRecurrent(bool i) { _recurrent = i; }
  
  // All in one clear/set function
  virtual  void   clearset(int flags, float decay, float tri)
  {
    //assert(_deltae.size() == _oldDeltae.size() ==
    //_momentum.size() == _tri.size() == _inputNeurons.size());
    if (flags & DWSUM)
      _dwsum = 0;
    if (flags & OUTPUT)
      _output = 0;
    if (flags == DWSUM || flags == OUTPUT)
      return;
    for (unsigned int i = 0; i < _inputNeurons.size(); i++)
      {
      if (flags & TRI)
        _tri[i] = tri;
      if (flags & OLD_DELTAE)
        _oldDeltae[i] = 0;
      if (flags & DELTAE)
        _deltae[i] = 0;
      if (flags & DECAY_DELTAE)
        _deltae[i] = decay * _weights[i];
      if (flags & MOMENTUM)
        _momentum[i] = 0;
      }
  }

  unsigned int getNBInputs() const { return _inputNeurons.size(); }

  virtual float   getOutput() const { return _output; }

  bool      getFixed() const { return _fixed; }
  bool      getTimeLagged() const { return _timeLagged; }


  float     getDwsum() const { return _dwsum; }

  float     getWeight(unsigned int i) const
  { assert(i < _weights.size()); return _weights[i]; }

  Matrix<double>  getWeights() const
  {
    std::vector<double> w(_weights.size());
    for (unsigned int i = 0; i < _weights.size(); i++)
      w[i] = _weights[i];
    Matrix<double> m(w);
    return m;
  }

  Neuron    *getInputNeuron(unsigned int i) const
  { assert(i < _inputNeurons.size()); return _inputNeurons[i]; }

  //-=-=-=-=-=-=-=-=-=-=-=-=-=-=
  // NEURON SPECIFIC FUNCTIONS

  //  Get the derivative F'(s)
  virtual float   getFPrime(void) const = 0;

  //    Process inputs and calculate activation (output)
  virtual float refreshOutput(void) = 0;

  //  Display info to standard output
  virtual void    display(void) const = 0;

protected:
  //  Calc weighted sum
  void      _calcWeightedSum(void);

  //  neurons that have their output connected to this neuron
  std::vector<Neuron*>  _inputNeurons;

  //  weight vector, 1st elem being the threshold
  std::vector<float>    _weights;

  //  learning momentum vector
  std::vector<float>    _momentum;

  //
  std::vector<float>    _tri;

  //
  std::vector<float>    _deltae, _oldDeltae;

  //  last output, updated when output function is called
  float     _output;

  //  weighted sum
  float     _s;

  //  delta-weight sum, used in backpropagation
  float     _dwsum;

  //  fixed-weights and timeLagged
  bool      _fixed, _timeLagged;
  
  //  Neuron has it's input fed from an other neuron in the network
  bool      _recurrent;
};

//--------------------------------------
// Threshold Neuron
// * output = 1
// This neuron is connected to a
// threshold connection of another neuron,
// and does not take any inputs.
// It always returns 1.
//
class ThresholdNeuron : public Neuron
{
public:
  //  Constructor
  ThresholdNeuron() { _output = 1; }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void) { return _output; }

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return 1; }

  //  Display info to standard output
  void      display(void) const {}
};


//--------------------------------------
// Input Neuron
// * output = input
// this neuron has 1 input connected to its output
//
class InputNeuron : public Neuron
{
public:
  //  Constructor
  InputNeuron(float input = 0) { _output = input; }


  //  Set up the input
  void      setInput(float input) { _output = input; }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void)
  {
    if (_inputNeurons.size() != 0)
      _output = _inputNeurons[0]->getOutput();
    return _output;
  }

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return 1; }

  //  Display info to standard output
  void      display(void) const {}
};


//--------------------------------------
// Sigmoid activation Neuron
// * output = sigmoid(activation)
//
class SigmoidNeuron : public Neuron
{
public:
  //  Constructor
  SigmoidNeuron(ThresholdNeuron     *n,
            float       threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return (_output * (1 - _output)); }

  //  Display info to standard output
  void      display(void) const;


protected:

  //  Sigmoid non-linear activation function
  inline float    _sigmoid(float s);

};

//--------------------------------------
// Approximated Sigmoid activation Neuron
// * output = sigmoid_approx(activation)
//
class ApproxSigmoidNeuron : public Neuron
{
public:
  //  Constructor
  ApproxSigmoidNeuron(ThresholdNeuron     *n,
                  float       threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return (fabs(_s) >= 1 ? 0 : 1 - fabs(_s)); }

  //  Display info to standard output
  void      display(void) const;


protected:

  //  Approximated Sigmoid non-linear activation function
  inline float    _sigmoid_approx(float s);

};

//--------------------------------------
// Gauss activation Neuron
// * output = gauss(activation)
//
class GaussNeuron : public Neuron
{
public:
  //  Constructor
  GaussNeuron(ThresholdNeuron *n,
            float       threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return (-2 * _s * _output); }

  //  Display info to standard output
  void      display(void) const;


protected:

  //  Tanh non-linear activation function
  inline float    _gauss(float s);

};

//--------------------------------------
// Tanh activation Neuron
// * output = tanh(activation)
//
class TanhNeuron : public Neuron
{
public:
  //  Constructor
  TanhNeuron(ThresholdNeuron  *n,
           float        threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return (1 - _output * _output); }

  //  Display info to standard output
  void      display(void) const;


protected:

  //  Tanh non-linear activation function
  inline float    _tanh(float s);

};

//--------------------------------------
// Approximated Tanh activation Neuron
// * output = tanh_approx(activation)
//
class ApproxTanhNeuron : public Neuron
{
public:
  //  Constructor
  ApproxTanhNeuron(ThresholdNeuron  *n,
               float          threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  float     getFPrime(void) const
  {
    if (fabs(_s) > 1.92033)
      return 0;
    else if (_s > 0 && _s <= 1.92033)
      return (-0.52074 * _s + 1);
    else
      return  (0.52074 * _s - 1);
  }

  //  Display info to standard output
  void      display(void) const;


protected:

  //  Approximated Tanh non-linear activation function
  inline float    _tanh_approx(float s);

};

//----------------------------------------
// Sgn activation Neuron
// * output = 0 if activation < 0
//            1 if activation >= 0
//
class SgnNeuron : public Neuron
{
public:
  //  Constructor
  SgnNeuron(ThresholdNeuron   *n,
          float         threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  // FIXME: not derivable
  float     getFPrime(void) const { return 1; }

  //  Display info to standard output
  void      display(void) const;


protected:

};


//----------------------------------------
// Linear activation Neuron
// * output = activation
//
class LinearNeuron : public Neuron
{
public:
  //  Constructor
  LinearNeuron(ThresholdNeuron      *n,
             float            threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //  Get the derivative F'(s)
  float     getFPrime(void) const { return 1; }

  //  Display info to standard output
  void      display(void) const;


protected:

};


//
//    Alloc a new neuron
//
Neuron      *newNeuron(ActivationFunctionType   type,
               ThresholdNeuron            *tneuron,
               float                threshold);

} // end NeuralNet namespace

#endif          /* !NEURON_HH_ */

Generated by  Doxygen 1.6.0   Back to index