Skip to content

medcat.trainer

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

Trainer

Trainer(cdb: CDB, caller: Callable[[str], MutableDocument], pipeline: Pipeline)

Methods:

Attributes:

Source code in medcat-v2/medcat/trainer.py
30
31
32
33
34
35
def __init__(self, cdb: CDB, caller: Callable[[str], MutableDocument],
             pipeline: Pipeline):
    self.cdb = cdb
    self.config = cdb.config
    self.caller = caller
    self._pipeline = pipeline

caller instance-attribute

caller = caller

cdb instance-attribute

cdb = cdb

config instance-attribute

config = config

strict_train class-attribute instance-attribute

strict_train: bool = False

add_and_train_concept

add_and_train_concept(cui: str, name: str, mut_doc: Optional[MutableDocument] = None, mut_entity: Optional[Union[list[MutableToken], MutableEntity]] = None, ontologies: set[str] = set(), name_status: str = 'A', type_ids: set[str] = set(), description: str = '', full_build: bool = True, negative: bool = False, devalue_others: bool = False, do_add_concept: bool = True) -> None

Add a name to an existing concept, or add a new concept, or do not do anything if the name or concept already exists. Perform training if spacy_entity and spacy_doc are set.

Parameters:

  • cui

    (str) –

    CUI of the concept.

  • name

    (str) –

    Name to be linked to the concept (in the case of MedCATtrainer this is simply the selected value in text, no preprocessing or anything needed).

  • mut_doc

    (Optional[MutableDocument], default: None ) –

    Spacy representation of the document that was manually annotated.

  • mut_entity (mut_entity

    Optional[Union[list[MutableToken], MutableEntity]]): Given the spacy document, this is the annotated span of text - list of annotated tokens that are marked with this CUI.

  • ontologies

    (set[str], default: set() ) –

    ontologies in which the concept exists (e.g. SNOMEDCT, HPO)

  • name_status

    (str, default: 'A' ) –

    One of P, N, A

  • type_ids

    (set[str], default: set() ) –

    Semantic type identifier (have a look at TUIs in UMLS or SNOMED-CT)

  • description

    (str, default: '' ) –

    Description of this concept.

  • full_build

    (bool, default: True ) –

    If True the dictionary self.addl_info will also be populated, contains a lot of extra information about concepts, but can be very memory consuming. This is not necessary for normal functioning of MedCAT (Default Value False).

  • negative

    (bool, default: False ) –

    Is this a negative or positive example.

  • devalue_others

    (bool, default: False ) –

    If set, cuis to which this name is assigned and are not cui will receive negative training given that negative=False.

  • do_add_concept

    (bool, default: True ) –

    Whether to add concept to CDB.

Source code in medcat-v2/medcat/trainer.py
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
def add_and_train_concept(self,
                          cui: str,
                          name: str,
                          mut_doc: Optional[MutableDocument] = None,
                          mut_entity: Optional[
                              Union[list[MutableToken],
                                    MutableEntity]] = None,
                          ontologies: set[str] = set(),
                          name_status: str = 'A',
                          type_ids: set[str] = set(),
                          description: str = '',
                          full_build: bool = True,
                          negative: bool = False,
                          devalue_others: bool = False,
                          do_add_concept: bool = True) -> None:
    r"""Add a name to an existing concept, or add a new concept, or do not
    do anything if the name or concept already exists. Perform training if
    spacy_entity and spacy_doc are set.

    Args:
        cui (str):
            CUI of the concept.
        name (str):
            Name to be linked to the concept (in the case of MedCATtrainer
            this is simply the selected value in text, no preprocessing or
            anything needed).
        mut_doc (Optional[MutableDocument]):
            Spacy representation of the document that was manually
            annotated.
        mut_entity (mut_entity: Optional[Union[list[MutableToken],
                                               MutableEntity]]):
            Given the spacy document, this is the annotated span of text -
            list of annotated tokens that are marked with this CUI.
        ontologies (set[str]):
            ontologies in which the concept exists (e.g. SNOMEDCT, HPO)
        name_status (str):
            One of `P`, `N`, `A`
        type_ids (set[str]):
            Semantic type identifier (have a look at TUIs in UMLS or
            SNOMED-CT)
        description (str):
            Description of this concept.
        full_build (bool):
            If True the dictionary self.addl_info will also be populated,
            contains a lot of extra information about concepts, but can be
            very memory consuming. This is not necessary for normal
            functioning of MedCAT (Default Value `False`).
        negative (bool):
            Is this a negative or positive example.
        devalue_others (bool):
            If set, cuis to which this name is assigned and are not `cui`
            will receive negative training given that negative=False.
        do_add_concept (bool):
            Whether to add concept to CDB.
    """
    names = prepare_name(name, self._pipeline.tokenizer_with_tag, {},
                         self._pn_configs)
    if (not names and cui not in self.cdb.cui2info and
            name_status == 'P'):
        logger.warning(
            "No names were able to be prepared in "
            "CAT.add_and_train_concept method. As such no preferred name "
            "will be able to be specifeid. The CUI: '%s' and raw name: "
            "'%s'", cui, name)
    # Only if not negative, otherwise do not add the new name if in fact
    # it should not be detected
    if do_add_concept and not negative:
        self.cdb._add_concept(cui=cui, names=names, ontologies=ontologies,
                              name_status=name_status, type_ids=type_ids,
                              description=description,
                              full_build=full_build)

    if mut_entity is None or mut_doc is None:
        return
    linker = self._pipeline.get_component(
        CoreComponentType.linking)
    if not isinstance(linker, TrainableComponent):
        logger.warning(
            "Linker cannot be trained during add_and_train_concept"
            "because it has no train method: %s", linker)
    else:
        # Train Linking
        if isinstance(mut_entity, list):
            mut_entity = self._pipeline.entity_from_tokens(mut_entity)
        linker.train(cui=cui, entity=mut_entity, doc=mut_doc,
                     negative=negative, names=names)

        if not negative and devalue_others:
            # Find all cuis
            cuis: set[str] = set()
            for n in names:
                if n in self.cdb.name2info:
                    info = self.cdb.name2info[n]
                    cuis.update(info['per_cui_status'].keys())
            # Remove the cui for which we just added positive training
            if cui in cuis:
                cuis.remove(cui)
            # Add negative training for all other CUIs that link to
            # these names
            for _cui in cuis:
                linker.train(cui=_cui, entity=mut_entity, doc=mut_doc,
                             negative=True)

train_supervised_raw

Train supervised based on the raw data provided.

The raw data is expected in the following format: {'projects': [ # list of projects { # project 1 'name': '', # list of documents 'documents': [{'name': '', # document 1 'text': '', # list of annotations 'annotations': [# annotation 1 {'start': -1, 'end': 1, 'cui': 'cui', 'value': ''}, ...], }, ...] }, ... ] }

Please take care that this is more a simulated online training then upervised.

When filtering, the filters within the CAT model are used first, then the ones from MedCATtrainer (MCT) export filters, and finally the extra_cui_filter (if set). That is to say, the expectation is: extra_cui_filter ⊆ MCT filter ⊆ Model/config filter.

Parameters:

  • data

    (dict[str, list[dict[str, dict]]]) –

    The raw data, e.g from MedCATtrainer on export.

  • reset_cui_count

    (bool, default: False ) –

    Used for training with weight_decay (annealing). Each concept has a count that is there from the beginning of the CDB, that count is used for annealing. Resetting the count will significantly increase the training impact. This will reset the count only for concepts that exist in the the training data.

  • nepochs

    (int, default: 1 ) –

    Number of epochs for which to run the training.

  • print_stats

    (int, default: 0 ) –

    If > 0 it will print stats every print_stats epochs.

  • use_filters

    (bool, default: False ) –

    Each project in medcattrainer can have filters, do we want to respect those filters when calculating metrics.

  • terminate_last

    (bool, default: False ) –

    If true, concept termination will be done after all training.

  • use_overlaps

    (bool, default: False ) –

    Allow overlapping entities, nearly always False as it is very difficult to annotate overlapping entities.

  • use_cui_doc_limit

    (bool, default: False ) –

    If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words if the document was annotated for that CUI. Useful in very specific situations when during the annotation process the set of CUIs changed.

  • test_size

    (float, default: 0 ) –

    If > 0 the data set will be split into train test based on this ration. Should be between 0 and 1. Usually 0.1 is fine.

  • devalue_others

    (bool, default: False ) –

    Check add_name for more details.

  • use_groups

    (bool, default: False ) –

    If True concepts that have groups will be combined and stats will be reported on groups.

  • never_terminate

    (bool, default: False ) –

    If True no termination will be applied

  • train_from_false_positives

    (bool, default: False ) –

    If True it will use false positive examples detected by medcat and train from them as negative examples.

  • extra_cui_filter

    (Optional[set], default: None ) –

    This filter will be intersected with all other filters, or if all others are not set then only this one will be used.

  • checkpoint

    (Optional[Optional[medcat.utils.checkpoint.Checkpoint]) –

    The MedCAT Checkpoint object

  • disable_progress

    (bool, default: False ) –

    Whether to disable the progress output (tqdm). Defaults to False.

  • train_addons

    (bool, default: False ) –

    Whether to also train the addons (e.g MetaCATs). Defaults to False.

Returns:

  • tuple ( tuple ) –

    Consisting of the following parts fp (dict): False positives for each CUI. fn (dict): False negatives for each CUI. tp (dict): True positives for each CUI. p (dict): Precision for each CUI. r (dict): Recall for each CUI. f1 (dict): F1 for each CUI. cui_counts (dict): Number of occurrence for each CUI. examples (dict): FP/FN examples of sentences for each CUI.

Source code in medcat-v2/medcat/trainer.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    def train_supervised_raw(self,
                             data: MedCATTrainerExport,
                             reset_cui_count: bool = False,
                             nepochs: int = 1,
                             print_stats: int = 0,
                             use_filters: bool = False,
                             terminate_last: bool = False,
                             use_overlaps: bool = False,
                             use_cui_doc_limit: bool = False,
                             test_size: float = 0,
                             devalue_others: bool = False,
                             use_groups: bool = False,
                             never_terminate: bool = False,
                             train_from_false_positives: bool = False,
                             extra_cui_filter: Optional[set[str]] = None,
                             #  checkpoint: Optional[Checkpoint] = None,
                             disable_progress: bool = False,
                             train_addons: bool = False,
                             ) -> tuple:
        """Train supervised based on the raw data provided.

        The raw data is expected in the following format:
        {'projects':
            [ # list of projects
                { # project 1
                    'name': '<some name>',
                    # list of documents
                    'documents': [{'name': '<some name>',  # document 1
                                    'text': '<text of the document>',
                                    # list of annotations
                                    'annotations': [# annotation 1
                                                    {'start': -1,
                                                    'end': 1,
                                                    'cui': 'cui',
                                                    'value': '<text value>'},
                                                    ...],
                                    }, ...]
                }, ...
            ]
        }

        Please take care that this is more a simulated online training then
        upervised.

        When filtering, the filters within the CAT model are used first,
        then the ones from MedCATtrainer (MCT) export filters,
        and finally the extra_cui_filter (if set).
        That is to say, the expectation is:
        extra_cui_filter ⊆ MCT filter ⊆ Model/config filter.

        Args:
            data (dict[str, list[dict[str, dict]]]):
                The raw data, e.g from MedCATtrainer on export.
            reset_cui_count (bool):
                Used for training with weight_decay (annealing). Each concept
                has a count that is there from the beginning of the CDB, that
                count is used for annealing. Resetting the count will
                significantly increase the training impact. This will reset
                the count only for concepts that exist in the the training
                data.
            nepochs (int):
                Number of epochs for which to run the training.
            print_stats (int):
                If > 0 it will print stats every print_stats epochs.
            use_filters (bool):
                Each project in medcattrainer can have filters, do we want to
                respect those filters
                when calculating metrics.
            terminate_last (bool):
                If true, concept termination will be done after all training.
            use_overlaps (bool):
                Allow overlapping entities, nearly always False as it is very
                difficult to annotate overlapping entities.
            use_cui_doc_limit (bool):
                If True the metrics for a CUI will be only calculated if that
                CUI appears in a document, in other words if the document was
                annotated for that CUI. Useful in very specific situations
                when during the annotation process the set of CUIs changed.
            test_size (float):
                If > 0 the data set will be split into train test based on
                this ration. Should be between 0 and 1. Usually 0.1 is fine.
            devalue_others(bool):
                Check add_name for more details.
            use_groups (bool):
                If True concepts that have groups will be combined and stats
                will be reported on groups.
            never_terminate (bool):
                If True no termination will be applied
            train_from_false_positives (bool):
                If True it will use false positive examples detected by medcat
                and train from them as negative examples.
            extra_cui_filter(Optional[set]):
                This filter will be intersected with all other filters, or if
                all others are not set then only this one will be used.
            checkpoint (Optional[Optional[medcat.utils.checkpoint.Checkpoint]):
                The MedCAT Checkpoint object
            disable_progress (bool):
                Whether to disable the progress output (tqdm). Defaults to
                False.
            train_addons (bool):
                Whether to also train the addons (e.g MetaCATs). Defaults
                to False.

        Returns:
            tuple: Consisting of the following parts
                fp (dict):
                    False positives for each CUI.
                fn (dict):
                    False negatives for each CUI.
                tp (dict):
                    True positives for each CUI.
                p (dict):
                    Precision for each CUI.
                r (dict):
                    Recall for each CUI.
                f1 (dict):
                    F1 for each CUI.
                cui_counts (dict):
                    Number of occurrence for each CUI.
                examples (dict):
                    FP/FN examples of sentences for each CUI.
        """
        # checkpoint = self._init_ckpts(is_resumed, checkpoint)

        # the config.linking.filters stuff is used directly in
# medcat.linking.context_based_linker and medcat.linking.vector_context_model
        # as such, they need to be kept up to date with per-project filters
        # However, the original state needs to be kept track of
        # so that it can be restored after training

        fp: dict[str, int] = {}
        fn: dict[str, int] = {}
        tp: dict[str, int] = {}
        p: dict[str, float] = {}
        r: dict[str, float] = {}
        f1: dict[str, float] = {}
        examples: dict[str, object] = {}

        cui_counts: dict[str, int] = {}

        if test_size == 0:
            logger.info("Running without a test set, or train==test")
            test_set = data
            train_set = data
        else:
            train_set, test_set, _, _ = make_mc_train_test(data, self.cdb,
                                                           test_size=test_size)

    # if print_stats > 0:
    #     fp, fn, tp, p, r, f1, cui_counts, examples = self._print_stats(
    #         test_set, use_project_filters=use_filters,
    #         use_cui_doc_limit=use_cui_doc_limit, use_overlaps=use_overlaps,
    #         use_groups=use_groups, extra_cui_filter=extra_cui_filter)
        if reset_cui_count:
            self._reset_cui_counts(train_set)

        # Remove entities that were terminated
        if not never_terminate:
            for ann in (ann for project in train_set['projects']
                        for doc in project['documents']
                        for ann in doc['annotations']):
                if ann.get('killed', False):
                    self.unlink_concept_name(ann['cui'], ann['value'], False)

        # latest_trained_step = (checkpoint.count if checkpoint is not None
        #                        else 0)
        # (current_epoch,
        #  current_project,
        #  current_document) = self._get_training_start(train_set,
        #                                               latest_trained_step)
        current_epoch = 0
        current_project = 0
        current_document = 0

        for epoch in trange(current_epoch, nepochs, initial=current_epoch,
                            total=nepochs, desc='Epoch', leave=False,
                            disable=disable_progress):
            self._perform_epoch(current_project, current_document, train_set,
                                disable_progress, extra_cui_filter,
                                use_filters, train_from_false_positives,
                                devalue_others, terminate_last,
                                never_terminate)

        # if print_stats > 0 and (epoch + 1) % print_stats == 0:
        #     fp, fn, tp, p, r, f1, cui_counts, examples = self._print_stats(
        #         test_set, epoch=epoch + 1, use_project_filters=use_filters,
        #         use_cui_doc_limit=use_cui_doc_limit,
        #         use_overlaps=use_overlaps, use_groups=use_groups,
        #         extra_cui_filter=extra_cui_filter)

        # # reset the state of filters
        # self.config.linking.filters = orig_filters

        if (train_addons and
                # NOTE if no annnotaitons, no point
                count_all_annotations(data) > 0):
            self._train_addons(data)

        return fp, fn, tp, p, r, f1, cui_counts, examples

train_unsupervised

train_unsupervised(data_iterator: Iterable[str], nepochs: int = 1, fine_tune: bool = True, progress_print: int = 1000) -> None

Runs training on the data, note that the maximum length of a line or document is 1M characters. Anything longer will be trimmed.

Parameters:

  • data_iterator

    (Iterable) –

    Simple iterator over sentences/documents, e.g. a open file or an array or anything that we can use in a for loop.

  • nepochs

    (int, default: 1 ) –

    Number of epochs for which to run the training.

  • fine_tune

    (bool, default: True ) –

    If False old training will be removed.

  • progress_print

    (int, default: 1000 ) –

    Print progress after N lines.

  • checkpoint

    (Optional[CheckpointUT]) –

    The MedCAT checkpoint object

  • is_resumed

    (bool) –

    If True resume the previous training; If False, start a fresh new training.

Source code in medcat-v2/medcat/trainer.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def train_unsupervised(self,
                       data_iterator: Iterable[str],
                       nepochs: int = 1,
                       fine_tune: bool = True,
                       progress_print: int = 1000,
                       #    checkpoint: Optional[Checkpoint] = None,
                       ) -> None:
    """Runs training on the data, note that the maximum length of a line
    or document is 1M characters. Anything longer will be trimmed.

    Args:
        data_iterator (Iterable):
            Simple iterator over sentences/documents, e.g. a open file
            or an array or anything that we can use in a for loop.
        nepochs (int):
            Number of epochs for which to run the training.
        fine_tune (bool):
            If False old training will be removed.
        progress_print (int):
            Print progress after N lines.
        checkpoint (Optional[medcat.utils.checkpoint.CheckpointUT]):
            The MedCAT checkpoint object
        is_resumed (bool):
            If True resume the previous training; If False, start a fresh
            new training.
    """
    with self.config.meta.prepare_and_report_training(
        data_iterator, nepochs, False
    ) as wrapped_iter:
        with temp_changed_config(self.config.components.linking,
                                 'train', True):
            self._train_unsupervised(wrapped_iter, nepochs, fine_tune,
                                     progress_print)
unlink_concept_name(cui: str, name: str, preprocessed_name: bool = False) -> None

Unlink a concept name from the CUI (or all CUIs if full_unlink), removes the link from the Concept Database (CDB). As a consequence medcat will never again link the name to this CUI - meaning the name will not be detected as a concept in the future.

Parameters:

  • (str) –

    The CUI from which the name will be removed.

  • (str) –

    The span of text to be removed from the linking dictionary.

  • (bool, default: False ) –

    Whether the name being used is preprocessed.

Examples:

>>> # To never again link C0020538 to HTN
>>> cat.unlink_concept_name('C0020538', 'htn', False)
Source code in medcat-v2/medcat/trainer.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def unlink_concept_name(self, cui: str, name: str,
                        preprocessed_name: bool = False) -> None:
    """Unlink a concept name from the CUI (or all CUIs if full_unlink),
    removes the link from the Concept Database (CDB). As a consequence
    medcat will never again link the `name` to this CUI - meaning the
    name will not be detected as a concept in the future.

    Args:
        cui (str):
            The CUI from which the `name` will be removed.
        name (str):
            The span of text to be removed from the linking dictionary.
        preprocessed_name (bool):
            Whether the name being used is preprocessed.

    Examples:

        >>> # To never again link C0020538 to HTN
        >>> cat.unlink_concept_name('C0020538', 'htn', False)
    """

    cuis = [cui]
    if preprocessed_name:
        names: dict[str, NameDescriptor] = {
            name: NameDescriptor([], set(), name, name.isupper())}
    else:
        names = prepare_name(name, self._pipeline.tokenizer, {},
                             self._pn_configs)

    # If full unlink find all CUIs
    if self.config.general.full_unlink:
        logger.warning("In the config `full_unlink` is set to `True`. "
                       "Thus removing all CUIs linked to the specified "
                       "name (%s)", name)
        for n in names:
            if n not in self.cdb.name2info:
                continue
            cuis.extend(self.cdb.name2info[n]['per_cui_status'].keys())

    # Remove name from all CUIs
    for c in cuis:
        self.cdb._remove_names(cui=c, names=names.keys())