Skip to content

medcat.components.addons.relation_extraction.ml_utils

Functions:

Attributes:

logger module-attribute

logger = getLogger(__name__)

create_dense_layers

create_dense_layers(relcat_config: ConfigRelCAT)
Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def create_dense_layers(relcat_config: ConfigRelCAT):

    # dense layers
    fc1 = nn.Linear(
        relcat_config.model.model_size,
        relcat_config.model.hidden_size)
    fc2 = nn.Linear(
        relcat_config.model.hidden_size,
        int(relcat_config.model.hidden_size / 2))
    fc3 = nn.Linear(
        int(relcat_config.model.hidden_size / 2),
        relcat_config.train.nclasses)

    return fc1, fc2, fc3

create_tokenizer_pretrain

This method simply adds the default special tokens that we ecounter.

Parameters:

Returns:

Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def create_tokenizer_pretrain(tokenizer: BaseTokenizerWrapper,
                              relcat_config: ConfigRelCAT
                              ) -> BaseTokenizerWrapper:
    """This method simply adds the default special tokens that we ecounter.

    Args:
        tokenizer (BaseTokenizerWrapper):
            BERT/Llama tokenizer.
        relcat_config (ConfigRelCAT): The RelCAT config.

    Returns:
        BaseTokenizerWrapper: The same tokenizer.
    """
    special_tags = (
        relcat_config.general.tokenizer_relation_annotation_special_tokens_tags
    )
    tokenizer.hf_tokenizers.add_tokens(
        special_tags, special_tokens=True)

    # used in llama tokenizer, may produce issues with other tokenizers
    tokenizer.hf_tokenizers.add_special_tokens(
        relcat_config.general.tokenizer_other_special_tokens)

    return tokenizer

get_annotation_schema_tag

get_annotation_schema_tag(sequence_output: Tensor, input_ids: Tensor, special_tag: list) -> Tensor

Gets to token sequences from the sequence_ouput for the specific token tag ids in self.relcat_config.general.annotation_schema_tag_ids.

Parameters:

  • sequence_output

    (Tensor) –

    hidden states/embeddings for each token in the input text

  • input_ids

    (Tensor) –

    input token ids

  • special_tag

    (list) –

    special annotation token id pairs

Returns:

  • Tensor

    torch.Tensor: new seq_tags

Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def get_annotation_schema_tag(sequence_output: torch.Tensor,
                              input_ids: torch.Tensor, special_tag: list
                              ) -> torch.Tensor:
    """ Gets to token sequences from the sequence_ouput for the specific token
        tag ids in self.relcat_config.general.annotation_schema_tag_ids.

    Args:
        sequence_output (torch.Tensor): hidden states/embeddings for
            each token in the input text
        input_ids (torch.Tensor): input token ids
        special_tag (list): special annotation token id pairs

    Returns:
        torch.Tensor: new seq_tags
    """

    # returns: row ids, idx of token[0]/star token in row
    idx_start = torch.where(input_ids == special_tag[0])
    # returns: row ids, idx of token[1]/end token in row
    idx_end = torch.where(input_ids == special_tag[1])

    seen = []  # List to store seen elements and their indices
    duplicate_indices = []

    for i in range(len(idx_start[0])):
        if idx_start[0][i] in seen:
            duplicate_indices.append(i)
        else:
            seen.append(idx_start[0][i])

    if len(duplicate_indices) > 0:
        logger.info("Duplicate entities found, removing them...")
        for idx_remove in duplicate_indices:
            idx_start_0 = torch.cat((idx_start[0][:idx_remove],
                                     idx_start[0][idx_remove + 1:]))
            idx_start_1 = torch.cat((idx_start[1][:idx_remove],
                                     idx_start[1][idx_remove + 1:]))
            idx_start = (idx_start_0, idx_start_1)

    seen = []
    duplicate_indices = []

    for i in range(len(idx_end[0])):
        if idx_end[0][i] in seen:
            duplicate_indices.append(i)
        else:
            seen.append(idx_end[0][i])

    if len(duplicate_indices) > 0:
        logger.info("Duplicate entities found, removing them...")
        for idx_remove in duplicate_indices:
            idx_end_0 = torch.cat((idx_end[0][:idx_remove],
                                   idx_end[0][idx_remove + 1:]))
            idx_end_1 = torch.cat((idx_end[1][:idx_remove],
                                   idx_end[1][idx_remove + 1:]))
            idx_end = (idx_end_0, idx_end_1)

    assert len(idx_start[0]) == input_ids.shape[0]
    assert len(idx_start[0]) == len(idx_end[0])

    sequence_output_entities: list[torch.Tensor] = []

    for i in range(len(idx_start[0])):
        to_append = sequence_output[i, idx_start[1][i] + 1:idx_end[1][i], ]

        to_append, _ = torch.max(to_append, axis=0)  # type: ignore

        sequence_output_entities.append(to_append)
    return torch.stack(sequence_output_entities)

load_bin_file

load_bin_file(file_name, path='./') -> Any
Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
100
101
102
103
def load_bin_file(file_name, path="./") -> Any:
    with open(os.path.join(path, file_name), 'rb') as f:
        data = pickle.load(f)
    return data

load_results

load_results(path: str, model_name: str = 'BERT', file_prefix: str = 'train') -> tuple[list, list, list]
Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def load_results(path: str, model_name: str = "BERT",
                 file_prefix: str = "train"
                 ) -> tuple[list, list, list]:
    data_dict_path = os.path.join(
        path, file_prefix + f"_losses_accuracy_f1_per_epoch_{model_name}.dat")

    data_dict: dict = {"losses_per_epoch": [],
                       "accuracy_per_epoch": [], "f1_per_epoch": []}
    if os.path.isfile(data_dict_path):
        data_dict = load_bin_file(data_dict_path)

    return (data_dict["losses_per_epoch"],
            data_dict["accuracy_per_epoch"],
            data_dict["f1_per_epoch"])

load_state

load_state(model, optimizer, scheduler, path: str = './', model_name: str = 'BERT', file_prefix: str = 'train', load_best: bool = False, relcat_config: ConfigRelCAT = ConfigRelCAT()) -> tuple[int, int]

Used by RelCAT.load() and RelCAT.train()

Parameters:

  • model

    (RelExtrBaseModel) –

    RelExtrBaseModel, it has to be initialized before calling this method via RelExtr(Bert/Llama)Model(...)

  • optimizer

    (_type_) –

    optimizer

  • scheduler

    (_type_) –

    scheduler

  • path

    (str, default: './' ) –

    Defaults to "./".

  • model_name

    (str, default: 'BERT' ) –

    Defaults to "BERT".

  • file_prefix

    (str, default: 'train' ) –

    Defaults to "train".

  • load_best

    (bool, default: False ) –

    Defaults to False.

  • relcat_config

    (ConfigRelCAT, default: ConfigRelCAT() ) –

    Defaults to ConfigRelCAT().

Returns:

  • tuple ( (int, int) ) –

    last epoch and f1 score.

Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def load_state(model, optimizer, scheduler, path: str = "./",
               model_name: str = "BERT", file_prefix: str = "train",
               load_best: bool = False,
               relcat_config: ConfigRelCAT = ConfigRelCAT()
               ) -> tuple[int, int]:
    """ Used by RelCAT.load() and RelCAT.train()

    Args:
        model (RelExtrBaseModel): RelExtrBaseModel,
            it has to be initialized before calling this method via
            RelExtr(Bert/Llama)Model(...)
        optimizer (_type_): optimizer
        scheduler (_type_): scheduler
        path (str, optional): Defaults to "./".
        model_name (str, optional): Defaults to "BERT".
        file_prefix (str, optional): Defaults to "train".
        load_best (bool, optional): Defaults to False.
        relcat_config (ConfigRelCAT): Defaults to ConfigRelCAT().

    Returns:
        tuple (int, int): last epoch and f1 score.
    """

    device: torch.device = torch.device(relcat_config.general.device)

    model_name = model_name.replace("/", "_")
    logging.info("Attempting to load RelCAT model on device: " + str(device))
    checkpoint_path = os.path.join(
        path, file_prefix + "_checkpoint_%s.dat" % model_name)
    best_path = os.path.join(
        path, file_prefix + "_best_%s.dat" % model_name)
    start_epoch, best_f1, checkpoint = 0, 0, None

    if load_best is True and os.path.isfile(best_path):
        checkpoint = torch.load(best_path, map_location=device)
        logging.info("Loaded best model.")
    elif os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        logging.info("Loaded checkpoint model.")

    if checkpoint is not None:
        start_epoch = checkpoint['epoch']
        best_f1 = checkpoint['best_f1']
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        if optimizer is None:
            parameters = filter(lambda p: p.requires_grad, model.parameters())
            optimizer = torch.optim.AdamW(
                params=parameters, lr=relcat_config.train.lr,
                weight_decay=relcat_config.train.adam_weight_decay,
                betas=relcat_config.train.adam_betas,
                eps=relcat_config.train.adam_epsilon)

        if scheduler is None:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=relcat_config.train.multistep_milestones,
                gamma=relcat_config.train.multistep_lr_gamma)
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        logging.info("Loaded model and optimizer.")

    return start_epoch, best_f1

save_bin_file

save_bin_file(file_name, data, path='./')
Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
106
107
108
def save_bin_file(file_name, data, path="./"):
    with open(os.path.join(path, file_name), "wb") as f:
        pickle.dump(data, f)

save_results

save_results(data, model_name: str = 'BERT', path: str = './', file_prefix: str = 'train')
Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
223
224
225
226
def save_results(data, model_name: str = "BERT", path: str = "./",
                 file_prefix: str = "train"):
    save_bin_file(file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" %
                  model_name, data, path)

save_state

save_state(model, optimizer: AdamW, scheduler: MultiStepLR, epoch: int = 1, best_f1: float = 0.0, path: str = './', model_name: str = 'BERT', task: str = 'train', is_checkpoint=False, final_export=False) -> None

Used by RelCAT.save() and RelCAT.train() Saves the RelCAT model state. For checkpointing multiple files are created, best_f1, loss etc. score. If you want to export the model after training set final_export=True and leave is_checkpoint=False.

Parameters:

  • model

    (BaseModel) –

    BertMode | LlamaModel etc.

  • optimizer

    (AdamW) –

    Defaults to None.

  • scheduler

    (MultiStepLR) –

    Defaults to None.

  • epoch

    (int, default: 1 ) –

    Defaults to None.

  • best_f1

    (float, default: 0.0 ) –

    Defaults to None.

  • path

    (str, default: './' ) –

    Defaults to "./".

  • model_name

    (str, default: 'BERT' ) –

    . Defaults to "BERT". This is used to checkpointing only.

  • task

    (str, default: 'train' ) –

    Defaults to "train". This is used to checkpointing only.

  • is_checkpoint

    (bool, default: False ) –

    Defaults to False.

  • final_export

    (bool, default: False ) –

    Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into "model.dat".

Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def save_state(model, optimizer: torch.optim.AdamW,
               scheduler: torch.optim.lr_scheduler.MultiStepLR,
               epoch: int = 1, best_f1: float = 0.0, path: str = "./",
               model_name: str = "BERT", task: str = "train",
               is_checkpoint=False, final_export=False) -> None:
    """ Used by RelCAT.save() and RelCAT.train()
        Saves the RelCAT model state.
        For checkpointing multiple files are created, best_f1, loss etc. score.
        If you want to export the model after training set final_export=True
        and leave is_checkpoint=False.

    Args:
        model (BaseModel): BertMode | LlamaModel etc.
        optimizer (torch.optim.AdamW, optional): Defaults to None.
        scheduler (torch.optim.lr_scheduler.MultiStepLR, optional):
            Defaults to None.
        epoch (int): Defaults to None.
        best_f1 (float): Defaults to None.
        path (str):Defaults to "./".
        model_name (str): . Defaults to "BERT". This is used to
            checkpointing only.
        task (str): Defaults to "train". This is used to checkpointing only.
        is_checkpoint (bool): Defaults to False.
        final_export (bool): Defaults to False, if True then is_checkpoint
            must be False also. Exports model.state_dict(),
            out into "model.dat".
    """

    model_name = model_name.replace("/", "_")
    file_name = "%s_checkpoint_%s.dat" % (task, model_name)

    if not is_checkpoint:
        file_name = "%s_best_%s.dat" % (task, model_name)
        if final_export:
            file_name = "model.dat"
            torch.save(model.state_dict(), os.path.join(path, file_name))

    if is_checkpoint:
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'best_f1':  best_f1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }, os.path.join(path, file_name))

split_list_train_test_by_class

split_list_train_test_by_class(data: list, sample_limit: int = -1, test_size: float = 0.2, shuffle: bool = True) -> tuple[list, list]

Parameters:

  • data

    (list) –

    "output_relations": relation_instances, see create_base_relations_from_doc/csv for data columns

  • sample_limit

    (int, default: -1 ) –

    Limit the number of samples per class, useful for dataset balancing . Defaults to -1.

  • test_size

    (float, default: 0.2 ) –

    Defaults to 0.2.

  • shuffle

    (bool, default: True ) –

    Shuffle data randomly. Defaults to True.

Returns:

  • tuple[list, list]

    tuple[list, list]: Train and test datasets

Source code in medcat-v2/medcat/components/addons/relation_extraction/ml_utils.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def split_list_train_test_by_class(data: list, sample_limit: int = -1,
                                   test_size: float = 0.2,
                                   shuffle: bool = True) -> tuple[list, list]:
    """

    Args:
        data (list): "output_relations": relation_instances,
            see create_base_relations_from_doc/csv for data columns
        sample_limit (int): Limit the number of samples per class,
            useful for dataset balancing . Defaults to -1.
        test_size (float): Defaults to 0.2.
        shuffle (bool): Shuffle data randomly. Defaults to True.

    Returns:
        tuple[list, list]: Train and test datasets
    """

    train_data = []
    test_data = []

    row_id_labels = {row_idx: data[row_idx][5] for row_idx in range(len(data))}
    lbl_id_to_name = {data[row_idx][5]: data[row_idx][4]
                      for row_idx in range((len(data)))}

    count_per_label = {lbl: list(row_id_labels.values()).count(
        lbl) for lbl in set(row_id_labels.values())}

    new_label_count_train = {}
    new_label_count_test = {}

    for lbl_id, count in count_per_label.items():
        if sample_limit != -1 and count > sample_limit:
            count = sample_limit

        _test_records_size = int(count * test_size)

        test_sample_count = 0
        train_sample_count = 0

        if _test_records_size not in [0, 1]:
            for row_idx, _lbl_id in row_id_labels.items():
                if _lbl_id == lbl_id:
                    if test_sample_count < _test_records_size:
                        test_data.append(data[row_idx])
                        test_sample_count += 1
                    else:
                        if sample_limit != -1:
                            if train_sample_count < sample_limit:
                                train_data.append(data[row_idx])
                                train_sample_count += 1
                        else:
                            train_data.append(data[row_idx])
                            train_sample_count += 1

        else:
            for row_idx, _lbl_id in row_id_labels.items():
                if _lbl_id == lbl_id:
                    train_data.append(data[row_idx])
                    test_data.append(data[row_idx])
                    train_sample_count += 1
                    test_sample_count += 1

        new_label_count_test[lbl_id] = test_sample_count
        new_label_count_train[lbl_id] = train_sample_count

    logging.info("Relations after train, test split :  train - %s | test - %s",
                 str(sum(new_label_count_train.values())),
                 str(sum(new_label_count_test.values())))

    for label_id in list(lbl_id_to_name.keys()):
        logging.info(" label: %s samples | train %s | test %s",
                     lbl_id_to_name[label_id],
                     str(new_label_count_train[label_id]),
                     str(new_label_count_test[label_id]))

    if shuffle:
        random.shuffle(train_data)
        random.shuffle(test_data)

    return train_data, test_data