Skip to content

medcat.components.addons.relation_extraction.models

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

BaseModelBluePrint

Bases: Module

Base class for the RelCAT models

Class to hold the HF model + model_config

Parameters:

  • pretrained_model_name_or_path

    (str) –

    path to load the model from, this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat' using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.

  • relcat_config

    (ConfigRelCAT) –

    relcat config.

  • model_config

    (PretrainedConfig) –

    HF bert config for model.

Methods:

  • forward

    Forward pass for the model

  • output2logits

    Convert the output of the model to logits

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(self, pretrained_model_name_or_path: str,
             relcat_config: ConfigRelCAT,
             model_config: Union[PretrainedConfig,
                                 RelExtrBaseConfig]):
    """ Class to hold the HF model + model_config

    Args:
        pretrained_model_name_or_path (str): path to load the model from,
                this can be a HF model i.e: "bert-base-uncased",
                if left empty, it is normally assumed that a model is
                loaded from 'model.dat' using the RelCAT.load() method.
                So if you are initializing/training a model from scratch
                be sure to base it on some model.
        relcat_config (ConfigRelCAT): relcat config.
        model_config (PretrainedConfig): HF bert config for model.
    """
    super().__init__()

drop_out instance-attribute

drop_out: Dropout

fc1 instance-attribute

fc1: Linear

fc2 instance-attribute

fc2: Linear

fc3 instance-attribute

fc3: Linear

hf_model instance-attribute

hf_model: PreTrainedModel

model_config instance-attribute

model_config: PretrainedConfig

relcat_config instance-attribute

relcat_config: ConfigRelCAT

forward

forward(input_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Any = None, head_mask: Any = None, encoder_hidden_states: Any = None, encoder_attention_mask: Any = None, Q: Any = None, e1_e2_start: Any = None, pooled_output: Any = None) -> Optional[tuple[Tensor, Tensor]]

Forward pass for the model

Parameters:

  • input_ids

    (Tensor, default: None ) –

    input token ids. Defaults to None.

  • attention_mask

    (Tensor, default: None ) –

    attention mask for the input ids. Defaults to None.

  • token_type_ids

    (Tensor, default: None ) –

    token type ids for the input ids. Defaults to None.

  • position_ids

    (Any, default: None ) –

    The position IDs. Defaults to None.

  • head_mask

    (Any, default: None ) –

    The head mask. Defaults to None.

  • encoder_hidden_states

    (Any, default: None ) –

    Encoder hidden states. Defaults to None.

  • encoder_attention_mask

    (Any, default: None ) –

    Encoder attention mask. Defaults to None.

  • Q

    (Any, default: None ) –

    Q. Defaults to None.

  • e1_e2_start

    (Any, default: None ) –

    Start and end indices for the entities in the input ids. Defaults to None.

  • pooled_output

    (Any, default: None ) –

    The pooled output. Defaults to None.

Returns:

  • Optional[tuple[Tensor, Tensor]]

    Optional[tuple[torch.Tensor, torch.Tensor]]: Logits for the relation classification task.

Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Any = None,
        head_mask: Any = None,
        encoder_hidden_states: Any = None,
        encoder_attention_mask: Any = None,
        Q: Any = None,
        e1_e2_start: Any = None,
        pooled_output: Any = None
        ) -> Optional[tuple[torch.Tensor, torch.Tensor]]:
    """Forward pass for the model

    Args:
        input_ids (torch.Tensor): input token ids. Defaults to None.
        attention_mask (torch.Tensor): attention mask for the input ids.
            Defaults to None.
        token_type_ids (torch.Tensor): token type ids for the input ids.
            Defaults to None.
        position_ids (Any): The position IDs. Defaults to None.
        head_mask (Any): The head mask. Defaults to None.
        encoder_hidden_states (Any): Encoder hidden states.
            Defaults to None.
        encoder_attention_mask (Any): Encoder attention mask.
            Defaults to None.
        Q (Any): Q. Defaults to None.
        e1_e2_start (Any): Start and end indices for the entities in
            the input ids. Defaults to None.
        pooled_output (Any): The pooled output. Defaults to None.

    Returns:
        Optional[tuple[torch.Tensor, torch.Tensor]]:
            Logits for the relation classification task.
    """
    return None

output2logits

output2logits(pooled_output: Tensor, sequence_output: Tensor, input_ids: Tensor, e1_e2_start: Tensor) -> Optional[Tensor]

Convert the output of the model to logits

Parameters:

  • pooled_output

    (Tensor) –

    output of the pooled layer.

  • sequence_output

    (Tensor) –

    output of the sequence layer.

  • input_ids

    (Tensor) –

    input token ids.

  • e1_e2_start

    (Tensor) –

    start and end indices for the entities in the input ids.

Returns:

  • logits ( Tensor ) –

    logits for the relation classification task.

Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def output2logits(self, pooled_output: torch.Tensor,
                  sequence_output: torch.Tensor,
                  input_ids: torch.Tensor,
                  e1_e2_start: torch.Tensor) -> Optional[torch.Tensor]:
    """ Convert the output of the model to logits

    Args:
        pooled_output (torch.Tensor): output of the pooled layer.
        sequence_output (torch.Tensor): output of the sequence layer.
        input_ids (torch.Tensor): input token ids.
        e1_e2_start (torch.Tensor): start and end indices for the entities
            in the input ids.

    Returns:
        logits (torch.Tensor): logits for the relation classification task.
    """
    return None

RelExtrBaseModel

RelExtrBaseModel(relcat_config: ConfigRelCAT, model_config: RelExtrBaseConfig, pretrained_model_name_or_path)

Bases: BaseModelBluePrint

Methods:

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def __init__(self, relcat_config: ConfigRelCAT,
             model_config: RelExtrBaseConfig,
             pretrained_model_name_or_path):
    super().__init__(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        relcat_config=relcat_config,
        model_config=model_config)

    self.relcat_config: ConfigRelCAT = relcat_config
    self.model_config: RelExtrBaseConfig = model_config
    self.hf_model = PreTrainedModel(config=model_config.hf_model_config)
    self.pretrained_model_name_or_path: str = pretrained_model_name_or_path

    self._reinitialize_dense_and_frozen_layers(relcat_config=relcat_config)

    logger.info("RelCAT model config: %s",
                str(self.model_config.hf_model_config))

hf_model instance-attribute

hf_model = PreTrainedModel(config=hf_model_config)

model_config instance-attribute

model_config: RelExtrBaseConfig = model_config

name class-attribute instance-attribute

name = 'basemodel_relcat'

pretrained_model_name_or_path instance-attribute

pretrained_model_name_or_path: str = pretrained_model_name_or_path

relcat_config instance-attribute

relcat_config: ConfigRelCAT = relcat_config

forward

forward(input_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Any = None, head_mask: Any = None, encoder_hidden_states: Any = None, encoder_attention_mask: Any = None, Q: Any = None, e1_e2_start: Any = None, pooled_output: Any = None) -> tuple[Tensor, Tensor]
Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def forward(self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Any = None,
            head_mask: Any = None,
            encoder_hidden_states: Any = None,
            encoder_attention_mask: Any = None,
            Q: Any = None,
            e1_e2_start: Any = None,
            pooled_output: Any = None
            ) -> tuple[torch.Tensor, torch.Tensor]:

    if input_ids is not None:
        input_shape = input_ids.size()
    else:
        raise ValueError("You have to specify input_ids")

    if attention_mask is None:
        attention_mask = torch.ones(
            input_shape, device=self.relcat_config.general.device)
    if encoder_attention_mask is None:
        encoder_attention_mask = torch.ones(
            input_shape, device=self.relcat_config.general.device)
    if token_type_ids is None:
        token_type_ids = torch.zeros(
            input_shape, dtype=torch.long,
            device=self.relcat_config.general.device)

    input_ids = input_ids.to(self.relcat_config.general.device)
    attention_mask = attention_mask.to(self.relcat_config.general.device)
    encoder_attention_mask = encoder_attention_mask.to(
        self.relcat_config.general.device)

    # NOTE: the wrapping of the method means that mypy can't
    #       properly understand it
    self.hf_model = self.hf_model.to(
        self.relcat_config.general.device)  # type: ignore

    model_output = self.hf_model(
        input_ids=input_ids, attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask)

    # (batch_size, sequence_length, hidden_size)
    sequence_output = model_output[0]
    pooled_output = model_output[1]

    classification_logits = self.output2logits(
        pooled_output, sequence_output, input_ids, e1_e2_start)

    return model_output, classification_logits.to(
        self.relcat_config.general.device)

load classmethod

Load the model from the given path

Parameters:

  • pretrained_model_name_or_path

    (str) –

    path to load the model from.

  • relcat_config

    (ConfigRelCAT) –

    relcat config.

  • model_config

    (RelExtrBaseConfig) –

    The model-specific config.

Returns:

Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
@classmethod
def load(cls, pretrained_model_name_or_path: str,
         relcat_config: ConfigRelCAT,
         model_config: RelExtrBaseConfig
         ) -> "RelExtrBaseModel":
    """ Load the model from the given path

    Args:
        pretrained_model_name_or_path (str): path to load the model from.
        relcat_config (ConfigRelCAT): relcat config.
        model_config (RelExtrBaseConfig):
            The model-specific config.

    returns:
        RelExtrBaseModel: The loaded model.
    """

    model = RelExtrBaseModel(
        relcat_config=relcat_config, model_config=model_config,
        pretrained_model_name_or_path=pretrained_model_name_or_path)

    if "modern-bert" in relcat_config.general.tokenizer_name or \
            "modern-bert" in relcat_config.general.model_name:
        from medcat.components.addons.relation_extraction.modernbert.model import RelExtrModernBertModel  # noqa
        from medcat.components.addons.relation_extraction.modernbert.config import RelExtrModernBertConfig  # noqa
        model = RelExtrModernBertModel.load_specific(
            pretrained_model_name_or_path, relcat_config=relcat_config,
            model_config=cast(RelExtrModernBertConfig, model_config))
    elif "bert" in relcat_config.general.tokenizer_name or \
         "bert" in relcat_config.general.model_name:
        from medcat.components.addons.relation_extraction.bert.model import RelExtrBertModel # noqa
        from medcat.components.addons.relation_extraction.bert.config import RelExtrBertConfig  # noqa
        model = RelExtrBertModel.load_specific(
            pretrained_model_name_or_path, relcat_config=relcat_config,
            model_config=cast(RelExtrBertConfig, model_config))
    elif "llama" in relcat_config.general.tokenizer_name or \
         "llama" in relcat_config.general.model_name:
        from medcat.components.addons.relation_extraction.llama.model import RelExtrLlamaModel # noqa
        from medcat.components.addons.relation_extraction.llama.config import RelExtrLlamaConfig # noqa
        model = RelExtrLlamaModel.load_specific(
            pretrained_model_name_or_path, relcat_config=relcat_config,
            model_config=cast(RelExtrLlamaConfig, model_config))
    else:
        if pretrained_model_name_or_path:
            model.hf_model = PreTrainedModel.from_pretrained(
                pretrained_model_name_or_path=pretrained_model_name_or_path, config=model_config)  # noqa
        else:
            model_name = relcat_config.general.model_name
            model.hf_model = PreTrainedModel.from_pretrained(
                pretrained_model_name_or_path=model_name,
                config=model_config)
            logger.info("Loaded model from relcat_config: %s",
                        relcat_config.general.model_name)

    logger.info("Loaded %s from pretrained_model_name_or_path: %s",
                str(model.__class__.__name__),
                pretrained_model_name_or_path)

    model._reinitialize_dense_and_frozen_layers(
        relcat_config=relcat_config)

    return model

output2logits

output2logits(pooled_output: Tensor, sequence_output: Tensor, input_ids: Tensor, e1_e2_start: Tensor) -> Tensor

Parameters:

  • pooled_output

    (Tensor) –

    embedding of the CLS token

  • sequence_output

    (Tensor) –

    hidden states/embeddings for each token in the input text

  • input_ids

    (Tensor) –

    input token ids.

  • e1_e2_start

    (Tensor) –

    annotation tags token position

Returns:

  • Tensor

    torch.Tensor: classification probabilities for each token.

Source code in medcat-v2/medcat/components/addons/relation_extraction/models.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def output2logits(self, pooled_output: torch.Tensor,
                  sequence_output: torch.Tensor, input_ids: torch.Tensor,
                  e1_e2_start: torch.Tensor) -> torch.Tensor:
    """

    Args:
        pooled_output (torch.Tensor): embedding of the CLS token
        sequence_output (torch.Tensor): hidden states/embeddings for
            each token in the input text
        input_ids (torch.Tensor): input token ids.
        e1_e2_start (torch.Tensor): annotation tags token position

    Returns:
        torch.Tensor: classification probabilities for each token.
    """

    new_pooled_output = pooled_output

    if self.relcat_config.general.annotation_schema_tag_ids:
        rel_range = range(
            0,
            len(self.relcat_config.general.annotation_schema_tag_ids), 2)
        annotation_schema_tag_ids_ = [
            self.relcat_config.general.annotation_schema_tag_ids[i:i + 2]
            for i in rel_range]
        seq_tags = []

        # for each pair of tags (e1,s1) and (e2,s2)
        for each_tags in annotation_schema_tag_ids_:
            seq_tags.append(get_annotation_schema_tag(
                sequence_output, input_ids, each_tags))

        stacked_tensor = torch.stack(seq_tags, dim=0)

        new_pooled_output = torch.cat(
            (pooled_output, *stacked_tensor), dim=1)
    else:
        e1e2_output = []
        temp_e1 = []
        temp_e2 = []

        for i, seq in enumerate(sequence_output):
            # e1e2 token sequences
            temp_e1.append(seq[e1_e2_start[i][0]])
            temp_e2.append(seq[e1_e2_start[i][1]])

        e1e2_output.append(torch.stack(temp_e1, dim=0))
        e1e2_output.append(torch.stack(temp_e2, dim=0))

        new_pooled_output = torch.cat((pooled_output, *e1e2_output), dim=1)

        del e1e2_output
        del temp_e2
        del temp_e1

    x = self.drop_out(new_pooled_output)
    x = self.fc1(x)
    x = self.drop_out(x)
    x = self.fc2(x)
    classification_logits = self.fc3(x)

    return classification_logits.to(self.relcat_config.general.device)