Context Biasing

In the practical application of ASR, the recognition effect of commonly used words is better, but for some unique words, the recognition accuracy may be low. Contextual biasing is the problem of injecting prior knowledge into an ASR system during inference, for example a user’s favorite songs, contacts, apps or location. Conventional ASR systems perform contextual biasing by building an n-gram finite state transducer (FST) from a list of biasing phrases, which is composed on-the-fly with the decoder graph during decoding. This helps to bias the recognition result towards the n-grams contained in the contextual FST, and thus improves accuracy in certain scenarios.

In WeNet, we compute the biasing scores $P_C(\mathbf y)$, which are interpolated with the base model $P(\mathbf y|\mathbf x)$ using shallow-fusion during beam search, including CTC prefix beam search and CTC WFST beam search.

$$ \mathbf y^*=\mathrm{arg,max,log}P(\mathbf y|\mathbf x)+\lambda,\mathrm{log},P_C(\mathbf y) $$

where, $\lambda$ is a tunable hyperparameter controlling how much the contextual LM influences the overall model score during beam search.

Context Graph

If we want to improve the score of the word “cat”, and the biasing score $\lambda,\mathrm{log},P_C(\mathbf y)$ of each character is 0.25. The context graph can be constructed as follow:

context graph

In the decoding process, when the corresponding prefix is matched, the corresponding score reward will be obtained. In order to avoid artificially boosting prefixes which match early on but do not match the entire phrase, we add a special failure arc which removes the boosted score.

WeNet records only one state for each prefix, to easily determine the boundary of the matched hot word. That is, only one hot word can be matched at the same time, and only after the hot word matching succeeds or fails can other hot words start matching.

int ContextGraph::GetNextState(int cur_state, int word_id, float* score,
                               bool* is_start_boundary, bool* is_end_boundary) {
  int next_state = 0;
  // Traverse the arcs of current state.
  for (fst::ArcIterator<fst::StdFst> aiter(*graph_, cur_state); !aiter.Done();
       aiter.Next()) {
    const fst::StdArc& arc = aiter.Value();
    if (arc.ilabel == 0) {
      // Record the score of the backoff arc. It might will be covered.
      *score = arc.weight.Value();
    } else if (arc.ilabel == word_id) {
      // If they match, record the next state and the score.
      next_state = arc.nextstate;
      *score = arc.weight.Value();
      // Check whether is the boundary of the hot word.
      if (cur_state == 0) {
        *is_start_boundary = true;
      }
      if (graph_->Final(arc.nextstate) == fst::StdArc::Weight::One()) {
        *is_end_boundary = true;
      }
      break;
    }
  }
  return next_state;
}

Pruning

The backoff arc will return the accumulated scores to a single ForwardLink. It leads to the cost of that ForwardLink is too large. We have to remove the cost returned by backoff arc before pruning.

void LatticeFasterDecoderTpl<FST, Token>::PruneForwardLinks(
    int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned,
    BaseFloat delta) {
  ...
  BaseFloat link_extra_cost =
            next_tok->extra_cost +
            ((tok->tot_cost + link->acoustic_cost + link->graph_cost) -
             next_tok->tot_cost);  // difference in brackets is >= 0
  // ========== Context code BEGIN ===========
  // graph_cost contains the score of hot word
  // link->context_score < 0 means the hot word of the link is returned from backoff arc
  if (link->context_score < 0) {
    link_extra_cost += link->context_score;
  }
  // ========== Context code END ==========
  // link_exta_cost is the difference in score between the best paths
  // through link source state and through link destination state

Usage

  1. Specify the --context_path to a text file.

    • Each line of the file contains a context.

    • Each context can be split into words with the symbol_table of the ASR model (It means there is no oov in the context).

  2. Specify the --context_score, the reward of each word in the context.

cd /home/wenet/runtime/libtorch
export GLOG_logtostderr=1
export GLOG_v=2
wav_path=docker_resource/test.wav
context_path=docker_resource/context.txt
model_dir=docker_resource/model
./build/decoder_main \
    --chunk_size -1 \
    --wav_path $wav_path \
    --model_path $model_dir/final.zip \
    --context_path $context_path \
    --context_score 3 \
    --unit_path $model_dir/units.txt 2>&1 | tee log.txt