Skip to content

medcat.components.ner.trf.transformers_ner

Classes:

Functions:

Attributes:

TrCBCreator module-attribute

TrCBCreator = Callable[[Trainer], TrainerCallback]

logger module-attribute

logger = getLogger(__name__)

TransformersNER

TransformersNER(cdb: CDB, base_tokenizer: BaseTokenizer, component: TransformersNERComponent, config: Optional[ConfigTransformersNER] = None, training_arguments=None)

Bases: AbstractEntityProvidingComponent

Methods:

Attributes:

Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
52
53
54
55
56
57
58
def __init__(self, cdb: CDB,
             base_tokenizer: BaseTokenizer,
             component: 'TransformersNERComponent',
             config: Optional[ConfigTransformersNER] = None,
             training_arguments=None,) -> None:
    super().__init__(write_to_linked_ents=True)
    self._component = component

name class-attribute instance-attribute

name = 'transformers_ner'

should_save property

should_save: bool

create_new classmethod

create_new(cdb: CDB, base_tokenizer: BaseTokenizer, config: Optional[ConfigTransformersNER] = None, training_arguments=None) -> TransformersNER
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
60
61
62
63
64
65
66
67
68
@classmethod
def create_new(cls, cdb: CDB, base_tokenizer: BaseTokenizer,
               config: Optional[ConfigTransformersNER] = None,
               training_arguments=None) -> 'TransformersNER':
    comp = TransformersNERComponent(
            cdb, base_tokenizer, config, training_arguments)
    return cls(cdb=cdb, base_tokenizer=base_tokenizer,
               config=config, training_arguments=training_arguments,
               component=comp)

create_new_component classmethod

create_new_component(cnf: ComponentConfig, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> TransformersNER
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@classmethod
def create_new_component(
        cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
        cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
        ) -> 'TransformersNER':
    config = cdb.config.components.ner.custom_cnf
    if not isinstance(config, ConfigTransformersNER):
        raise ValueError(
            "Did not find correct Transformers NER config. "
            f"Found: {config}")
    # TODO: anywhere to get these?
    training_arguments = None
    if model_load_path is not None:
        load_path = os.path.join(
            model_load_path, COMPONENTS_FOLDER, cls.NAME_PREFIX + "ner")
        return cls.load_existing(cdb, tokenizer, load_path,
                                 training_arguments, config)
    return cls.create_new(cdb, tokenizer, config, training_arguments)

deserialise_from classmethod

deserialise_from(folder_path: str, **init_kwargs) -> TransformersNER
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
127
128
129
130
131
132
133
134
135
@classmethod
def deserialise_from(cls, folder_path: str, **init_kwargs
                     ) -> 'TransformersNER':
    return cls.load_existing(
        load_path=folder_path,
        cdb=init_kwargs['cdb'],
        base_tokenizer=init_kwargs['tokenizer'],
        # from Config.components.ner (of type Ner)
        config=init_kwargs['cnf'].custom_cnf)

get_folder_name

get_folder_name() -> str
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
121
122
def get_folder_name(self) -> str:
    return self.NAME_PREFIX + self.get_type().name

get_init_attrs classmethod

get_init_attrs() -> list[str]
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
140
141
142
@classmethod
def get_init_attrs(cls) -> list[str]:
    return []

get_strategy

get_strategy() -> SerialisingStrategy
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
137
138
def get_strategy(self) -> SerialisingStrategy:
    return SerialisingStrategy.MANUAL

get_type

get_type()
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
 99
100
def get_type(self):
    return CoreComponentType.ner

ignore_attrs classmethod

ignore_attrs() -> list[str]
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
144
145
146
@classmethod
def ignore_attrs(cls) -> list[str]:
    return []

include_properties classmethod

include_properties() -> list[str]
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
148
149
150
@classmethod
def include_properties(cls) -> list[str]:
    return []

load_existing classmethod

load_existing(cdb: CDB, base_tokenizer: BaseTokenizer, load_path: str, training_arguments=None, config: Optional[ConfigTransformersNER] = None) -> TransformersNER
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
89
90
91
92
93
94
95
96
97
@classmethod
def load_existing(cls, cdb: CDB, base_tokenizer: BaseTokenizer,
                  load_path: str, training_arguments=None,
                  config: Optional[ConfigTransformersNER] = None,
                  ) -> 'TransformersNER':
    comp = _load_component(cdb, load_path, base_tokenizer)
    return cls(cdb=cdb, base_tokenizer=base_tokenizer,
               config=config, training_arguments=training_arguments,
               component=comp)

predict_entities

predict_entities(doc: MutableDocument, ents: list[MutableEntity] | None = None) -> list[MutableEntity]
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
111
112
113
114
115
116
117
def predict_entities(self, doc: MutableDocument,
                     ents: list[MutableEntity] | None = None
                     ) -> list[MutableEntity]:
    if ents:
        raise ValueError(
            "This method should ne be called with pre-defined entities")
    return self._component(doc)[1]

save

save(folder: str, overwrite: bool = False) -> None
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
106
107
108
109
def save(self, folder: str, overwrite: bool = False) -> None:
    _save_component(self._component,
                    folder, serialiser=self._def_serialiser,
                    overwrite=overwrite)

serialise_to

serialise_to(folder_path: str) -> None
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
124
125
def serialise_to(self, folder_path: str) -> None:
    self.save(folder_path)

TransformersNERComponent

TransformersNERComponent(cdb: CDB, base_tokenizer: BaseTokenizer, config: Optional[ConfigTransformersNER] = None, training_arguments=None)

TODO: Add documentation

Methods:

Attributes:

Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def __init__(self, cdb: CDB,
             base_tokenizer: BaseTokenizer,
             config: Optional[ConfigTransformersNER] = None,
             training_arguments=None) -> None:
    self.base_tokenizer = base_tokenizer
    self.cdb = cdb
    if config is None:
        cnf = cdb.config.components.ner.custom_cnf
        if cnf is not None and isinstance(cnf, ConfigTransformersNER):
            config = cnf
        else:
            config = ConfigTransformersNER()

    self.config = config
    set_all_seeds(config.general.seed)

    self.model = AutoModelForTokenClassification.from_pretrained(
        config.general.model_name)

    # Get the tokenizer either create a new one or load existing
    if os.path.exists(os.path.join(
            config.general.model_name, 'tokenizer.dat')):
        self.tokenizer = TransformersTokenizer.load(
            os.path.join(config.general.model_name, 'tokenizer.dat'))
    else:
        hf_tokenizer = AutoTokenizer.from_pretrained(
            self.config.general.model_name)
        self.tokenizer = TransformersTokenizer(hf_tokenizer)

    if training_arguments is None:
        self.training_arguments = TrainingArguments(
            output_dir='./results',
            # directory for storing logs
            logging_dir='./logs',
            # total number of training epochs
            num_train_epochs=10,
            # batch size per device during training
            per_device_train_batch_size=1,
            # batch size for evaluation
            per_device_eval_batch_size=1,
            # strength of weight decay
            weight_decay=0.14,
            warmup_ratio=0.01,
            # Should be smaller when finetuning an existing deid model
            learning_rate=4.47e-05,
            eval_accumulation_steps=1,
            # We want to get to bs=4
            gradient_accumulation_steps=4,
            do_eval=True,
            # eval_strategy since transformers==4.41
            eval_strategy='epoch',
            logging_strategy='epoch',     # type: ignore
            save_strategy='epoch',        # type: ignore
            # Can be changed if our preference is not recall but precision
            # or f1
            metric_for_best_model='eval_recall',
            load_best_model_at_end=True,
            remove_unused_columns=False)
    else:
        self.training_arguments = training_arguments

base_tokenizer instance-attribute

base_tokenizer = base_tokenizer

cdb instance-attribute

cdb = cdb

config instance-attribute

config = config

model instance-attribute

model = from_pretrained(model_name)

tokenizer instance-attribute

tokenizer = load(join(model_name, 'tokenizer.dat'))

training_arguments instance-attribute

training_arguments = TrainingArguments(output_dir='./results', logging_dir='./logs', num_train_epochs=10, per_device_train_batch_size=1, per_device_eval_batch_size=1, weight_decay=0.14, warmup_ratio=0.01, learning_rate=4.47e-05, eval_accumulation_steps=1, gradient_accumulation_steps=4, do_eval=True, eval_strategy='epoch', logging_strategy='epoch', save_strategy='epoch', metric_for_best_model='eval_recall', load_best_model_at_end=True, remove_unused_columns=False)

batch_generator staticmethod

batch_generator(stream: Iterable[MutableDocument], batch_size_chars: int) -> Iterable[list[MutableDocument]]
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
@staticmethod
def batch_generator(stream: Iterable[MutableDocument],
                    batch_size_chars: int
                    ) -> Iterable[list[MutableDocument]]:
    docs = []
    char_count = 0
    for doc in stream:
        char_count += len(doc.base.text)
        docs.append(doc)
        if char_count < batch_size_chars:
            continue
        yield docs
        docs = []
        char_count = 0

    # If there is anything left return that also
    if len(docs) > 0:
        yield docs

create_eval_pipeline

create_eval_pipeline()
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def create_eval_pipeline(self):

    if self.config.general.chunking_overlap_window is None:
        logger.warning(
            "Chunking overlap window attribute in the config is set to "
            "None, hence chunking is disabled. Be cautious, PII data MAY "
            "BE REVEALED. To enable chunking, set the value to 0 or above")
    self.ner_pipe = pipeline(
        model=self.model, task="ner",
        tokenizer=self.tokenizer.hf_tokenizer,
        stride=self.config.general.chunking_overlap_window)
    if not hasattr(self.ner_pipe.tokenizer, '_in_target_context_manager'):
        # NOTE: this will fix the DeID model(s) created before medcat 1.9.3
        #       though this fix may very well be unstable
        self.ner_pipe.tokenizer._in_target_context_manager = False
    if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'):
        # NOTE: this will fix the DeID model(s) created with transformers
        #       before 4.42 and allow them to run with later transformers
        self.ner_pipe.tokenizer.split_special_tokens = False
    if (not hasattr(self.ner_pipe.tokenizer, 'pad_token') and
            hasattr(self.ner_pipe.tokenizer, '_pad_token')):
        # NOTE: This will fix the DeID model(s) created with transformers
        #       before 4.47 and allow them to run with later transformmers
        #       versions
        #       In 4.47 the special tokens started to be used differently,
        #       yet our saved model is not aware of that. So we need to
        #       explicitly fix that.
        special_tokens_map = self.ner_pipe.tokenizer.__dict__.get(
            '_special_tokens_map', {})
        for name in self.ner_pipe.tokenizer.SPECIAL_TOKENS_ATTRIBUTES:
            # previously saved in (e.g) _pad_token
            prev_val = getattr(self.ner_pipe.tokenizer, f"_{name}")
            # now saved in the special tokens map by its name
            special_tokens_map[name] = prev_val
        # the map is saved in __dict__ explicitly, and
        # it is later used in __getattr__ of the base class.
        self.ner_pipe.tokenizer.__dict__[
            '_special_tokens_map'] = special_tokens_map

    self.ner_pipe.device = self.model.device

eval

eval(json_path: Union[str, list, None] = None, dataset=None, ignore_extra_labels=False, meta_requirements=None)
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def eval(self, json_path: Union[str, list, None] = None, dataset=None,
         ignore_extra_labels=False, meta_requirements=None):
    if dataset is None:
        json_path = self._prepare_dataset(
            json_path, ignore_extra_labels=ignore_extra_labels,
            meta_requirements=meta_requirements,
            file_name='data_eval.json')
        # Load dataset
        dataset = datasets.load_dataset(
            os.path.abspath(transformers_ner.__file__),
            data_files={'train': json_path},  # type: ignore
            split='train',
            cache_dir='/tmp/')

    # Encode dataset
    # Note: tokenizer.encode performs chunking
    encoded_dataset = dataset.map(
            lambda examples: self.tokenizer.encode(
                examples, ignore_subwords=False),
            batched=True,
            remove_columns=['ent_cuis', 'ent_ends', 'ent_starts', 'text'])

    data_collator = CollateAndPadNER(
        self.tokenizer.hf_tokenizer.pad_token_id)  # type: ignore
    # TODO: switch from trainer to model prediction
    trainer = Trainer(
            model=self.model,
            args=self.training_arguments,
            train_dataset=None,
            eval_dataset=encoded_dataset,  # type: ignore
            compute_metrics=None,
            data_collator=data_collator,  # type: ignore
            tokenizer=None)

    # Run an eval step and return metrics
    p = trainer.predict(encoded_dataset)  # type: ignore
    df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer,
                           dataset=encoded_dataset)

    return df, examples, dataset

expand_model_with_concepts

expand_model_with_concepts(cui2preferred_name: dict[str, str], use_avg_init: bool = True) -> None

Expand the model with new concepts and their preferred names, which requires subsequent retraining on the model.

Parameters:

  • cui2preferred_name

    (Dict[str, str]) –

    Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name.

  • use_avg_init

    (bool, default: True ) –

    Whether to use the average of existing weights or biases as the initial value for the new concept. Defaults to True.

Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def expand_model_with_concepts(self, cui2preferred_name: dict[str, str],
                               use_avg_init: bool = True) -> None:
    """Expand the model with new concepts and their preferred names, which
    requires subsequent retraining on the model.

    Args:
        cui2preferred_name(Dict[str, str]):
            Dictionary where each key is the literal ID of the concept to
            be added and each value is its preferred name.
        use_avg_init(bool):
            Whether to use the average of existing weights or biases as
            the initial value for the new concept. Defaults to True.
    """

    avg_weight = torch.mean(self.model.classifier.weight, dim=0,
                            keepdim=True)
    avg_bias = torch.mean(self.model.classifier.bias, dim=0, keepdim=True)

    new_cuis = set()
    for label, preferred_name in cui2preferred_name.items():
        if label in self.model.config.label2id.keys():
            logger.warning(
                "Concept ID '%s' already exists in the model, skipping...",
                label)
            continue

        sname = preferred_name.lower().replace(" ", "~")
        new_names = {
            sname: NameDescriptor(
                tokens=[],
                snames={sname},
                raw_name=preferred_name,
                is_upper=True
            )
        }
        self.cdb.add_names(
            cui=label, names=new_names, name_status="P", full_build=True)

        new_label_id = sorted(self.model.config.label2id.values())[-1] + 1
        self.model.config.label2id[label] = new_label_id
        self.model.config.id2label[new_label_id] = label
        self.tokenizer.label_map[label] = new_label_id
        self.tokenizer.cui2name = {k: self.cdb.get_name(k) for
                                   k in self.tokenizer.label_map.keys()}

        if use_avg_init:
            self.model.classifier.weight = torch.nn.Parameter(
                torch.cat((self.model.classifier.weight, avg_weight), 0)
            )
            self.model.classifier.bias = torch.nn.Parameter(
                torch.cat((self.model.classifier.bias, avg_bias), 0)
            )
        else:
            self.model.classifier.weight = torch.nn.Parameter(
                torch.cat((self.model.classifier.weight, torch.randn(
                    1, self.model.config.hidden_size)), 0)
            )
            self.model.classifier.bias = torch.nn.Parameter(
                torch.cat((self.model.classifier.bias, torch.randn(1)), 0)
            )
        self.model.num_labels += 1
        self.model.classifier.out_features += 1

        new_cuis.add(label)

    logger.info("Model expanded with the new concept(s): %s and shall be "
                "retrained before use.", str(new_cuis))

get_hash

get_hash() -> str

A partial hash trying to catch differences between models.

Returns:

  • str ( str ) –

    The hex hash.

Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
336
337
338
339
340
341
342
343
344
345
346
347
348
def get_hash(self) -> str:
    """A partial hash trying to catch differences between models.

    Returns:
        str: The hex hash.
    """
    hasher = Hasher()
    # Set last_train_on if None
    if self.config.general.last_train_on is None:
        self.config.general.last_train_on = datetime.now().timestamp()

    hasher.update(self.config.get_hash())
    return hasher.hexdigest()

pipe

Process many documents at once.

Parameters:

  • stream

    (Iterable[MutableDocument]) –

    List of documents.

  • *args

    Extra arguments (not used here).

  • **kwargs

    Extra keyword arguments (not used here).

Yields:

Returns:

Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
def pipe(self, stream: Iterable[Union[MutableDocument, None]],
         *args, **kwargs) -> Iterator[tuple[MutableDocument,
                                            list[MutableEntity]]]:
    """Process many documents at once.

    Args:
        stream (Iterable[MutableDocument]):
            List of documents.
        *args: Extra arguments (not used here).
        **kwargs: Extra keyword arguments (not used here).

    Yields:
        Doc: The same document.

    Returns:
        Iterator[tuple[MutableDocument, list[MutableEntity]]]: The stream
            of documents and entities
    """
    # Just in case
    if stream is None or not stream:
        # return an empty generator
        return

    batch_size_chars = self.config.general.pipe_batch_size_in_chars
    yield from self._process(stream, batch_size_chars)  # type: ignore

train

train(json_path: Union[str, list, None] = None, ignore_extra_labels=False, dataset=None, meta_requirements=None, train_json_path: Union[str, list, None] = None, test_json_path: Union[str, list, None] = None, trainer_callbacks: Optional[list[TrCBCreator]] = None) -> tuple

Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new.

Parameters:

  • json_path

    (str or list, default: None ) –

    Path/Paths to a MedCATtrainer export containing the meta_annotations we want to train for.

  • ignore_extra_labels

    Makes only sense when an existing deid model was loaded and from the new data we want to ignore labels that did not exist in the old model.

  • dataset

    Defaults to None.

  • meta_requirements

    Defaults to None

  • train_json_path

    (Union[str, list, None], default: None ) –

    The json path for the training data. Defaults to None.

  • test_json_path

    (Union[str, list, None], default: None ) –

    The json path for the test data. Defaults to None.

  • trainer_callbacks

    (list[TrCBCreator], default: None ) –

    A list of trainer callbacks for collecting metrics during the training at the client side. The transformers Trainer object will be passed in when each callback is called.

Returns:

  • Tuple ( tuple ) –

    The dataframe, examples, and the dataset

Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
def train(self,
          json_path: Union[str, list, None] = None,
          ignore_extra_labels=False,
          dataset=None,
          meta_requirements=None,
          train_json_path: Union[str, list, None] = None,
          test_json_path: Union[str, list, None] = None,
          trainer_callbacks: Optional[list[TrCBCreator]] = None
          ) -> tuple:
    """Train or continue training a model give a json_path containing a
    MedCATtrainer export. It will continue training if an existing model
    is loaded or start new training if the model is blank/new.

    Args:
        json_path (str or list):
            Path/Paths to a MedCATtrainer export containing the
            meta_annotations we want to train for.
        ignore_extra_labels:
            Makes only sense when an existing deid model was loaded and
            from the new data we want to ignore labels that did not exist
            in the old model.
        dataset: Defaults to None.
        meta_requirements: Defaults to None
        train_json_path (Union[str, list, None]):
            The json path for the training data. Defaults to None.
        test_json_path (Union[str, list, None]):
            The json path for the test data. Defaults to None.
        trainer_callbacks (list[TrCBCreator]):
            A list of trainer callbacks for collecting metrics during the
            training at the client side. The transformers Trainer object
            will be passed in when each callback is called.

    Returns:
        Tuple: The dataframe, examples, and the dataset
    """

    if dataset is None:
        # Load the medcattrainer export
        if json_path is not None:
            json_path = self._prepare_dataset(
                json_path, ignore_extra_labels=ignore_extra_labels,
                meta_requirements=meta_requirements,
                file_name='data_eval.json')
        elif test_json_path is not None and train_json_path is not None:
            train_json_path = self._prepare_dataset(
                train_json_path, ignore_extra_labels=ignore_extra_labels,
                meta_requirements=meta_requirements,
                file_name='data_train.json')
            test_json_path = self._prepare_dataset(
                test_json_path, ignore_extra_labels=ignore_extra_labels,
                meta_requirements=meta_requirements,
                file_name='data_test.json')
        # Load dataset

        # NOTE: The following is for backwards comppatibility
        #       in datasets==2.20.0 `trust_remote_code=True`
        #       must be explicitly specified, otherwise an error is raised.
        #       On the other hand, the keyword argument was added in
        #       datasets==2.16.0 yet we support datasets>=2.2.0.
        #       So we need to use the kwarg if applicable and omit
        #       its use otherwise.
        if func_has_kwarg(datasets.load_dataset, 'trust_remote_code'):
            ds_load_dataset = partial(datasets.load_dataset,
                                      trust_remote_code=True)
        else:
            ds_load_dataset = datasets.load_dataset
        if json_path:
            dataset = ds_load_dataset(os.path.abspath(
                transformers_ner.__file__),
                data_files={'train': json_path},  # type: ignore
                split='train',
                cache_dir='/tmp/')
            # We split before encoding so the split is document level,
            # as encoding  does the document splitting into max_seq_len
            dataset = dataset.train_test_split(
                test_size=self.config.general.test_size)  # type: ignore
        elif train_json_path and test_json_path:
            dataset = ds_load_dataset(
                os.path.abspath(transformers_ner.__file__),
                data_files={
                    'train': train_json_path,
                    'test': test_json_path},  # type: ignore
                cache_dir='/tmp/')
        else:
            raise ValueError(
                "Either json_path or train_json_path and test_json_path "
                "must be provided when no dataset is provided")

    # Update labelmap in case the current dataset has more labels
    # than what we had before
    self.tokenizer.calculate_label_map(dataset['train'])
    self.tokenizer.calculate_label_map(dataset['test'])

    if self.model.num_labels != len(self.tokenizer.label_map):
        logger.warning(
            "The dataset contains labels we've not seen before, "
            "model is being reinitialized")
        logger.warning("Model: %s vs Dataset: %s",
                       self.model.num_labels,
                       len(self.tokenizer.label_map))
        self.model = AutoModelForTokenClassification.from_pretrained(
            self.config.general.model_name,
            num_labels=len(self.tokenizer.label_map),
            ignore_mismatched_sizes=True)
        self.tokenizer.cui2name = {
            k: self.cdb.get_name(k)
            for k in self.tokenizer.label_map.keys()}

    self.model.config.id2label = {
        v: k for k, v in self.tokenizer.label_map.items()}
    self.model.config.label2id = self.tokenizer.label_map

    # Encode dataset
    # Note: tokenizer.encode performs chunking
    encoded_dataset = dataset.map(
            lambda examples: self.tokenizer.encode(
                examples, ignore_subwords=False),
            batched=True,
            remove_columns=['ent_cuis', 'ent_ends', 'ent_starts', 'text'])

    data_collator = CollateAndPadNER(
        self.tokenizer.hf_tokenizer.pad_token_id)  # type: ignore
    trainer = Trainer(
            model=self.model,
            args=self.training_arguments,
            train_dataset=encoded_dataset['train'],
            eval_dataset=encoded_dataset['test'],
            compute_metrics=lambda p: metrics(
                p, tokenizer=self.tokenizer,
                dataset=encoded_dataset['test'],
                verbose=self.config.general.verbose_metrics),
            data_collator=data_collator,  # type: ignore
            tokenizer=None)
    if trainer_callbacks:
        for tr_callback in trainer_callbacks:
            tcbo = tr_callback(trainer)
            # NOTE: No idea why mypy isn't able to find the method
            #       It reports (`[attr-defined]`):
            #          error: "Trainer" has no attribute "callback_handler"
            trainer.add_callback(tcbo)  # type: ignore

    # NOTE: No idea why mypy isn't able to find the method
    #       It reports:
    #          error: "Trainer" has no attribute "train"  [attr-defined]
    trainer.train()  # type: ignore

    # Save the training time
    self.config.general.last_train_on = datetime.now().timestamp()

    output_dir = self.training_arguments.output_dir
    if output_dir is None:
        # NOTE: shouldn't ever really happen
        raise ValueError("Unable to save output during training "
                         "since output path is None")
    # Save everything
    _save_component(self, save_dir_path=os.path.join(
        output_dir, 'final_model'),
        overwrite=True)

    # Run an eval step and return metrics
    p = trainer.predict(encoded_dataset['test'])  # type: ignore
    df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer,
                           dataset=encoded_dataset['test'])

    # Create the pipeline for eval
    self.create_eval_pipeline()

    return df, examples, dataset

func_has_kwarg

func_has_kwarg(func: Callable, keyword: str)
Source code in medcat-v2/medcat/components/ner/trf/transformers_ner.py
782
783
784
def func_has_kwarg(func: Callable, keyword: str):
    sig = inspect.signature(func)
    return keyword in sig.parameters