Skip to content

medcat.components.addons.relation_extraction.base_component

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

RelExtrBaseComponent

Component that holds the model and everything for RelCAT.

Parameters:

Methods:

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/base_component.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
        self,
        tokenizer: BaseTokenizerWrapper = BaseTokenizerWrapper(),
        model: RelExtrBaseModel = None,  # type: ignore
        model_config: RelExtrBaseConfig = None,  # type: ignore
        config: ConfigRelCAT = ConfigRelCAT(),
        task: str = "train",
        init_model: bool = False):
    """ Component that holds the model and everything for RelCAT.

    Args:
        tokenizer (BaseTokenizerWrapper): The base tokenizer for RelCAT.
        model (RelExtrBaseModel): The model wrapper.
        model_config (RelExtrBaseConfig):
            The model-specific config.
        config (ConfigRelCAT): The RelCAT config.
        task (str): The task - used for checkpointing.
        init_model (bool): Loads default BERT base model, tokenizer,
            model config. Defaults to False.
    """
    self.model: RelExtrBaseModel = model
    self.tokenizer: BaseTokenizerWrapper = tokenizer
    self.relcat_config: ConfigRelCAT = config
    self.model_config: RelExtrBaseConfig = model_config
    self.optimizer: AdamW = None  # type: ignore
    self.scheduler: MultiStepLR = None  # type: ignore
    self.task: str = task
    self.epoch: int = 0
    self.best_f1: float = 0.0

    if init_model:
        model_name = self.relcat_config.general.model_name
        self.model_config = RelExtrBaseConfig.load(
            pretrained_model_name_or_path=model_name,
            relcat_config=self.relcat_config)

        self.tokenizer = BaseTokenizerWrapper.load(
            tokenizer_path=model_name,
            relcat_config=self.relcat_config)

        special_tokens = self.relcat_config.general.tokenizer_relation_annotation_special_tokens_tags  # noqa
        self.tokenizer.hf_tokenizers.add_tokens(
            special_tokens,
            special_tokens=True)

        # used in llama tokenizer, may produce issues with other tokenizers
        self.tokenizer.hf_tokenizers.add_special_tokens(
            self.relcat_config.general.tokenizer_other_special_tokens)
        self.relcat_config.general.annotation_schema_tag_ids = (
            self.tokenizer.hf_tokenizers.convert_tokens_to_ids(
                special_tokens))
        pad_idx = self.tokenizer.get_pad_id()
        self.relcat_config.model.padding_idx = pad_idx
        self.model_config.pad_token_id = pad_idx
        self.model_config.hf_model_config.vocab_size = (
            self.tokenizer.get_size())

        self.model = RelExtrBaseModel.load(
            pretrained_model_name_or_path=model_name,
            model_config=self.model_config,
            relcat_config=self.relcat_config)

        self.model.hf_model.resize_token_embeddings(
            self.tokenizer.get_size())

    self.pad_id = self.relcat_config.model.padding_idx
    self.padding_seq = Pad_Sequence(seq_pad_value=self.pad_id,
                                    label_pad_value=self.pad_id)

    logging.basicConfig(level=self.relcat_config.general.log_level)
    logger.setLevel(self.relcat_config.general.log_level)

    logger.info("RelExtrBaseComponent initialized")

best_f1 instance-attribute

best_f1: float = 0.0

epoch instance-attribute

epoch: int = 0

model instance-attribute

model_config instance-attribute

name class-attribute instance-attribute

name = 'base_component_rel'

optimizer instance-attribute

optimizer: AdamW = None

pad_id instance-attribute

pad_id = padding_idx

padding_seq instance-attribute

padding_seq = Pad_Sequence(seq_pad_value=pad_id, label_pad_value=pad_id)

relcat_config instance-attribute

relcat_config: ConfigRelCAT = config

scheduler instance-attribute

scheduler: MultiStepLR = None

task instance-attribute

task: str = task

tokenizer instance-attribute

from_relcat_config classmethod

from_relcat_config(relcat_config: ConfigRelCAT, pretrained_model_name_or_path: str = './') -> RelExtrBaseComponent
Source code in medcat-v2/medcat/components/addons/relation_extraction/base_component.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@classmethod
def from_relcat_config(cls, relcat_config: ConfigRelCAT,
                       pretrained_model_name_or_path: str = './'
                       ) -> 'RelExtrBaseComponent':
    model_config = RelExtrBaseConfig.load(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        relcat_config=relcat_config)

    tokenizer = BaseTokenizerWrapper.load(
        tokenizer_path=pretrained_model_name_or_path,
        relcat_config=relcat_config)

    model = RelExtrBaseModel.load(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        model_config=model_config, relcat_config=relcat_config)

    model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers))

    optimizer = None
    scheduler = None

    epoch, best_f1 = load_state(
        model, optimizer, scheduler, path=pretrained_model_name_or_path,
        model_name=relcat_config.general.model_name,
        file_prefix=relcat_config.general.task,
        relcat_config=relcat_config)

    component = cls(model=model, tokenizer=tokenizer,
                    model_config=model_config, config=relcat_config)
    cls.epoch = epoch
    cls.best_f1 = best_f1

    return component

load classmethod

Parameters:

  • pretrained_model_name_or_path

    (str, default: './' ) –

    Path to RelCAT model. Defaults to "./".

Returns:

Source code in medcat-v2/medcat/components/addons/relation_extraction/base_component.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@classmethod
def load(cls, pretrained_model_name_or_path: str = "./"
         ) -> "RelExtrBaseComponent":
    """
    Args:
        pretrained_model_name_or_path (str): Path to RelCAT model.
            Defaults to "./".

    Returns:
        RelExtrBaseComponent: component.
    """

    relcat_config = ConfigRelCAT.load(
        load_path=pretrained_model_name_or_path)
    return cls.from_relcat_config(
        relcat_config, pretrained_model_name_or_path)

save

save(save_path: str) -> None

Saves model and its dependencies to specified save_path folder. The CDB is obviously not saved, it is however necessary to save the tokenizer used.

Parameters:

  • save_path

    (str) –

    folder path in which to save the model & deps.

Source code in medcat-v2/medcat/components/addons/relation_extraction/base_component.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def save(self, save_path: str) -> None:
    """ Saves model and its dependencies to specified save_path folder.
        The CDB is obviously not saved, it is however necessary to save
        the tokenizer used.

    Args:
        save_path (str): folder path in which to save the model & deps.
    """

    assert self.relcat_config is not None
    cnf_path = os.path.join(save_path, "config")
    # NOTE: this'll be saved in a folder now
    os.makedirs(cnf_path, exist_ok=True)
    serialise('json', self.relcat_config, cnf_path)

    assert self.tokenizer is not None
    self.tokenizer.save(os.path.join(save_path))

    assert self.model is not None and self.model.hf_model is not None
    self.model.hf_model.resize_token_embeddings(self.tokenizer.get_size())

    assert self.model_config is not None
    self.model_config.hf_model_config.vocab_size = (
        self.tokenizer.get_size())
    self.model_config.hf_model_config.pad_token_id = self.pad_id
    self.model_config.save(save_path)

    save_state(self.model, optimizer=self.optimizer,
               scheduler=self.scheduler, epoch=self.epoch,
               best_f1=self.best_f1, path=save_path,
               model_name=self.relcat_config.general.model_name,
               task=self.task, is_checkpoint=False, final_export=True)