Skip to content

medcat.components.addons.relation_extraction.rel_cat

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

BalancedBatchSampler

BalancedBatchSampler(dataset, classes, batch_size, max_samples, max_minority)

Bases: Sampler

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
126
127
128
129
130
131
132
133
134
135
136
def __init__(self, dataset, classes,
             batch_size, max_samples, max_minority):
    self.dataset = dataset
    self.classes = classes
    self.batch_size = batch_size
    self.num_classes = len(classes)
    self.indices = list(range(len(dataset)))

    self.max_minority = max_minority

    self.max_samples_per_class = max_samples

batch_size instance-attribute

batch_size = batch_size

classes instance-attribute

classes = classes

dataset instance-attribute

dataset = dataset

indices instance-attribute

indices = list(range(len(dataset)))

max_minority instance-attribute

max_minority = max_minority

max_samples_per_class instance-attribute

max_samples_per_class = max_samples

num_classes instance-attribute

num_classes = len(classes)

RelCAT

RelCAT(base_tokenizer: BaseTokenizer, cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(), task: str = 'train', init_model: bool = False)

The RelCAT class used for training 'Relation-Annotation' models, i.e., annotation of relations between clinical concepts.

Parameters:

  • cdb

    (CDB) –

    cdb, this is used when creating relation datasets.

  • tokenizer

    (TokenizerWrapperBERT) –

    The Huggingface tokenizer instance. This can be a pre-trained tokenzier instance from a BERT-style model. For now, only BERT models are supported.

  • config

    (ConfigRelCAT, default: ConfigRelCAT() ) –

    the configuration for RelCAT. Param descriptions available in ConfigRelCAT class docs.

  • task

    (str, default: 'train' ) –

    What task is this model supposed to handle. Defaults to "train"

  • init_model

    (bool, default: False ) –

    loads default model. Defaults to False.

Methods:

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def __init__(self, base_tokenizer: BaseTokenizer,
             cdb: CDB, config: ConfigRelCAT = ConfigRelCAT(),
             task: str = "train", init_model: bool = False):
    self.base_tokenizer = base_tokenizer
    self.component = RelExtrBaseComponent()
    self.task: str = task
    self.checkpoint_path: str = "./"

    set_all_seeds(config.general.seed)

    if init_model:
        self.component = RelExtrBaseComponent(
            config=config, task=task, init_model=True)

    self.cdb = cdb
    logging.basicConfig(
        level=self.component.relcat_config.general.log_level)
    logger.setLevel(self.component.relcat_config.general.log_level)

    self.is_cuda_available = torch.cuda.is_available()
    self.device = torch.device(
        "cuda" if self.is_cuda_available and
        self.component.relcat_config.general.device != "cpu" else "cpu")
    self._init_data_paths()

addon_type class-attribute instance-attribute

addon_type = 'rel_cat'

base_tokenizer instance-attribute

base_tokenizer = base_tokenizer

cdb instance-attribute

cdb = cdb

checkpoint_path instance-attribute

checkpoint_path: str = './'

component instance-attribute

component = RelExtrBaseComponent()

device instance-attribute

device = device('cuda' if is_cuda_available and device != 'cpu' else 'cpu')

is_cuda_available instance-attribute

is_cuda_available = is_available()

output_key class-attribute instance-attribute

output_key = 'rel_'

task instance-attribute

task: str = task

evaluate_

evaluate_(output_logits, labels, ignore_idx)
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
def evaluate_(self, output_logits, labels, ignore_idx):
    # ignore index (padding) when calculating accuracy
    idxs = (labels != ignore_idx).squeeze()
    labels_ = labels.squeeze()[idxs].to(self.device)
    pred_labels = torch.softmax(output_logits, dim=1).max(1)[1]
    pred_labels = pred_labels[idxs].to(self.device)

    true_labels = labels_.cpu().numpy().tolist(
    ) if labels_.is_cuda else labels_.numpy().tolist()
    pred_labels = pred_labels.cpu().numpy().tolist(
    ) if pred_labels.is_cuda else pred_labels.numpy().tolist()

    unique_labels = set(true_labels)

    batch_size = len(true_labels)

    stat_per_label = dict()

    total_tp, total_fp, total_tn, total_fn = 0, 0, 0, 0
    acc, micro_recall, micro_precision, micro_f1 = 0, 0, 0, 0

    for label in unique_labels:
        stat_per_label[label] = {
            "tp": 0, "fp": 0, "tn": 0, "fn": 0,
            "f1": 0.0, "acc": 0.0, "prec": 0.0, "recall": 0.0}

        for true_label_idx in range(len(true_labels)):
            if true_labels[true_label_idx] == label:
                if pred_labels[true_label_idx] == label:
                    stat_per_label[label]["tp"] += 1
                    total_tp += 1
                if pred_labels[true_label_idx] != label:
                    stat_per_label[label]["fp"] += 1
                    total_fp += 1
            elif (true_labels[true_label_idx] != label and
                  label == pred_labels[true_label_idx]):
                stat_per_label[label]["fn"] += 1
                total_fn += 1
            else:
                stat_per_label[label]["tn"] += 1
                total_tn += 1

        lbl_tp_tn = stat_per_label[label]["tn"] + \
            stat_per_label[label]["tp"]

        lbl_tp_fn = stat_per_label[label]["fn"] + \
            stat_per_label[label]["tp"]
        lbl_tp_fn = lbl_tp_fn if lbl_tp_fn > 0.0 else 1.0

        lbl_tp_fp = stat_per_label[label]["tp"] + \
            stat_per_label[label]["fp"]
        lbl_tp_fp = lbl_tp_fp if lbl_tp_fp > 0.0 else 1.0

        stat_per_label[label]["acc"] = lbl_tp_tn / batch_size
        stat_per_label[label]["prec"] = (stat_per_label[label]["tp"] /
                                         lbl_tp_fp)
        stat_per_label[label]["recall"] = (stat_per_label[label]["tp"] /
                                           lbl_tp_fn)

        lbl_re_pr = stat_per_label[label]["recall"] + \
            stat_per_label[label]["prec"]
        lbl_re_pr = lbl_re_pr if lbl_re_pr > 0.0 else 1.0

        stat_per_label[label]["f1"] = (
            2 * (stat_per_label[label]["recall"] *
                 stat_per_label[label]["prec"])) / lbl_re_pr

    tp_fn = total_fn + total_tp
    tp_fn = tp_fn if tp_fn > 0.0 else 1.0

    tp_fp = total_fp + total_tp
    tp_fp = tp_fp if tp_fp > 0.0 else 1.0

    micro_recall = total_tp / tp_fn
    micro_precision = total_tp / tp_fp

    re_pr = micro_recall + micro_precision
    re_pr = re_pr if re_pr > 0.0 else 1.0
    micro_f1 = (2 * (micro_recall * micro_precision)) / re_pr

    acc = total_tp / batch_size

    return (acc, micro_recall, micro_precision, micro_f1,
            pred_labels, true_labels, stat_per_label)

evaluate_results

evaluate_results(data_loader, pad_id)
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
def evaluate_results(self, data_loader, pad_id):
    logger.info("Evaluating test samples...")
    rc_cnf = self.component.relcat_config
    if (rc_cnf.train.class_weights is not None and
            rc_cnf.train.enable_class_weights):
        criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(
            rc_cnf.train.class_weights).to(self.device))
    else:
        criterion = nn.CrossEntropyLoss()

    total_loss, total_acc, total_f1, total_recall, total_precision = (
        0.0, 0.0, 0.0, 0.0, 0.0)
    all_batch_stats_per_label = []

    self.component.model.eval()

    for i, data in enumerate(data_loader):
        with torch.no_grad():
            token_ids, e1_e2_start, labels, _, _ = data
            attention_mask = (token_ids != pad_id).float().to(self.device)
            token_type_ids = torch.zeros(
                (*token_ids.shape[:2],)).long().to(self.device)

            labels = labels.to(self.device)

            model_output, pred_classification_logits = (
                self.component.model(token_ids,
                                     token_type_ids=token_type_ids,
                                     attention_mask=attention_mask,
                                     Q=None,
                                     e1_e2_start=e1_e2_start))

            batch_loss = criterion(pred_classification_logits.view(
                -1, rc_cnf.train.nclasses).to(self.device),
                labels.squeeze(1))
            total_loss += batch_loss.item()

            (batch_accuracy, batch_recall, batch_precision, batch_f1,
             pred_labels, true_labels, batch_stats_per_label) = (
                self.evaluate_(pred_classification_logits,
                               labels, ignore_idx=-1))

            all_batch_stats_per_label.append(batch_stats_per_label)

            total_acc += batch_accuracy
            total_recall += batch_recall
            total_precision += batch_precision
            total_f1 += batch_f1

    final_stats_per_label = {}

    for batch_label_stats in all_batch_stats_per_label:
        for label_id, stat_dict in batch_label_stats.items():

            if label_id not in final_stats_per_label.keys():
                final_stats_per_label[label_id] = stat_dict
            else:
                for stat, score in stat_dict.items():
                    final_stats_per_label[label_id][stat] += score

    for label_id, stat_dict in final_stats_per_label.items():
        for stat_name, value in stat_dict.items():
            final_stats_per_label[label_id][stat_name] = value / (i + 1)

    total_loss = total_loss / (i + 1)
    total_acc = total_acc / (i + 1)
    total_precision = total_precision / (i + 1)
    total_f1 = total_f1 / (i + 1)
    total_recall = total_recall / (i + 1)

    results = {
        "loss": total_loss,
        "accuracy": total_acc,
        "precision": total_precision,
        "recall": total_recall,
        "f1": total_f1
    }

    logger.info("=" * 20 + " Evaluation Results " + "=" * 20)
    logger.info(" no. of batches:" + str(i + 1))
    for key in sorted(results.keys()):
        logger.info(" %s = %0.3f" % (key, results[key]))
    logger.info("-" * 23 + " class stats " + "-" * 23)
    for label_id, stat_dict in final_stats_per_label.items():
        logger.info(
            "label: %s | f1: %0.3f | prec : %0.3f | acc: %0.3f | "
            "recall: %0.3f ",
            rc_cnf.general.idx2labels[label_id],
            stat_dict["f1"],
            stat_dict["prec"],
            stat_dict["acc"],
            stat_dict["recall"]
        )
    logger.info("-" * 59)
    logger.info("=" * 59)

    return results

load classmethod

load(load_path: str = './') -> RelCAT
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
@classmethod
def load(cls, load_path: str = "./") -> "RelCAT":

    if os.path.exists(os.path.join(load_path, "cdb.dat")):
        cdb = cast(CDB, deserialise(os.path.join(load_path, "cdb.dat")))
    else:
        cdb = CDB(config=Config())
        logger.info(
            "The default CDB file name 'cdb.dat' doesn't exist in the "
            "specified path, you will need to load & set "
            "a CDB manually via rel_cat.cdb = CDB.load('path') ")

    component = RelExtrBaseComponent.load(
        pretrained_model_name_or_path=load_path)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and
        component.relcat_config.general.device != "cpu" else "cpu")

    rel_cat = RelCAT(
        # NOTE: this is a throaway tokenizer just for registrations
        create_tokenizer(cdb.config.general.nlp.provider, cdb.config),
        cdb=cdb, config=component.relcat_config, task=component.task)
    rel_cat.device = device
    rel_cat.component = component

    return rel_cat

pipe

pipe(stream: Iterable[MutableDocument], *args, **kwargs) -> Iterator[MutableDocument]
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def pipe(self, stream: Iterable[MutableDocument], *args, **kwargs
         ) -> Iterator[MutableDocument]:
    rc_cnf = self.component.relcat_config

    predict_rel_dataset = RelData(
        cdb=self.cdb, config=rc_cnf,
        tokenizer=self.component.tokenizer)

    self.component.model = self.component.model.to(self.device)

    for doc_id, doc in enumerate(stream, 0):
        predict_rel_dataset.dataset, _ = self._create_test_train_datasets(
            data=predict_rel_dataset.create_base_relations_from_doc(
                doc, doc_id=str(doc_id)),
            split_sets=False)

        predict_dataloader = DataLoader(
            dataset=predict_rel_dataset, shuffle=False,
            batch_size=rc_cnf.train.batch_size,
            num_workers=0, collate_fn=self.component.padding_seq,
            pin_memory=rc_cnf.general.pin_memory)

        total_rel_found = len(
            predict_rel_dataset.dataset["output_relations"])
        rel_idx = -1

        logger.info("total relations for doc: " + str(total_rel_found))
        logger.info("processing...")

        pbar = tqdm(total=total_rel_found)

        for i, data in enumerate(predict_dataloader):
            with torch.no_grad():
                token_ids, e1_e2_start, labels, _, _ = data

                attention_mask = (
                    token_ids != self.component.pad_id
                    ).float().to(self.device)
                token_type_ids = torch.zeros(
                    *token_ids.shape[:2]).long().to(self.device)

                (model_output,
                 pred_classification_logits) = self.component.model(
                    token_ids, token_type_ids=token_type_ids,
                    attention_mask=attention_mask,
                    e1_e2_start=e1_e2_start)

                for i, pred_rel_logits in enumerate(
                        pred_classification_logits):
                    rel_idx += 1

                    confidence = torch.softmax(
                        pred_rel_logits, dim=0).max(0)
                    predicted_label_id = int(confidence[1].item())

                    relations: list = doc.get_addon_data(  # type: ignore
                        "relations")
                    out_rels = predict_rel_dataset.dataset[
                        "output_relations"][rel_idx]
                    relations.append(
                        {
                            "relation": rc_cnf.general.idx2labels[
                                predicted_label_id],
                            "label_id": predicted_label_id,
                            "ent1_text": out_rels[2],
                            "ent2_text": out_rels[3],
                            "confidence": float("{:.3f}".format(
                                confidence[0])),
                            "start_ent1_char_pos": out_rels[18],
                            "end_ent1_char_pos": out_rels[19],
                            "start_ent2_char_pos": out_rels[20],
                            "end_ent2_char_pos": out_rels[21],
                            "start_entity_id": out_rels[8],
                            "end_entity_id": out_rels[9],
                        })
                pbar.update(len(token_ids))
        pbar.close()

        yield doc

predict_text_with_anns

predict_text_with_anns(text: str, annotations: list[dict]) -> MutableDocument

Creates spacy doc from text and annotation input. Predicts using self.call

Parameters:

  • text

    (str) –

    text

  • annotations

    (dict) –

    dict containing the entities from NER (of your choosing), the format must be the following format: [ { "cui": "202099003", -this is optional "value": "discoid lateral meniscus", "start": 294, "end": 318 }, { "cui": "202099003", "value": "Discoid lateral meniscus", "start": 1905, "end": 1929, } ]

Returns:

Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
def predict_text_with_anns(self, text: str, annotations: list[dict]
                           ) -> MutableDocument:
    """ Creates spacy doc from text and annotation input.
    Predicts using self.__call__

    Args:
        text (str): text
        annotations (dict): dict containing the entities from NER
            (of your choosing), the format must be the following format:
                [
                    {
                        "cui": "202099003", -this is optional
                        "value": "discoid lateral meniscus",
                        "start": 294,
                        "end": 318
                    },
                    {
                        "cui": "202099003",
                        "value": "Discoid lateral meniscus",
                        "start": 1905,
                        "end": 1929,
                    }
                ]

    Returns:
        Doc: spacy doc with the relations.
    """
    # NOTE: This runs not an empty language, but the specified one
    base_tokenizer = create_tokenizer(
        self.cdb.config.general.nlp.provider, self.cdb.config)
    doc = base_tokenizer(text)

    for ann in annotations:
        tkn_idx = []
        for ind, word in enumerate(doc):
            end_char = word.base.char_index + len(word.base.text)
            if end_char <= ann['end'] and end_char > ann['start']:
                tkn_idx.append(ind)
        entity = base_tokenizer.create_entity(
            doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"])
        entity.cui = ann["cui"]
        entity.set_addon_data('start', ann['start'])
        entity.set_addon_data('end', ann['end'])
        doc.ner_ents.append(entity)

    doc = self(doc)

    return doc

save

save(save_path: str = './') -> None
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
219
220
def save(self, save_path: str = "./") -> None:
    self.component.save(save_path=save_path)

train

train(export_data_path: str = '', train_csv_path: str = '', test_csv_path: str = '', checkpoint_path: str = './')
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def train(self, export_data_path: str = "", train_csv_path: str = "",
          test_csv_path: str = "", checkpoint_path: str = "./"):

    if self.is_cuda_available:
        logger.info("Training on device: %s%s",
                    str(torch.cuda.get_device_name(0)), str(self.device))

    self.component.model = self.component.model.to(self.device)

    rc_cnf = self.component.relcat_config

    # resize vocab just in case more tokens have been added
    self.component.model_config.vocab_size = (
        self.component.tokenizer.get_size())

    train_rel_data = RelData(
        cdb=self.cdb, config=rc_cnf,
        tokenizer=self.component.tokenizer)
    test_rel_data = RelData(
        cdb=self.cdb, config=rc_cnf,
        tokenizer=self.component.tokenizer)

    if train_csv_path != "":
        if test_csv_path != "":
            train_rel_data.dataset, _ = self._create_test_train_datasets(
                train_rel_data.create_base_relations_from_csv(
                    train_csv_path), split_sets=False)
            test_rel_data.dataset, _ = self._create_test_train_datasets(
                train_rel_data.create_base_relations_from_csv(
                    test_csv_path), split_sets=False)
        else:
            (train_rel_data.dataset,
             test_rel_data.dataset) = self._create_test_train_datasets(
                train_rel_data.create_base_relations_from_csv(
                    train_csv_path), split_sets=True)

    elif export_data_path != "":
        export_data = {}
        with open(export_data_path) as f:
            export_data = json.load(f)
        (train_rel_data.dataset,
         test_rel_data.dataset) = self._create_test_train_datasets(
            train_rel_data.create_relations_from_export(export_data),
            split_sets=True)
    else:
        raise ValueError(
            "NO DATA HAS BEEN PROVIDED (MedCAT Trainer export "
            "JSON/CSV/spacy_DOCS)")

    train_dataset_size = len(train_rel_data)
    batch_size = (
        train_dataset_size if train_dataset_size < rc_cnf.train.batch_size
        else rc_cnf.train.batch_size)

    # to use stratified batching
    if rc_cnf.train.stratified_batching:
        sampler = BalancedBatchSampler(
            train_rel_data, [
                i for i in
                range(rc_cnf.train.nclasses)],
            batch_size,
            rc_cnf.train.batching_samples_per_class,
            rc_cnf.train.batching_minority_limit)

        train_dataloader = DataLoader(
            train_rel_data, num_workers=0,
            collate_fn=self.component.padding_seq,
            batch_sampler=sampler,
            pin_memory=rc_cnf.general.pin_memory)
    else:
        train_dataloader = DataLoader(
            train_rel_data, batch_size=batch_size,
            shuffle=rc_cnf.train.shuffle_data,
            num_workers=0,
            collate_fn=self.component.padding_seq,
            pin_memory=rc_cnf.general.pin_memory)

    test_dataset_size = len(test_rel_data)
    test_batch_size = (
        test_dataset_size if
        test_dataset_size < rc_cnf.train.batch_size
        else rc_cnf.train.batch_size)
    test_dataloader = DataLoader(
        test_rel_data,
        batch_size=test_batch_size,
        shuffle=rc_cnf.train.shuffle_data,
        num_workers=0,
        collate_fn=self.component.padding_seq,
        pin_memory=rc_cnf.general.pin_memory)

    if (rc_cnf.train.class_weights is not None and
            rc_cnf.train.enable_class_weights):
        criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(
            numpy.asarray(rc_cnf.train.class_weights)
            ).to(self.device))
    elif rc_cnf.train.enable_class_weights:
        all_class_lbl_ids = [
            rec[5] for rec in train_rel_data.dataset["output_relations"]]
        rc_cnf.train.class_weights = (
            compute_class_weight(class_weight="balanced",
                                 classes=numpy.unique(all_class_lbl_ids),
                                 y=all_class_lbl_ids).tolist())
        criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(
            rc_cnf.train.class_weights).to(
                self.device))
    else:
        criterion = nn.CrossEntropyLoss()

    if self.component.optimizer is None:
        parameters = filter(lambda p: p.requires_grad,
                            self.component.model.parameters())
        self.component.optimizer = AdamW(
            parameters, lr=self.component.relcat_config.train.lr,
            weight_decay=rc_cnf.train.adam_weight_decay,
            betas=rc_cnf.train.adam_betas, eps=rc_cnf.train.adam_epsilon)

    if self.component.scheduler is None:
        self.component.scheduler = MultiStepLR(
            self.component.optimizer,
            milestones=rc_cnf.train.multistep_milestones,
            gamma=rc_cnf.train.multistep_lr_gamma)

    self.epoch, self.best_f1 = load_state(
        self.component.model, self.component.optimizer,
        self.component.scheduler, load_best=False, path=checkpoint_path,
        relcat_config=rc_cnf)

    logger.info("Starting training process...")

    losses_per_epoch, accuracy_per_epoch, f1_per_epoch = load_results(
        path=checkpoint_path)

    if train_rel_data.dataset["nclasses"] > rc_cnf.train.nclasses:
        rc_cnf.train.nclasses = train_rel_data.dataset["nclasses"]
        self.component.model.relcat_config.train.nclasses = (
            rc_cnf.train.nclasses)

    rc_cnf.general.labels2idx.update(train_rel_data.dataset["labels2idx"])
    rc_cnf.general.idx2labels = {
        int(v): k for k, v in rc_cnf.general.labels2idx.items()}

    gradient_acc_steps = (
        rc_cnf.train.gradient_acc_steps)
    max_grad_norm = rc_cnf.train.max_grad_norm

    _epochs = self.epoch + rc_cnf.train.nepochs

    for epoch in range(0, _epochs):
        epoch_losses, epoch_precision, epoch_f1 = self._train_epoch(
            epoch, gradient_acc_steps, max_grad_norm, train_dataset_size,
            train_dataloader, test_dataloader, criterion, _epochs,
            checkpoint_path)
        losses_per_epoch.extend(epoch_losses)
        accuracy_per_epoch.extend(epoch_precision)
        f1_per_epoch.extend(epoch_f1)

RelCATAddon

RelCATAddon(config: ConfigRelCAT, rel_cat: RelCAT)

Bases: AddonComponent

Methods:

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
46
47
48
49
def __init__(self, config: ConfigRelCAT,
             rel_cat: "RelCAT"):
    self.config = config
    self._rel_cat = rel_cat

addon_type class-attribute instance-attribute

addon_type = 'rel_cat'

config instance-attribute

config: ConfigRelCAT = config

name property

name: str

output_key class-attribute instance-attribute

output_key = 'relations'

create_new classmethod

create_new(config: ConfigRelCAT, base_tokenizer: BaseTokenizer, cdb: CDB) -> RelCATAddon

Factory method to create a new MetaCATAddon instance.

Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
51
52
53
54
55
56
@classmethod
def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer,
               cdb: CDB) -> 'RelCATAddon':
    """Factory method to create a new MetaCATAddon instance."""
    return cls(config,
               RelCAT(base_tokenizer, cdb, config=config, init_model=True))

create_new_component classmethod

create_new_component(cnf: ComponentConfig, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> RelCATAddon
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
58
59
60
61
62
63
64
65
66
67
68
69
70
@classmethod
def create_new_component(
        cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
        cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
        ) -> 'RelCATAddon':
    if not isinstance(cnf, ConfigRelCAT):
        raise ValueError(f"Incompatible config: {cnf}")
    config = cnf
    if model_load_path is not None:
        load_path = os.path.join(model_load_path, COMPONENTS_FOLDER,
                                 cls.NAME_PREFIX + cls.addon_type)
        return cls.load_existing(config, tokenizer, cdb, load_path)
    return cls.create_new(config, tokenizer, cdb)

deserialise_from classmethod

deserialise_from(folder_path: str, **init_kwargs) -> RelCATAddon
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
 94
 95
 96
 97
 98
 99
100
101
102
103
@classmethod
def deserialise_from(cls, folder_path: str, **init_kwargs
                     ) -> 'RelCATAddon':
    # NOTE: model load path sent by kwargs
    return cls.load_existing(
        load_path=folder_path,
        base_tokenizer=init_kwargs['tokenizer'],
        cnf=init_kwargs['cnf'],
        cdb=init_kwargs['cdb'],
    )

get_init_attrs classmethod

get_init_attrs() -> list[str]
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
108
109
110
@classmethod
def get_init_attrs(cls) -> list[str]:
    return []

get_strategy

get_strategy() -> SerialisingStrategy
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
105
106
def get_strategy(self) -> SerialisingStrategy:
    return SerialisingStrategy.MANUAL

ignore_attrs classmethod

ignore_attrs() -> list[str]
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
112
113
114
@classmethod
def ignore_attrs(cls) -> list[str]:
    return []

include_properties classmethod

include_properties() -> list[str]
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
116
117
118
@classmethod
def include_properties(cls) -> list[str]:
    return []

load_existing classmethod

load_existing(cnf: ConfigRelCAT, base_tokenizer: BaseTokenizer, cdb: CDB, load_path: str) -> RelCATAddon

Factory method to load an existing RelCAT addon from disk.

Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
72
73
74
75
76
77
78
79
80
81
82
@classmethod
def load_existing(cls, cnf: ConfigRelCAT,
                  base_tokenizer: BaseTokenizer,
                  cdb: CDB,
                  load_path: str) -> 'RelCATAddon':
    """Factory method to load an existing RelCAT addon from disk."""
    rc = RelCAT.load(load_path)
    # set the correct base tokenizer and redo data paths
    rc.base_tokenizer = base_tokenizer
    rc._init_data_paths()
    return cls(cnf, rc)

serialise_to

serialise_to(folder_path: str) -> None
Source code in medcat-v2/medcat/components/addons/relation_extraction/rel_cat.py
84
85
86
def serialise_to(self, folder_path: str) -> None:
    os.mkdir(folder_path)
    self._rel_cat.save(folder_path)