Skip to content

medcat.components.ner.trf.tokenizer

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

TransformersTokenizer

TransformersTokenizer(hf_tokenizer: Optional[PreTrainedTokenizerBase] = None, max_len: int = 512, id2type: Optional[Dict] = None, cui2name: Optional[Dict] = None)

hf_tokenizer Must be able to return token offsets. max_len: Max sequence length, if longer it will be split into multiple examples. id2type: Can be ignored in most cases, should be a map from token to 'start' or 'sub' meaning is the token a subword or the start/full word. For BERT 'start' is everything that does not begin with ##. cui2name: Map from CUI to full name for labels.

Methods:

Attributes:

Source code in medcat-v2/medcat/components/ner/trf/tokenizer.py
24
25
26
27
28
29
30
31
32
33
34
def __init__(self,
             hf_tokenizer: Optional[PreTrainedTokenizerBase] = None,
             max_len: int = 512,
             id2type: Optional[Dict] = None,
             cui2name: Optional[Dict] = None) -> None:
    self.hf_tokenizer = hf_tokenizer
    self.max_len = max_len
    # We'll keep the 'X' in case id2type is provided
    self.label_map = {'O': 0, 'X': 1}
    self.id2type = id2type
    self.cui2name = cui2name

cui2name instance-attribute

cui2name = cui2name

hf_tokenizer instance-attribute

hf_tokenizer = hf_tokenizer

id2type instance-attribute

id2type = id2type

label_map instance-attribute

label_map = {'O': 0, 'X': 1}

max_len instance-attribute

max_len = max_len

calculate_label_map

calculate_label_map(dataset) -> None
Source code in medcat-v2/medcat/components/ner/trf/tokenizer.py
36
37
38
39
40
def calculate_label_map(self, dataset) -> None:
    for cuis in dataset['ent_cuis']:
        for cui in cuis:
            if cui not in self.label_map:
                self.label_map[cui] = len(self.label_map)

encode

encode(examples: Dict, ignore_subwords: bool = False) -> Dict

Used with huggingface datasets map function to convert medcat_ner dataset into the appropriate form for NER with BERT. It will split long text segments into max_len sequences (performs chunking).

Parameters:

  • examples

    (Dict) –

    Stream of examples.

  • ignore_subwords

    (bool, default: False ) –

    If set to True subwords of any token will get the special label X.

Returns:

  • Dict ( Dict ) –

    The same dict, modified.

Source code in medcat-v2/medcat/components/ner/trf/tokenizer.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def encode(self, examples: Dict, ignore_subwords: bool = False) -> Dict:
    """Used with huggingface datasets map function to convert medcat_ner
    dataset into the appropriate form for NER with BERT. It will split
    long text segments into max_len sequences (performs chunking).

    Args:
        examples (Dict):
            Stream of examples.
        ignore_subwords (bool):
            If set to `True` subwords of any token will get the special
            label `X`.

    Returns:
        Dict: The same dict, modified.
    """
    self.hf_tokenizer = self.ensure_tokenizer()
    old_ids = examples['id']
    old_names = examples['name']
    examples['input_ids'] = []
    examples['labels'] = []
    examples['id'] = []
    examples['name'] = []

    for _ind, example in enumerate(zip(
            examples['text'], examples['ent_starts'],
            examples['ent_ends'], examples['ent_cuis'])):
        tokens = self.hf_tokenizer(
            example[0], return_offsets_mapping=True,
            add_special_tokens=False)
        entities = [(start, end, cui)
                    for start, end, cui in
                    zip(example[1], example[2], example[3])]
        entities.sort(key=lambda x: x[0])
        input_ids = []
        labels = []

        tkn_part_of_entity = False
        for ind in range(len(tokens['offset_mapping'])):
            offset = tokens['offset_mapping'][ind]
            input_ids.append(tokens['input_ids'][ind])

            if (entities and (offset[0] >= entities[0][0] and
                              offset[1] <= entities[0][1])):
                # Means this token is part of entity at position 0
                tkn_part_of_entity = True
                if (not ignore_subwords or (
                        self.id2type is not None and
                        self.id2type[tokens['input_ids'][ind]] == 'start'
                        )):
                    labels.append(self.label_map[entities[0][2]])
                else:
                    labels.append(self.label_map['X'])

                if entities[0][1] <= offset[1]:
                    # If it is the last token of the entity,
                    # remove the entity as it is done
                    del entities[0]
                    # Set this so the next token is not removed
                    tkn_part_of_entity = False

            else:
                if tkn_part_of_entity:
                    del entities[0]
                    tkn_part_of_entity = False

                if (not ignore_subwords or (
                        self.id2type is not None and
                        self.id2type[tokens['input_ids'][ind]] == 'start'
                        )):
                    labels.append(self.label_map["O"])
                else:
                    labels.append(self.label_map['X'])

            if len(input_ids) >= self.max_len:
                logger.debug(
                    "Document exceeding max length encountered. Length of"
                    " current document is %d. Performing chunking...",
                    len(tokens['offset_mapping']))
                # Split into multiple examples if too long
                examples['input_ids'].append(input_ids)
                examples['labels'].append(labels)
                examples['id'].append(old_ids[_ind])
                examples['name'].append(old_names[_ind])

                input_ids = []
                labels = []

        if input_ids:
            examples['input_ids'].append(input_ids)
            examples['labels'].append(labels)
            examples['id'].append(old_ids[_ind])
            examples['name'].append(old_names[_ind])

    return examples

ensure_tokenizer

ensure_tokenizer() -> PreTrainedTokenizerBase
Source code in medcat-v2/medcat/components/ner/trf/tokenizer.py
141
142
143
144
def ensure_tokenizer(self) -> PreTrainedTokenizerBase:
    if self.hf_tokenizer is None:
        raise ValueError("The tokenizer is not loaded yet")
    return self.hf_tokenizer

load classmethod

load(path: str) -> TransformersTokenizer
Source code in medcat-v2/medcat/components/ner/trf/tokenizer.py
146
147
148
149
150
151
152
153
154
@classmethod
def load(cls, path: str) -> 'TransformersTokenizer':
    tokenizer = cls()
    with open(path, 'rb') as f:
        d = dill.load(f)
        for k in tokenizer.__dict__:
            if k in d:
                tokenizer.__dict__[k] = d[k]
    return tokenizer

save

save(path: str) -> None
Source code in medcat-v2/medcat/components/ner/trf/tokenizer.py
137
138
139
def save(self, path: str) -> None:
    with open(path, 'wb') as f:
        dill.dump(self.__dict__, f)