Implementing Baum-Welch (Forward-Backward) algorithm in SRILM
Andreas Stolcke
stolcke at speech.sri.com
Thu Jan 1 11:05:48 PST 2004
In message <005801c3cf8b$a20869d0$34284484 at cs.technion.ac.il>you wrote:
> Hi,
>
> I'm using disambig for part-of-speech tagging. I create a language model
> over sequences of tags with ngram-count, and provide P(word|tag) in the
> map file.
>
> What I would like to do is to start with this model, based on tagged
> corpus, and improve it using the Baum-Welch (forwad-backward) algorithm,
> with untagged corpus. After each iteration I should get a new language
> model for the tags and a new map file . After each iteration I would
> like to test the model on some held-out data, so I know when to stop.
>
> How can I implement that in SRILM?
You need to write some scripts to manipulate intermediate data, but you
can pretty much do what you want.
To implement EM for your tagger you have two steps:
1. E-Step: get expected counts for the tag n-gram and the word/tag mapping.
a. Tag n-gram expectations. This step is unfortunately not well supported
by the tools right now. Although disambig uses the FB algorithm
it doesn't collect (let alone output) the expected counts in a way
that's suitable for reestimating a model from them. You can use
two approximations. First, you could use the 1-best tag sequence
as a stand-in for the real thing and generate tag N-gram counts from
it (that's sometimes called the "Viterbi" approximation of EM).
Second, you can use the -nbest option to generate the top N most
likely taggings of each sentence along with their score. You then
have to normalize the scores to obtain posterior probabilities for
the tag sequences and weight the tag N-gram counts by these
posteriors and total them over your entire training corpus.
b. Word/tag expectations. Here again, you could use the Viterbi
approximation, simply pairing up the words and their most likely
tags (as output by disambig). However, the most recent version of
disambig actually has an option to collect and output the
expected word/tag bigram counts. I have appended a patch that
should allow you to do this with the 1.3.3 version of disambig.
The option that this adds is
-write-counts file
Outputs the V2-V1 bigram counts corresponding to
the tagging performed on the input data. If -fb
was specified these are expected counts, and other-
wise they reflect the 1-best tagging decisions.
2. M-step: reestimate the tag N-gram LM and the word/tag mapping probabilties
a. Once you have the tag N-gram counts (obtained by one of the methods
suggested above) you just need to run ngram on the count file to
get a new model. Use -float-counts and a suitable discounting method
if you are using fractional counts.
b. Again, just use ngram to estimate a word/tag bigram model from the
expected counts. You then have to post-process the LM file to
extract the word/tag probabilties and format them into a map file
usable by disambig.
Hope this helps.
Happy New Year,
Andreas
*** /tmp/T00BSlQ1 Wed Dec 31 13:30:23 2003
--- /tmp/T10e6tMs Wed Dec 31 13:30:23 2003
***************
*** 38,46 ****
--- 38,48 ----
static char *vocab1File = 0;
static char *vocab2File = 0;
static char *mapFile = 0;
+ static char *classesFile = 0;
static char *mapWriteFile = 0;
static char *textFile = 0;
static char *textMapFile = 0;
+ static char *countsFile = 0;
static int keepUnk = 0;
static int tolower1 = 0;
static int tolower2 = 0;
***************
*** 63,70 ****
--- 65,74 ----
{ OPT_STRING, "write-vocab1", &vocab1File, "output observable vocabulary" },
{ OPT_STRING, "write-vocab2", &vocab2File, "output hidden vocabulary" },
{ OPT_STRING, "map", &mapFile, "mapping from observable to hidden tokens" },
+ { OPT_STRING, "classes", &classesFile, "mapping in class expansion format" },
{ OPT_TRUE, "logmap", &logMap, "map file contains log probabilities" },
{ OPT_STRING, "write-map", &mapWriteFile, "output map file (for validation)" },
+ { OPT_STRING, "write-counts", &countsFile, "output substitution counts" },
{ OPT_TRUE, "scale", &scale, "scale map probabilities by unigram probs" },
{ OPT_TRUE, "keep-unk", &keepUnk, "preserve unknown words" },
{ OPT_TRUE, "tolower1", &tolower1, "map observable vocabulary to lowercase" },
***************
*** 88,94 ****
*/
unsigned
disambiguateSentence(Vocab &vocab, VocabIndex *wids, VocabIndex *hiddenWids[],
! LogP totalProb[], VocabMap &map, LM &lm,
unsigned numNbest, Boolean positionMapped = false)
{
static VocabIndex emptyContext[] = { Vocab_None };
--- 92,98 ----
*/
unsigned
disambiguateSentence(Vocab &vocab, VocabIndex *wids, VocabIndex *hiddenWids[],
! LogP totalProb[], VocabMap &map, LM &lm, VocabMap *counts,
unsigned numNbest, Boolean positionMapped = false)
{
static VocabIndex emptyContext[] = { Vocab_None };
***************
*** 236,241 ****
--- 240,256 ----
}
hiddenWids[n][len] = Vocab_None;
}
+
+ /*
+ * update v1-v2 counts if requested
+ */
+ if (counts) {
+ for (unsigned i = 0; i < len; i++) {
+ counts->put(wids[i], hiddenWids[0][i],
+ counts->get(wids[i], hiddenWids[0][i]) + 1);
+ }
+ }
+
return numNbest;
} else {
/*
***************
*** 426,431 ****
--- 441,460 ----
}
cout << endl;
}
+
+ /*
+ * update v1-v2 counts if requested
+ */
+ if (counts) {
+ symbolIter.init();
+ while (symbolProb = symbolIter.next(symbol)) {
+ LogP2 posterior = *symbolProb - totalPosterior;
+
+ counts->put(wids[pos], symbol,
+ counts->get(wids[pos], symbol) +
+ LogPtoProb(posteriors));
+ }
+ }
}
/*
***************
*** 442,448 ****
* disambiguate it, and print out the result
*/
void
! disambiguateFile(File &file, VocabMap &map, LM &lm)
{
char *line;
VocabString sentence[maxWordsPerLine];
--- 471,477 ----
* disambiguate it, and print out the result
*/
void
! disambiguateFile(File &file, VocabMap &map, LM &lm, VocabMap *counts)
{
char *line;
VocabString sentence[maxWordsPerLine];
***************
*** 476,482 ****
LogP totalProb[numNbest];
unsigned numHyps =
disambiguateSentence(map.vocab1, wids, hiddenWids,
! totalProb, map, lm, numNbest);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
} else if (totals) {
--- 505,511 ----
LogP totalProb[numNbest];
unsigned numHyps =
disambiguateSentence(map.vocab1, wids, hiddenWids,
! totalProb, map, lm, counts, numNbest);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
} else if (totals) {
***************
*** 521,527 ****
* disambiguate it, and print out the result
*/
void
! disambiguateFileContinuous(File &file, VocabMap &map, LM &lm)
{
char *line;
Array<VocabIndex> wids;
--- 550,557 ----
* disambiguate it, and print out the result
*/
void
! disambiguateFileContinuous(File &file, VocabMap &map, LM &lm,
! VocabMap *counts)
{
char *line;
Array<VocabIndex> wids;
***************
*** 560,566 ****
LogP totalProb[numNbest];
unsigned numHyps = disambiguateSentence(map.vocab1, &wids[0], hiddenWids,
! totalProb, map, lm, numNbest);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
--- 590,596 ----
LogP totalProb[numNbest];
unsigned numHyps = disambiguateSentence(map.vocab1, &wids[0], hiddenWids,
! totalProb, map, lm, counts, numNbest);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
***************
*** 593,599 ****
* disambiguate it, and print out the result
*/
void
! disambiguateTextMap(File &file, Vocab &vocab, LM &lm)
{
char *line;
--- 623,629 ----
* disambiguate it, and print out the result
*/
void
! disambiguateTextMap(File &file, Vocab &vocab, LM &lm, VocabMap *counts)
{
char *line;
***************
*** 664,670 ****
LogP totalProb[numNbest];
unsigned numHyps =
disambiguateSentence(vocab, &wids[0], hiddenWids, totalProb,
! map, lm, numNbest, true);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
--- 694,700 ----
LogP totalProb[numNbest];
unsigned numHyps =
disambiguateSentence(vocab, &wids[0], hiddenWids, totalProb,
! map, lm, counts, numNbest, true);
if (!numHyps) {
file.position() << "Disambiguation failed\n";
***************
*** 720,725 ****
--- 750,764 ----
}
}
+ if (classesFile) {
+ File file(classesFile, "r");
+
+ if (!map.readClasses(file)) {
+ cerr << "format error in classes file\n";
+ exit(1);
+ }
+ }
+
if (lmFile) {
File file(lmFile, "r");
***************
*** 734,746 ****
hiddenLM->debugme(debug);
}
if (textFile) {
File file(textFile, "r");
if (continuous) {
! disambiguateFileContinuous(file, map, *hiddenLM);
} else {
! disambiguateFile(file, map, *hiddenLM);
}
}
--- 773,797 ----
hiddenLM->debugme(debug);
}
+ VocabMap *counts;
+ if (countsFile) {
+ counts = new VocabMap(vocab, hiddenVocab);
+ assert(counts != 0);
+
+ counts->remove(vocab.ssIndex, hiddenVocab.ssIndex);
+ counts->remove(vocab.seIndex, hiddenVocab.seIndex);
+ counts->remove(vocab.unkIndex, hiddenVocab.unkIndex);
+ } else {
+ counts = 0;
+ }
+
if (textFile) {
File file(textFile, "r");
if (continuous) {
! disambiguateFileContinuous(file, map, *hiddenLM, counts);
} else {
! disambiguateFile(file, map, *hiddenLM, counts);
}
}
***************
*** 747,755 ****
if (textMapFile) {
File file(textMapFile, "r");
! disambiguateTextMap(file, vocab, *hiddenLM);
}
if (mapWriteFile) {
File file(mapWriteFile, "w");
map.write(file);
--- 798,812 ----
if (textMapFile) {
File file(textMapFile, "r");
! disambiguateTextMap(file, vocab, *hiddenLM, counts);
}
+ if (countsFile) {
+ File file(countsFile, "w");
+
+ counts->writeBigrams(file);
+ }
+
if (mapWriteFile) {
File file(mapWriteFile, "w");
map.write(file);
More information about the SRILM-User
mailing list