Logo Search packages:      
Sourcecode: speech-tools version File versions  Download package

wagon_aux.cc

/*************************************************************************/
/*                                                                       */
/*                Centre for Speech Technology Research                  */
/*                     University of Edinburgh, UK                       */
/*                      Copyright (c) 1996,1997                          */
/*                        All Rights Reserved.                           */
/*                                                                       */
/*  Permission is hereby granted, free of charge, to use and distribute  */
/*  this software and its documentation without restriction, including   */
/*  without limitation the rights to use, copy, modify, merge, publish,  */
/*  distribute, sublicense, and/or sell copies of this work, and to      */
/*  permit persons to whom this work is furnished to do so, subject to   */
/*  the following conditions:                                            */
/*   1. The code must retain the above copyright notice, this list of    */
/*      conditions and the following disclaimer.                         */
/*   2. Any modifications must be clearly marked as such.                */
/*   3. Original authors' names are not deleted.                         */
/*   4. The authors' names are not used to endorse or promote products   */
/*      derived from this software without specific prior written        */
/*      permission.                                                      */
/*                                                                       */
/*  THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK        */
/*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
/*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
/*  SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE     */
/*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
/*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
/*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
/*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
/*  THIS SOFTWARE.                                                       */
/*                                                                       */
/*************************************************************************/
/*                     Author :  Alan W Black                            */
/*                     Date   :  May 1996                                */
/*-----------------------------------------------------------------------*/
/*                                                                       */
/*  Various method functions                                             */
/*=======================================================================*/

#include <stdlib.h>
#include <iostream.h>
#include <string.h>
#include "EST_unix.h"
#include "EST_cutils.h"
#include "EST_Token.h"
#include "EST_Wagon.h"

EST_Val WNode::predict(const WVector &d)
{
    if (leaf())
      return impurity.value();
    else if (question.ask(d))
      return left->predict(d);
    else
      return right->predict(d);
}

WNode *WNode::predict_node(const WVector &d)
{
    if (leaf())
      return this;
    else if (question.ask(d))
      return left->predict_node(d);
    else
      return right->predict_node(d);
}

int WNode::pure(void)
{
    //  A node is pure if it has no sub-nodes or its not of type class

    if ((left == 0) && (right == 0))
      return TRUE;
    else if (get_impurity().type() != wnim_class)
      return TRUE;
    else
      return FALSE;
}

void WNode::prune(void)
{
    // Check all sub-nodes and if they are all of the same class 
    // delete their sub nodes.  Returns pureness of this node

    if (pure() == FALSE)
    {
      // Ok lets try and make it pure
      if (left != 0) left->prune();
      if (right != 0) right->prune();

      // Have to check purity as well as values to ensure left and right
      // don't further split
      if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
          (left->get_impurity().value() == right->get_impurity().value()))
      {
           delete left; left = 0;
           delete right; right = 0;
      }
    }    

}

void WNode::held_out_prune()
{
    // prune tree with held out data
    // Check if node's questions differentiates for the held out data
    // if not, prune all sub_nodes

    // Rescore with prune data
    set_impurity(WImpurity(get_data()));  // for this new data

    if (left != 0)
    {
      wgn_score_question(question,get_data());
      if (question.get_score() < get_impurity().measure())
      {  // its worth goint ot the next level
          wgn_find_split(question,get_data(),
                   left->get_data(),
                   right->get_data());
          left->held_out_prune();
          right->held_out_prune();
      }
      else
      {  // not worth the split so prune both sub_nodes
          delete left; left = 0;
          delete right; right = 0;
      }
    }    
}

void WNode::print_out(ostream &s, int margin)
{
    int i;

    s << endl;
    for (i=0;i<margin;i++) s << " ";
    s << "(";
    if (left==0) // base case
      s << impurity;
    else
    {
      s << question;
      left->print_out(s,margin+1);
      right->print_out(s,margin+1);
    }
    s << ")";
}

ostream & operator <<(ostream &s, WNode &n)
{
    // Output this node and its sub-node 

    n.print_out(s,0);
    s << endl;
    return s;
}

void WDataSet::load_description(const EST_String &fname, LISP ignores)
{
    // Initialise a dataset with sizes and types
    EST_String tname;
    int i;
    LISP description,d;

    description = car(vload(fname,1));
    dlength = siod_llength(description);

    p_type.resize(dlength);
    p_ignore.resize(dlength);
    p_name.resize(dlength);

    if (wgn_predictee_name == "")
      wgn_predictee = 0;  // default predictee is first field
    else
      wgn_predictee = -1;

    for (i=0,d=description; d != NIL; d=cdr(d),i++)
    {
      p_name[i] = get_c_string(car(car(d)));
      tname = get_c_string(car(cdr(car(d))));
      p_ignore[i] = FALSE;
      if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
          wgn_predictee = i;
      if ((wgn_count_field_name != "") && 
          (wgn_count_field_name == p_name[i]))
          wgn_count_field = i;
      if ((tname == "count") || (i == wgn_count_field))
      {
          // The count must be ignored, repeat it if you want it too
          p_type[i] = wndt_ignore;  // the count must be ignored
          p_ignore[i] = TRUE;
          wgn_count_field = i;
      }
      else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
      {
          p_type[i] = wndt_ignore;  // user specified ignore
          p_ignore[i] = TRUE;
          if (i == wgn_predictee)
            wagon_error(EST_String("predictee \"")+p_name[i]+
                      "\" can't be ignored \n");
      }
      else if (siod_llength(car(d)) > 2)
      {
          LISP rest = cdr(car(d));
          EST_StrList sl;
          siod_list_to_strlist(rest,sl);
          p_type[i] = wgn_discretes.def(sl);
          if (streq(get_c_string(car(rest)),"_other_"))
            wgn_discretes[p_type[i]].def_val("_other_");
      }
      else if (tname == "binary")
          p_type[i] = wndt_binary;
      else if (tname == "cluster")
          p_type[i] = wndt_cluster;
      else if (tname == "float")
          p_type[i] = wndt_float;
      else 
      {
          wagon_error(EST_String("Unknown type \"")+tname+
                  "\" for field number "+itoString(i)+
                        " in description file \""+fname+"\"");
      }
    }

    if (wgn_predictee == -1)
    {
      wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
                  "\" not found in description ");
    }
}

const int WQuestion::ask(const WVector &w) const
{
    // Ask this question of the given vector
    switch (op)
    {
      case wnop_equal:    // for numbers
      if (w.get_flt_val(feature_pos) == operand1.Float())
          return TRUE;
      else 
          return FALSE;
      case wnop_binary:    // for numbers
      if (w.get_int_val(feature_pos) == 1)
          return TRUE;
      else 
          return FALSE;
      case wnop_greaterthan:
      if (w.get_flt_val(feature_pos) > operand1.Float())
          return TRUE;
      else 
          return FALSE;
      case wnop_lessthan:
      if (w.get_flt_val(feature_pos) < operand1.Float())
          return TRUE;
      else 
          return FALSE;
      case wnop_is:       // for classes
      if (w.get_int_val(feature_pos) == operand1.Int())
          return TRUE;
      else
          return FALSE;
      case wnop_in:       // for subsets -- note operand is list of ints
      if (ilist_member(operandl,w.get_int_val(feature_pos)))
          return TRUE;
      else
          return FALSE;
      default:
      wagon_error("Unknown test operator");
    }
    
    return FALSE;
}

ostream& operator<<(ostream& s, const WQuestion &q)
{
    EST_String name;
    static EST_Regex needquotes(".*[()'\";., \t\n\r].*");

    s << "(" << wgn_dataset.feat_name(q.get_fp());
    switch (q.get_op())
    {
      case wnop_equal:
      s << " = " << q.get_operand1().string();
      break;
      case wnop_binary:
      break;
      case wnop_greaterthan:
      s << " > " << q.get_operand1().Float();
      break;
      case wnop_lessthan:
      s << " < " << q.get_operand1().Float();
      break;
      case wnop_is:
      name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
          name(q.get_operand1().Int());
      s << " is ";
      if (name.matches(needquotes))
          s << quote_string(name,"\"","\\",1);
      else
          s << name;
      break;
      case wnop_matches:
      name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
          name(q.get_operand1().Int());
      s << " matches " << quote_string(name,"\"","\\",1);
      break;
      case wnop_in:
      s << " in (";
      for (int l=0; l < q.get_operandl().length(); l++)
      {
          name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
            name(q.get_operandl().nth(l));
          if (name.matches(needquotes))
            s << quote_string(name,"\"","\\",1);
          else
            s << name;
          s << " ";
      }
      s << ")";
      break;
        // SunCC wont let me add this
//      default:
//    s << " unknown operation ";
    }
    s << ")";
    
    return s;
}

EST_Val WImpurity::value(void)
{
    // Returns the recommended value for this 
    EST_String s;
    double prob;
    
    if (t==wnim_unset)
    {
      cerr << "WImpurity: no value currently set\n";
      return EST_Val(0.0);
    }
    else if (t==wnim_class)
      return EST_Val(p.most_probable(&prob));
    else if (t==wnim_cluster)
      return EST_Val(a.mean());
    else
      return EST_Val(a.mean());
}

double WImpurity::samples(void)
{
    if (t==wnim_float)
      return a.samples();
    else if (t==wnim_class)
      return (int)p.samples();
    else if (t==wnim_cluster)
      return members.length();
    else
      return 0;
}

WImpurity::WImpurity(const WVectorVector &ds)
{
    int i;

    t=wnim_unset;
    for (i=0; i < ds.n(); i++)
    {
      if (wgn_count_field == -1)
          cumulate((*(ds(i)))[wgn_predictee],1);
      else
          cumulate((*(ds(i)))[wgn_predictee],
                 (*(ds(i)))[wgn_count_field]);
    }
}

float WImpurity::measure(void)
{
    if (t == wnim_float)
      return a.variance()*a.samples();
    else if (t == wnim_class)
      return p.entropy()*p.samples();
    else if (t == wnim_cluster)
      return cluster_impurity();
    else
    {
      cerr << "WImpurity: can't measure unset object" << endl;
      return 0.0;
    }
}

float WImpurity::cluster_impurity()
{
    // Find the mean distance between all members of the dataset
    // Uses the global DistMatrix for distances between members of
    // the cluster set.  Distances are assumed to be symmetric thus only
    // the bottom half of the distance matrix is filled
    EST_Litem *pp, *q;
    int i,j;
    double dist;

    a.reset();
    for (pp=members.head(); pp != 0; pp=next(pp))
    {
      i = members.item(pp);
      for (q=next(pp); q != 0; q=next(q))
      {
          j = members.item(q);
          dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
                        wgn_DistMatrix.a_no_check(j,i));
          a+=dist;  // cumulate for whole cluster
      }
    }

    // This is sum distance between cross product of members
    return a.sum();
}

float WImpurity::cluster_distance(int i)
{
    // Distance this unit is from all others in this cluster
    // in absolute standard deviations from the the mean.
    float dist = cluster_member_mean(i);
    float mdist = dist-a.mean();
    
    if (mdist == 0.0)
      return 0.0;
    else
      return fabs((dist-a.mean())/a.stddev());

}

int WImpurity::in_cluster(int i)
{
    // Would this be a member of this cluster?.  Returns 1 if 
    // its distance is less than at least one other
    float dist = cluster_member_mean(i);
    EST_Litem *pp;

    for (pp=members.head(); pp != 0; pp=next(pp))
    {
      if (dist < cluster_member_mean(members.item(pp)))
          return 1;
    }
    return 0;
}

float WImpurity::cluster_ranking(int i)
{
    // Position in ranking closest to centre
    float dist = cluster_distance(i);
    EST_Litem *pp;
    int ranking = 1;

    for (pp=members.head(); pp != 0; pp=next(pp))
    {
      if (dist >= cluster_distance(members.item(pp)))
          ranking++;
    }

    return ranking;
}

float WImpurity::cluster_member_mean(int i)
{
    // Returns the mean difference between this member and all others
    // in cluster
    EST_Litem *q;
    int j,n;
    double dist,sum;

    for (sum=0.0,n=0,q=members.head(); q != 0; q=next(q))
    {
      j = members.item(q);
      if (i != j)
      {
          dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
          sum += dist;
          n++;
      }
    }

    return ( n == 0 ? 0.0 : sum/n );
}

void WImpurity::cumulate(const float pv,double count)
{
    // Cumulate data for impurity calculation 

    if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
    {
      t = wnim_cluster;
      members.append((int)pv);
    }
    else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
    {
      if (t == wnim_unset)
          p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
      t = wnim_class;
      p.cumulate((int)pv,count);
    }
    else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
    {
      t = wnim_float;
      a.cumulate((int)pv,count);
    }
    else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
    {
      t = wnim_float;
      a.cumulate(pv,count);
    }
    else
    {
      wagon_error("WImpurity: cannot cumulate EST_Val type");
    }
}

ostream & operator <<(ostream &s, WImpurity &imp)
{
    if (imp.t == wnim_float)
      s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
    else if (imp.t == wnim_cluster)
    {
      EST_Litem *p;
      s << "((";
      for (p=imp.members.head(); p != 0; p=next(p))
      {
          // Ouput cluster member and its mean distance to others
          s << "(" << imp.members.item(p) << " " <<
            imp.cluster_member_mean(imp.members.item(p)) << ")";
          if (next(p) != 0)
            s << " ";
      }
      s << ") ";
      // Mean of cross product of distances (cluster score)
      s << imp.a.mean() << ")";
    }
    else if (imp.t == wnim_class)
    {
      int i;
      EST_String name;
      double prob;

      s << "(";
      for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
      {
          imp.p.item_prob(i,name,prob);
          s << "(" << name << " " << prob << ") ";
      }
      s << imp.p.most_probable(&prob) << ")";
    }
    else
      s << "([WImpurity unset])";
    
    return s;
}





Generated by  Doxygen 1.6.0   Back to index