Logo Search packages:      
Sourcecode: speech-tools version File versions  Download package

EST_SCFG_inout.cc

/*************************************************************************/
/*                                                                       */
/*                Centre for Speech Technology Research                  */
/*                     University of Edinburgh, UK                       */
/*                         Copyright (c) 1997                            */
/*                        All Rights Reserved.                           */
/*                                                                       */
/*  Permission is hereby granted, free of charge, to use and distribute  */
/*  this software and its documentation without restriction, including   */
/*  without limitation the rights to use, copy, modify, merge, publish,  */
/*  distribute, sublicense, and/or sell copies of this work, and to      */
/*  permit persons to whom this work is furnished to do so, subject to   */
/*  the following conditions:                                            */
/*   1. The code must retain the above copyright notice, this list of    */
/*      conditions and the following disclaimer.                         */
/*   2. Any modifications must be clearly marked as such.                */
/*   3. Original authors' names are not deleted.                         */
/*   4. The authors' names are not used to endorse or promote products   */
/*      derived from this software without specific prior written        */
/*      permission.                                                      */
/*                                                                       */
/*  THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK        */
/*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
/*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
/*  SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE     */
/*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
/*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
/*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
/*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
/*  THIS SOFTWARE.                                                       */
/*                                                                       */
/*************************************************************************/
/*             Author :  Alan W Black                                    */
/*             Date   :  October 1997                                    */
/*-----------------------------------------------------------------------*/
/*                                                                       */
/* Implementation of an inside-outside reestimation procedure for        */
/* building a sochastic CFG seeded with a bracket corpus.                */
/* Based on "Inside-Outside Reestimation from partially bracked          */
/* corpora", F Pereira and Y. Schabes. pp 128-135, 30th ACL, Newark,     */
/* Delaware 1992.                                                        */
/*                                                                       */
/* This should really be done in the log domain.  Addition in the log    */
/* domain can be done with a formula in Huang, Ariki and Jack            */
/*                         (log(a)-log(b))                               */
/*    log(a+b) = log(1 + e                ) + log(b)                     */
/*                                                                       */
/*=======================================================================*/
#include <stdlib.h>
#include "EST_SCFG_Chart.h"
#include "EST_simplestats.h"
#include "EST_math.h"
#include "EST_TVector.h"

static const EST_bracketed_string def_val_s;
static EST_bracketed_string error_return_s;
const EST_bracketed_string *EST_TVector<EST_bracketed_string>::def_val=&def_val_s;
EST_bracketed_string *EST_TVector<EST_bracketed_string>::error_return=&error_return_s;


#if defined(INSTANTIATE_TEMPLATES)
#include "../base_class/EST_TVector.cc"

template class EST_TVector<EST_bracketed_string>;
#endif

void set_corpus(EST_Bcorpus &b, LISP examples)
{
    LISP e;
    int i;

    b.resize(siod_llength(examples));

    for (i=0,e=examples; e != NIL; e=cdr(e),i++)
      b.a_no_check(i).set_bracketed_string(car(e));
}

void EST_bracketed_string::init()
{
    bs = NIL; 
    gc_protect(&bs); 
    symbols = 0;
    valid_spans = 0;
    p_length = 0;
}

EST_bracketed_string::EST_bracketed_string()
{
    init();
}

EST_bracketed_string::EST_bracketed_string(LISP string)
{ 
    init();

    set_bracketed_string(string); 
}

EST_bracketed_string::~EST_bracketed_string()
{ 
    int i;
    bs=NIL;
    gc_unprotect(&bs); 
    delete [] symbols;
    for (i=0; i < p_length; i++)
      delete [] valid_spans[i];
    delete [] valid_spans;
}

void EST_bracketed_string::set_bracketed_string(LISP string)
{
    
    bs=NIL;
    delete [] symbols;

    p_length = find_num_nodes(string);
    symbols = new LISP[p_length];

    set_leaf_indices(string,0,symbols);

    bs = string;

    int i,j;
    valid_spans = new int*[length()];
    for (i=0; i < length(); i++)
    {
      valid_spans[i] = new int[length()+1];
      for (j=i+1; j <= length(); j++)
          valid_spans[i][j] = 0;
    }
    
    // fill in valid table 
    if (p_length > 0)
      find_valid(0,bs);

}

int EST_bracketed_string::find_num_nodes(LISP string)
{
    // This wont could nil as an atom
    if (string == NIL)
      return 0;
    else if (CONSP(string))
      return find_num_nodes(car(string))+
          find_num_nodes(cdr(string));
    else
      return 1;
}

int  EST_bracketed_string::set_leaf_indices(LISP string,int i,LISP *syms)
{
    if (string == NIL)
      return i;
    else if (!CONSP(car(string)))
    {
      syms[i] = string;
      return set_leaf_indices(cdr(string),i+1,syms);
    }
    else  // car is a tree
    {
      return set_leaf_indices(cdr(string),
                        set_leaf_indices(car(string),i,syms),
                        syms);
    }
}

void EST_bracketed_string::find_valid(int s,LISP t) const
{
    LISP l;
    int c;

    if (consp(t))
    {
      for (c=s,l=t; l != NIL; l=cdr(l))
      {
          c += num_leafs(car(l));
          valid_spans[s][c] = 1;
      }
      find_valid(s,car(t));
      find_valid(s+num_leafs(car(t)),cdr(t));
    }
}

int EST_bracketed_string::num_leafs(LISP t) const
{
    if (t == NIL)
      return 0;
    else if (!consp(t))
      return 1;
    else 
      return num_leafs(car(t)) + num_leafs(cdr(t));
}

EST_SCFG_traintest::EST_SCFG_traintest(void) : EST_SCFG() 
{
    inside = 0;
    outside = 0;
    n.resize(0);
    d.resize(0);
}

EST_SCFG_traintest::~EST_SCFG_traintest(void)
{
    
}

00207 void EST_SCFG_traintest::load_corpus(const EST_String &filename)
{
    set_corpus(corpus,vload(filename,1));
}

// From the formula in the paper
00213 double EST_SCFG_traintest::f_I_cal(int c, int p, int i, int k)
{
    // Find Inside probability
    double res;

    if (i == k-1)
    {
      res = prob_U(p,terminal(corpus.a_no_check(c).symbol_at(i)));
//    printf("prob_U p %s (%d) %d m %s (%d) res %g\n",
//           (const char *)nonterminal(p),p,
//           i,
//           (const char *)corpus.a_no_check(c).symbol_at(i),
//           terminal(corpus.a_no_check(c).symbol_at(i)),
//           res);
    }
    else if (corpus.a_no_check(c).valid(i,k) == TRUE)
    {
      int j;
      double s=0;
      int q,r;
      
      for (q = 0; q < num_nonterminals(); q++)
          for (r = 0; r < num_nonterminals(); r++)
          {
            double pBpqr = prob_B(p,q,r);
            if (pBpqr > 0)
                for (j=i+1; j < k; j++)
                {
                  double in = f_I(c,q,i,j);
                  if (in > 0)
                      s += pBpqr * in * f_I(c,r,j,k);
                }
          }
      res = s;
    }
    else
      res = 0.0;

    inside[p][i][k] = res;

//    printf("f_I p %s i %d k %d res %g\n",
//       (const char *)nonterminal(p),i,k,res);

    return res;
}

00259 double EST_SCFG_traintest::f_O_cal(int c, int p, int i, int k)
{
    // Find Outside probability
    double res;

    if ((i == 0) && (k == corpus.a_no_check(c).length()))
    {
      if (p == distinguished_symbol()) // distinguished non-terminal
          res = 1.0;
      else
          res = 0.0;
    }
    else if (corpus.a_no_check(c).valid(i,k) == TRUE)
    {
      double s1=0.0;
      double s2,s3;
      double pBqrp,pBqpr;
      int j;
      int q,r;

      for (q = 0; q < num_nonterminals(); q++)
          for (r = 0; r < num_nonterminals(); r++)
          {
            pBqrp = prob_B(q,r,p);
            s2 = s3 = 0.0;
            if (pBqrp > 0)
            {
                for (j=0;j < i; j++)
                {
                  double out = f_O(c,q,j,k);
                  if (out > 0)
                      s2 += out * f_I(c,r,j,i);
                }
                s2 *= pBqrp;
            }
            pBqpr = prob_B(q,p,r);
            if (pBqpr > 0)
            {
                for (j=k+1;j <= corpus.a_no_check(c).length(); j++)
                {
                  double out = f_O(c,q,i,j);
                  if (out > 0)
                      s3 += out * f_I(c,r,k,j);
                }
                s3 *= pBqpr;
            }
            s1 += s2 + s3;
          }
      res = s1;
    }
    else  // not a valid bracketing
      res = 0.0;

    outside[p][i][k] = res;

    return res;
}

00317 void EST_SCFG_traintest::reestimate_rule_prob_B(int c, int ri, int p, int q, int r)
{
    // Restimate probability for binary rules
    int i,j,k;
    double n2=0;
    
    double pBpqr = prob_B(p,q,r);

    if (pBpqr > 0)
    {
      for (i=0; i <= corpus.a_no_check(c).length()-2; i++)
          for (j=i+1; j <= corpus.a_no_check(c).length()-1; j++)
          {
            double d1 = f_I(c,q,i,j);
            if (d1 == 0) continue;
            for (k=j+1; k <= corpus.a_no_check(c).length(); k++)
            {
                double d2 = f_I(c,r,j,k);
                if (d2 == 0) continue;
                double d3 = f_O(c,p,i,k);
                if (d3 == 0) continue;
                n2 += d1 * d2 * d3;
            }
          }
      n2 *= pBpqr;
    }
    // f_P(c) is probably redundant
    double fp = f_P(c);
    double n1,d1;
    n1 = n2 / fp;

    d1 = f_P(c,p) / fp;
    //      printf("n1 %f d1 %f n2 %f fp %f\n",n1,d1,n2,fp);
    n[ri] += n1;
    d[ri] += d1;

}

00355 void EST_SCFG_traintest::reestimate_rule_prob_U(int c,int ri, int p, int m)
{
    // Restimate probability for unary rules
    int i;

//    printf("reestimate_rule_prob_U: %f p %s m %s\n",
//       prob_U(ip,im),
//       (const char *)p,
//       (const char *)m);

    double n2=0;
      
    for (i=1; i < corpus.a_no_check(c).length(); i++)
      if (m == terminal(corpus.a_no_check(c).symbol_at(i-1)))
          n2 += prob_U(p,m) * f_O(c,p,i-1,i);

    double fP = f_P(c);
    if (fP != 0)
    {
      n[ri] += n2 / fP;
      d[ri] += f_P(c,p) / fP;
    }
}

00379 double EST_SCFG_traintest::f_P(int c)
{
    return f_I(c,distinguished_symbol(),0,corpus.a_no_check(c).length());
}

00384 double EST_SCFG_traintest::f_P(int c,int p)
{
    int i,j;
    double db=0;

    for (i=0; i < corpus.a_no_check(c).length(); i++)
      for (j=i+1; j <= corpus.a_no_check(c).length(); j++)
      {
          double d1 = f_O(c,p,i,j);
          if (d1 == 0) continue;
          db += f_I(c,p,i,j)*d1;
      }

    return db;
}

00400 void EST_SCFG_traintest::reestimate_grammar_probs(int passes,
                              int startpass,
                              int checkpoint,
                              int spread,
                              const EST_String &outfile)
{
    // Iterate over the corpus cummulating factors for each rules
    // This reduces the sapce requirements and recalculations of
    // values for each sentences.  
    // Repeat training passes to number specified
    int pass = 0;
    double zero=0;
    double se;
    int ri,c;

    n.resize(rules.length());
    d.resize(rules.length());

    for (pass = startpass; pass < passes; pass++)
    {
      EST_Litem *r;
      double mC, lPc;

      d.fill(zero);
      n.fill(zero);
      set_rule_prob_cache();

      for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
      {
          // For skipping some sentences to speed up convergence
          if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
            continue;
          printf(" %d",c); fflush(stdout);
          if (corpus.a_no_check(c).length() == 0) continue;
          init_io_cache(c,num_nonterminals());
          for (ri=0,r=rules.head(); r != 0; r=next(r),ri++)
          {
            if (rules(r).type() == est_scfg_binary_rule)
                reestimate_rule_prob_B(c,ri,
                                 rules(r).mother(),
                                 rules(r).daughter1(),
                                 rules(r).daughter2());
            else
                reestimate_rule_prob_U(c,
                                 ri,
                                 rules(r).mother(),
                                 rules(r).daughter1());
          }
          lPc += safe_log(f_P(c));
          mC += corpus.a_no_check(c).length();
          clear_io_cache(c);
      }
      printf("\n");

      for (se=0.0,ri=0,r=rules.head(); r != 0; r=next(r),ri++)
      {
          double n_prob = n[ri]/d[ri];
          se += (n_prob-rules(r).prob())*(n_prob-rules(r).prob());
          rules(r).set_prob(n_prob);
      }
      printf("pass %d cross entropy %g RMSE %f\n",
             pass,-(lPc/mC),sqrt(se/rules.length()));
      
      if (checkpoint != -1) 
      {
          if ((pass % checkpoint) == checkpoint-1)
          {
            char cp[20];
            sprintf(cp,".%03d",pass);
            save(outfile+cp);
            user_gc(NIL);  // just to keep things neat
          }
      }

    }
}

00477 void EST_SCFG_traintest::train_inout(int passes,
                             int startpass,
                             int checkpoint,
                             int spread,
                             const EST_String &outfile)
{
    // Train a Stochastic CFG using the inside outside algorithm

    reestimate_grammar_probs(passes, startpass, checkpoint, 
                       spread, outfile);
}

00489 void EST_SCFG_traintest::init_io_cache(int c,int nt)
{
    // Build an array to cache the in/out values
    int i,j,k;
    int mc = corpus.a_no_check(c).length()+1;
    
    inside = new double**[nt];
    outside = new double**[nt];
    for (i=0; i < nt; i++)
    {
      inside[i] = new double*[mc];
      outside[i] = new double*[mc];
      for (j=0; j < mc; j++)
      {
          inside[i][j] = new double[mc];
          outside[i][j] = new double[mc];
          for (k=0; k < mc; k++)
          {
            inside[i][j][k] = -1;
            outside[i][j][k] = -1;
          }
      }
    }
}
      
00514 void EST_SCFG_traintest::clear_io_cache(int c)
{
    int mc = corpus.a_no_check(c).length()+1;
    int i,j;

    if (inside == 0)
      return;

    for (i=0; i < num_nonterminals(); i++)
    {
      for (j=0; j < mc; j++)
      {
          delete [] inside[i][j];
          delete [] outside[i][j];
      }
      delete [] inside[i];
      delete [] outside[i];
    }

    delete [] inside;
    delete [] outside;

    inside = 0;
    outside = 0;
}

double EST_SCFG_traintest::cross_entropy()
{
    double lPc=0,mC=0;
    int c;

    for (c=0; c < corpus.length(); c++)
    {
      lPc += log(f_P(c));
      mC += corpus.a_no_check(c).length();
    }

    return -(lPc/mC);
}

00554 void EST_SCFG_traintest::test_corpus()
{
    // Test corpus against current grammar.
    double mC,lPc;
    int c,i;
    int failed=0;
    double fP;

    // Lets try simply finding the cross entropy
    n.resize(rules.length());
    d.resize(rules.length());
    for (i=0; i < rules.length(); i++)
      d[i] = n[i] = 0.0;

    for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
    {
      if (corpus.length() > 50)
      {
          printf(" %d",c); 
          fflush(stdout);
      }
      init_io_cache(c,num_nonterminals());
      fP = f_P(c);
      if (fP == 0)
          failed++;
      else
      {
          lPc += safe_log(fP);
          mC += corpus.a_no_check(c).length();
      }
      clear_io_cache(c);
    }
    if (corpus.length() > 50)
      printf("\n");

    cout << "cross entropy " << -(lPc/mC) << " (" << failed << " failed out of " <<
      corpus.length() << " sentences )" << endl;

}


Generated by  Doxygen 1.6.0   Back to index