Source code for genalog.text.anchor

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# ---------------------------------------------------------

"""
Baseline alignment algorithm is slow on long documents.
The idea is to break down the longer text into smaller fragments
for quicker alignment on individual pieces. We refer "anchor words"
as these points of breakage.

The bulk of this algorithm is to identify these "anchor words".

This is an re-implementation of the algorithm in this paper
"A Fast Alignment Scheme for Automatic OCR Evaluation of Books"
(https://ieeexplore.ieee.org/document/6065412)

We rely on `genalog.text.alignment` to align the subsequences.
"""
import itertools
from collections import Counter

from genalog.text import alignment, preprocess
from genalog.text.alignment import GAP_CHAR
from genalog.text.lcs import LCS

# The recursively portion of the algorithm will run on
# segments longer than this value to find anchor points in
# the longer segment (to break it up further).
MAX_ALIGN_SEGMENT_LENGTH = 100  # in characters length


[docs]def get_unique_words(tokens, case_sensitive=False): """Get a set of unique words from a Counter dictionary of word occurrences Arguments: d (dict) : a Counter dictionary of word occurrences case_sensitive (bool, optional) : whether unique words are case sensitive. Defaults to False. Returns: set: a set of unique words (original alphabetical case of the word is preserved) """ if case_sensitive: word_count = Counter(tokens) return {word for word, count in word_count.items() if count < 2} else: tokens_lowercase = [tk.lower() for tk in tokens] word_count = Counter(tokens_lowercase) return {tk for tk in tokens if word_count[tk.lower()] < 2}
[docs]def segment_len(tokens): """Get length of the segment Arguments: segment (list) : a list of tokens Returns: int : the length of the segment """ return sum(map(len, tokens))
[docs]def get_word_map(unique_words, src_tokens): """Arrange the set of unique words by the order they original appear in the text Arguments: unique_words (set) : a set of unique words src_tokens (list) : a list of tokens Returns: list : a ``word_map``: a list of word corrdinate tuples ``(word, word_index)`` defined as follow: 1. ``word`` is a typical word token 2. ``word_index`` is the index of the word in the source token array """ # Find the indices of the unique words in the source text unique_word_indices = map(src_tokens.index, unique_words) word_map = list(zip(unique_words, unique_word_indices)) word_map.sort(key=lambda x: x[1]) # Re-arrange order by the index return word_map
[docs]def get_anchor_map(gt_tokens, ocr_tokens, min_anchor_len=2): """Find the location of anchor words in both the gt and ocr text. Anchor words are location where we can split both the source gt and ocr text into smaller text fragment for faster alignment. Arguments: gt_tokens (list) : a list of ground truth tokens ocr_tokens (list) : a list of tokens from OCR'ed document min_anchor_len (int, optional) : minimum len of the anchor word. Defaults to 2. Returns: tuple: a 2-element ``(anchor_map_gt, anchor_map_ocr)`` tuple: 1. ``anchor_map_gt`` is a ``word_map`` that locates all the anchor words in the gt tokens 2. ``anchor_map_gt`` is a ``word_map`` that locates all the anchor words in the ocr tokens And ``len(anchor_map_gt) == len(anchor_map_ocr)`` :: For example: Input: gt_tokens: ["b", "a", "c"] ocr_tokens: ["c", "b", "a"] Ourput: ([("b", 0), ("a", 1)], [("b", 1), ("a", 2)]) """ # 1. Get unique words common in both gt and ocr unique_words_gt = get_unique_words(gt_tokens) unique_words_ocr = get_unique_words(ocr_tokens) unique_words_common = unique_words_gt.intersection(unique_words_ocr) if not unique_words_common: return [], [] # 2. Arrange the common unique words in their original order unique_word_map_gt = get_word_map(unique_words_common, gt_tokens) unique_word_map_ocr = get_word_map(unique_words_common, ocr_tokens) # Unzip to get the ordered unique_words ordered_unique_words_gt, _ = zip(*unique_word_map_gt) ordered_unique_words_ocr, _ = zip(*unique_word_map_ocr) # Join words into a space-separated string for finding LCS unique_words_gt_str = preprocess.join_tokens(ordered_unique_words_gt) unique_words_ocr_str = preprocess.join_tokens(ordered_unique_words_ocr) # 3. Find the LCS between the two ordered list of unique words lcs = LCS(unique_words_gt_str, unique_words_ocr_str) lcs_str = lcs.get_str() # 4. Break up the LCS string into tokens lcs_words = set(preprocess.tokenize(lcs_str)) # 5. Anchor words are the unique words in the lcs string anchor_words = lcs_words.intersection(unique_words_common) # 6. Filter the unique words to keep the anchor words ONLY anchor_map_gt = list( filter( # This is a list of (unique_word, unique_word_index) lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_gt, ) ) anchor_map_ocr = list( filter( lambda word_coordinate: word_coordinate[0] in anchor_words, unique_word_map_ocr, ) ) return anchor_map_gt, anchor_map_ocr
[docs]def find_anchor_recur( gt_tokens, ocr_tokens, start_pos_gt=0, start_pos_ocr=0, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH, ): """Recursively find anchor positions in the gt and ocr text Arguments: gt_tokens (list) : a list of ground truth tokens ocr_tokens (list) : a list of tokens from OCR'ed document start_pos (int, optional) : a constant to add to all the resulting indices. Defaults to 0. max_seg_length (int, optional) : trigger recursion if any text segment is larger than this. Defaults to ``MAX_ALIGN_SEGMENT_LENGTH``. Raises: ValueError: when there different number of anchor points in gt and ocr. Returns: tuple : two lists of token indices where each list is the position of the anchor in the input ``gt_tokens`` and ``ocr_tokens`` """ # 1. Try to find anchor words anchor_word_map_gt, anchor_word_map_ocr = get_anchor_map(gt_tokens, ocr_tokens) # 2. Check invariant if len(anchor_word_map_gt) != len(anchor_word_map_ocr): raise ValueError("Unequal number of anchor points across gt and ocr string") # Return empty if no anchor word found if len(anchor_word_map_gt) == 0: return [], [] # 3. Unzip map to get indices of the anchor tokens _, anchor_indices_gt = map(list, zip(*anchor_word_map_gt)) _, anchor_indices_ocr = map(list, zip(*anchor_word_map_ocr)) output_gt_anchors = set(map(lambda x: x + start_pos_gt, anchor_indices_gt)) output_ocr_anchors = set(map(lambda x: x + start_pos_ocr, anchor_indices_ocr)) # 4. Find split point of each segment seg_start_gt = list(itertools.chain([0], anchor_indices_gt)) seg_start_ocr = list(itertools.chain([0], anchor_indices_ocr)) start_n_end_gt = zip(seg_start_gt, itertools.chain(anchor_indices_gt, [None])) start_n_end_ocr = zip(seg_start_ocr, itertools.chain(anchor_indices_ocr, [None])) gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt] ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr] # 4. Loop through each segment for gt_seg, ocr_seg, gt_start, ocr_start in zip( gt_segments, ocr_segments, seg_start_gt, seg_start_ocr ): if ( segment_len(gt_seg) > max_seg_length or segment_len(ocr_seg) > max_seg_length ): # recur on the segment in between the two anchors. # We assume the first token in the segment is an anchor word gt_anchors, ocr_anchors = find_anchor_recur( gt_seg[1:], ocr_seg[1:], start_pos_gt=gt_start + 1, start_pos_ocr=ocr_start + 1, max_seg_length=max_seg_length, ) # shift the token indices # (these are indices of a subsequence and does not reflect true position in the source sequence) gt_anchors = set(map(lambda x: x + start_pos_gt, gt_anchors)) ocr_anchors = set(map(lambda x: x + start_pos_ocr, ocr_anchors)) # merge recursion results output_gt_anchors = output_gt_anchors.union(gt_anchors) output_ocr_anchors = output_ocr_anchors.union(ocr_anchors) return sorted(output_gt_anchors), sorted(output_ocr_anchors)
[docs]def align_w_anchor(gt, ocr, gap_char=GAP_CHAR, max_seg_length=MAX_ALIGN_SEGMENT_LENGTH): """A faster alignment scheme of two text segments. This method first breaks the strings into smaller segments with anchor words. Then these smaller segments are aligned. **NOTE:** this function shares the same contract as `genalog.text.alignment.align()` These two methods are interchangeable and their alignment results should be similar. :: For example: Ground Truth: "The planet Mars, I scarcely need remind the reader," Noisy Text: "The plamet Maris, I scacely neee remind te reader," Here the unique anchor words are "I", "remind" and "reader". Thus, the algorithm will split into following segment pairs: "The planet Mar, " "The plamet Maris, " "I scarcely need " "I scacely neee " "remind the reader," "remind te reader," And run sequence alignment on each pair. Arguments: gt (str) : ground truth text noise (str) : text with ocr noise gap_char (str, optional) : gap char used in alignment algorithm . Defaults to GAP_CHAR. max_seg_length (int, optional) : maximum segment length. Segments longer than this threshold will continued be split recursively into smaller segment. Defaults to ``MAX_ALIGN_SEGMENT_LENGTH``. Returns: a tuple (str, str) of aligned ground truth and noise: (aligned_gt, aligned_noise) """ gt_tokens = preprocess.tokenize(gt) ocr_tokens = preprocess.tokenize(ocr) # 1. Find anchor positions gt_anchors, ocr_anchors = find_anchor_recur( gt_tokens, ocr_tokens, max_seg_length=max_seg_length ) # 2. Split into segments start_n_end_gt = zip( itertools.chain([0], gt_anchors), itertools.chain(gt_anchors, [None]) ) start_n_end_ocr = zip( itertools.chain([0], ocr_anchors), itertools.chain(ocr_anchors, [None]) ) gt_segments = [gt_tokens[start:end] for start, end in start_n_end_gt] ocr_segments = [ocr_tokens[start:end] for start, end in start_n_end_ocr] # 3. Run alignment on each segment aligned_segments_gt = [] aligned_segments_ocr = [] for gt_segment, noisy_segment in zip(gt_segments, ocr_segments): gt_segment = preprocess.join_tokens(gt_segment) noisy_segment = preprocess.join_tokens(noisy_segment) # Run alignment algorithm aligned_seg_gt, aligned_seg_ocr = alignment.align( gt_segment, noisy_segment, gap_char=gap_char ) if aligned_seg_gt and aligned_seg_ocr: # if not empty string "" aligned_segments_gt.append(aligned_seg_gt) aligned_segments_ocr.append(aligned_seg_ocr) # Stitch all segments together aligned_gt = " ".join(aligned_segments_gt) aligned_noise = " ".join(aligned_segments_ocr) return aligned_gt, aligned_noise