Skip to content

medcat.components.addons.relation_extraction.bert.model

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

RelExtrBertModel

Bases: RelExtrBaseModel

BertModel class for RelCAT

Class to hold the BERT model + model_config

Parameters:

  • pretrained_model_name_or_path

    (str) –

    path to load the model from, this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.

  • relcat_config

    (ConfigRelCAT) –

    relcat config.

  • model_config

    (Union[RelExtrBaseConfig | RelExtrBertConfig]) –

    HF bert config for model.

Methods:

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/bert/model.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(self, pretrained_model_name_or_path: str,
             relcat_config: ConfigRelCAT,
             model_config: RelExtrBertConfig):
    """ Class to hold the BERT model + model_config

    Args:
        pretrained_model_name_or_path (str): path to load the model from,
                this can be a HF model i.e: "bert-base-uncased",
                if left empty, it is normally assumed that a model
                is loaded from 'model.dat' using the RelCAT.load() method.
                So if you are initializing/training a model from scratch
                be sure to base it on some model.
        relcat_config (ConfigRelCAT): relcat config.
        model_config (Union[RelExtrBaseConfig | RelExtrBertConfig]):
            HF bert config for model.
    """
    super().__init__(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        relcat_config=relcat_config, model_config=model_config)

    self.relcat_config: ConfigRelCAT = relcat_config
    self.model_config: RelExtrBertConfig = model_config
    self.pretrained_model_name_or_path: str = pretrained_model_name_or_path

    self.hf_model = BertModel(model_config.hf_model_config)

    for param in self.hf_model.parameters():
        if self.relcat_config.model.freeze_layers:
            param.requires_grad = False
        else:
            param.requires_grad = True

    self.drop_out = nn.Dropout(self.relcat_config.model.dropout)

    # dense layers
    self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config)

drop_out instance-attribute

drop_out = Dropout(dropout)

hf_model instance-attribute

hf_model = BertModel(hf_model_config)

model_config instance-attribute

name class-attribute instance-attribute

name = 'bertmodel_relcat'

pretrained_model_name_or_path instance-attribute

pretrained_model_name_or_path: str = pretrained_model_name_or_path

relcat_config instance-attribute

relcat_config: ConfigRelCAT = relcat_config

load_specific classmethod

load_specific(pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: RelExtrBertConfig, **kwargs) -> RelExtrBertModel
Source code in medcat-v2/medcat/components/addons/relation_extraction/bert/model.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@classmethod
def load_specific(cls, pretrained_model_name_or_path: str,
                  relcat_config: ConfigRelCAT,
                  model_config: RelExtrBertConfig, **kwargs
                  ) -> "RelExtrBertModel":

    model = RelExtrBertModel(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        relcat_config=relcat_config, model_config=model_config)

    model_path = os.path.join(pretrained_model_name_or_path, "model.dat")

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(
            model_path, map_location=relcat_config.general.device))
        logger.info("Loaded model from file: %s", model_path)
    elif pretrained_model_name_or_path:
        model.hf_model = BertModel.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            config=model_config.hf_model_config,
            ignore_mismatched_sizes=True, **kwargs)
        logger.info("Loaded model from pretrained: %s",
                    pretrained_model_name_or_path)
    else:
        pretrained_model = cls.pretrained_model_name_or_path
        model.hf_model = BertModel.from_pretrained(
            pretrained_model_name_or_path=pretrained_model,
            config=model_config.hf_model_config,
            ignore_mismatched_sizes=True, **kwargs)
        logger.info("Loaded model from pretrained: %s", pretrained_model)

    return model