Implementing Baum-Welch (Forward-Backward) algorithm in SRILM

Andreas Stolcke stolcke at speech.sri.com
Thu Jan 1 11:05:48 PST 2004


In message <005801c3cf8b$a20869d0$34284484 at cs.technion.ac.il>you wrote:
> Hi,
> 
> I'm using disambig for part-of-speech tagging. I create a language model
> over sequences of tags with ngram-count, and provide P(word|tag) in the
> map file. 
> 
> What I would like to do is to start with this model, based on tagged
> corpus, and improve it using the Baum-Welch (forwad-backward) algorithm,
> with untagged corpus. After each iteration I should get a new language
> model for the tags and a new map file . After each iteration I would
> like to test the model on some held-out data, so I know when to stop.
> 
> How can I implement that in SRILM?

You need to write some scripts to manipulate intermediate data, but you
can pretty much do what you want.
To implement EM for your tagger you have two steps:

1. E-Step:   get expected counts for the tag n-gram and the word/tag mapping.

   a. Tag n-gram expectations.   This step is unfortunately not well supported 
	by the tools right now.  Although disambig uses the FB algorithm
	it doesn't collect (let alone output) the expected counts in a way
	that's suitable for reestimating a model from them.  You can use 
	two approximations.  First, you could use the 1-best tag sequence
	as a stand-in for the real thing and generate tag N-gram counts from
	it (that's sometimes called the "Viterbi" approximation of EM).
	Second, you can use the -nbest option to generate the top N most
	likely taggings of each sentence along with their score.  You then
	have to normalize the scores to obtain posterior probabilities for
	the tag sequences and weight the tag N-gram counts by these 
	posteriors and total them over your entire training corpus.

   b. Word/tag expectations.  Here again, you could use the Viterbi
	approximation, simply pairing up the words and their most likely
	tags (as output by disambig).  However, the most recent version of
	disambig actually has an option to collect and output the 
	expected word/tag bigram counts.  I have appended a patch that 
	should allow you to do this with the 1.3.3 version of disambig.
	The option that this adds is 

       -write-counts file
              Outputs the V2-V1 bigram  counts  corresponding  to
              the  tagging  performed  on the input data.  If -fb
              was specified these are expected counts, and other-
              wise they reflect the 1-best tagging decisions.

2. M-step:   reestimate the tag N-gram LM and the word/tag mapping probabilties

   a.  Once you have the tag N-gram counts (obtained by one of the methods
	suggested above) you just need to run ngram on the count file to
	get a new model.  Use -float-counts and a suitable discounting method
	if you are using fractional counts.  

   b.  Again, just use ngram to estimate a word/tag bigram model from the
	expected counts.  You then have to post-process the LM file to
	extract the word/tag probabilties and format them into a map file
	usable by disambig.

Hope this helps.

Happy New Year,

Andreas 

*** /tmp/T00BSlQ1	Wed Dec 31 13:30:23 2003
--- /tmp/T10e6tMs	Wed Dec 31 13:30:23 2003
***************
*** 38,46 ****
--- 38,48 ----
  static char *vocab1File = 0;
  static char *vocab2File = 0;
  static char *mapFile = 0;
+ static char *classesFile = 0;
  static char *mapWriteFile = 0;
  static char *textFile = 0;
  static char *textMapFile = 0;
+ static char *countsFile = 0;
  static int keepUnk = 0;
  static int tolower1 = 0;
  static int tolower2 = 0;
***************
*** 63,70 ****
--- 65,74 ----
      { OPT_STRING, "write-vocab1", &vocab1File, "output observable vocabulary" },
      { OPT_STRING, "write-vocab2", &vocab2File, "output hidden vocabulary" },
      { OPT_STRING, "map", &mapFile, "mapping from observable to hidden tokens" },
+     { OPT_STRING, "classes", &classesFile, "mapping in class expansion format" },
      { OPT_TRUE, "logmap", &logMap, "map file contains log probabilities" },
      { OPT_STRING, "write-map", &mapWriteFile, "output map file (for validation)" },
+     { OPT_STRING, "write-counts", &countsFile, "output substitution counts" },
      { OPT_TRUE, "scale", &scale, "scale map probabilities by unigram probs" },
      { OPT_TRUE, "keep-unk", &keepUnk, "preserve unknown words" },
      { OPT_TRUE, "tolower1", &tolower1, "map observable vocabulary to lowercase" },
***************
*** 88,94 ****
   */
  unsigned
  disambiguateSentence(Vocab &vocab, VocabIndex *wids, VocabIndex *hiddenWids[],
! 		     LogP totalProb[], VocabMap &map, LM &lm,
  		     unsigned numNbest, Boolean positionMapped = false)
  {
      static VocabIndex emptyContext[] = { Vocab_None };
--- 92,98 ----
   */
  unsigned
  disambiguateSentence(Vocab &vocab, VocabIndex *wids, VocabIndex *hiddenWids[],
! 		     LogP totalProb[], VocabMap &map, LM &lm, VocabMap *counts,
  		     unsigned numNbest, Boolean positionMapped = false)
  {
      static VocabIndex emptyContext[] = { Vocab_None };
***************
*** 236,241 ****
--- 240,256 ----
  	    }
  	    hiddenWids[n][len] = Vocab_None;
  	}
+ 
+ 	/* 
+ 	 * update v1-v2 counts if requested 
+ 	 */
+ 	if (counts) {
+ 	    for (unsigned i = 0; i < len; i++) {
+ 		counts->put(wids[i], hiddenWids[0][i],
+ 			    counts->get(wids[i], hiddenWids[0][i]) + 1);
+ 	    }
+ 	}
+ 
  	return numNbest;
      } else {
  	/*
***************
*** 426,431 ****
--- 441,460 ----
  		}
  		cout << endl;
  	    }
+ 
+ 	    /* 
+ 	     * update v1-v2 counts if requested 
+ 	     */
+ 	    if (counts) {
+ 		symbolIter.init();
+ 		while (symbolProb = symbolIter.next(symbol)) {
+ 		    LogP2 posterior = *symbolProb - totalPosterior;
+ 
+ 		    counts->put(wids[pos], symbol,
+ 			        counts->get(wids[pos], symbol) +
+ 					LogPtoProb(posteriors));
+ 		}
+ 	    }
  	}
  
          /*
***************
*** 442,448 ****
   * disambiguate it, and print out the result
   */
  void
! disambiguateFile(File &file, VocabMap &map, LM &lm)
  {
      char *line;
      VocabString sentence[maxWordsPerLine];
--- 471,477 ----
   * disambiguate it, and print out the result
   */
  void
! disambiguateFile(File &file, VocabMap &map, LM &lm, VocabMap *counts)
  {
      char *line;
      VocabString sentence[maxWordsPerLine];
***************
*** 476,482 ****
  	    LogP totalProb[numNbest];
  	    unsigned numHyps =
  			disambiguateSentence(map.vocab1, wids, hiddenWids,
! 						totalProb, map, lm, numNbest);
  	    if (!numHyps) {
  		file.position() << "Disambiguation failed\n";
  	    } else if (totals) {
--- 505,511 ----
  	    LogP totalProb[numNbest];
  	    unsigned numHyps =
  			disambiguateSentence(map.vocab1, wids, hiddenWids,
! 					totalProb, map, lm, counts, numNbest);
  	    if (!numHyps) {
  		file.position() << "Disambiguation failed\n";
  	    } else if (totals) {
***************
*** 521,527 ****
   * disambiguate it, and print out the result
   */
  void
! disambiguateFileContinuous(File &file, VocabMap &map, LM &lm)
  {
      char *line;
      Array<VocabIndex> wids;
--- 550,557 ----
   * disambiguate it, and print out the result
   */
  void
! disambiguateFileContinuous(File &file, VocabMap &map, LM &lm,
! 							VocabMap *counts)
  {
      char *line;
      Array<VocabIndex> wids;
***************
*** 560,566 ****
  
      LogP totalProb[numNbest];
      unsigned numHyps = disambiguateSentence(map.vocab1, &wids[0], hiddenWids,
! 					    totalProb, map, lm, numNbest);
  
      if (!numHyps) {
  	file.position() << "Disambiguation failed\n";
--- 590,596 ----
  
      LogP totalProb[numNbest];
      unsigned numHyps = disambiguateSentence(map.vocab1, &wids[0], hiddenWids,
! 					totalProb, map, lm, counts, numNbest);
  
      if (!numHyps) {
  	file.position() << "Disambiguation failed\n";
***************
*** 593,599 ****
   * disambiguate it, and print out the result
   */
  void
! disambiguateTextMap(File &file, Vocab &vocab, LM &lm)
  {
      char *line;
  
--- 623,629 ----
   * disambiguate it, and print out the result
   */
  void
! disambiguateTextMap(File &file, Vocab &vocab, LM &lm, VocabMap *counts)
  {
      char *line;
  
***************
*** 664,670 ****
  	    LogP totalProb[numNbest];
  	    unsigned numHyps =
  		    disambiguateSentence(vocab, &wids[0], hiddenWids, totalProb,
! 						    map, lm, numNbest, true);
  
  	    if (!numHyps) {
  		file.position() << "Disambiguation failed\n";
--- 694,700 ----
  	    LogP totalProb[numNbest];
  	    unsigned numHyps =
  		    disambiguateSentence(vocab, &wids[0], hiddenWids, totalProb,
! 					    map, lm, counts, numNbest, true);
  
  	    if (!numHyps) {
  		file.position() << "Disambiguation failed\n";
***************
*** 720,725 ****
--- 750,764 ----
  	}
      }
  
+     if (classesFile) {
+ 	File file(classesFile, "r");
+ 
+ 	if (!map.readClasses(file)) {
+ 	    cerr << "format error in classes file\n";
+ 	    exit(1);
+ 	}
+     }
+ 
      if (lmFile) {
  	File file(lmFile, "r");
  
***************
*** 734,746 ****
  	hiddenLM->debugme(debug);
      }
  
      if (textFile) {
  	File file(textFile, "r");
  
  	if (continuous) {
! 	    disambiguateFileContinuous(file, map, *hiddenLM);
  	} else {
! 	    disambiguateFile(file, map, *hiddenLM);
  	}
      }
  
--- 773,797 ----
  	hiddenLM->debugme(debug);
      }
  
+     VocabMap *counts;
+     if (countsFile) {
+ 	counts = new VocabMap(vocab, hiddenVocab);
+ 	assert(counts != 0);
+ 
+ 	counts->remove(vocab.ssIndex, hiddenVocab.ssIndex);
+ 	counts->remove(vocab.seIndex, hiddenVocab.seIndex);
+ 	counts->remove(vocab.unkIndex, hiddenVocab.unkIndex);
+     } else {
+ 	counts = 0;
+     }
+ 
      if (textFile) {
  	File file(textFile, "r");
  
  	if (continuous) {
! 	    disambiguateFileContinuous(file, map, *hiddenLM, counts);
  	} else {
! 	    disambiguateFile(file, map, *hiddenLM, counts);
  	}
      }
  
***************
*** 747,755 ****
      if (textMapFile) {
  	File file(textMapFile, "r");
  
! 	disambiguateTextMap(file, vocab, *hiddenLM);
      }
  
      if (mapWriteFile) {
  	File file(mapWriteFile, "w");
  	map.write(file);
--- 798,812 ----
      if (textMapFile) {
  	File file(textMapFile, "r");
  
! 	disambiguateTextMap(file, vocab, *hiddenLM, counts);
      }
  
+     if (countsFile) {
+ 	File file(countsFile, "w");
+ 
+ 	counts->writeBigrams(file);
+     }
+ 
      if (mapWriteFile) {
  	File file(mapWriteFile, "w");
  	map.write(file);



More information about the SRILM-User mailing list