Skip to content

medcat.components.addons.relation_extraction.config

Classes:

Attributes:

logger module-attribute

logger = getLogger(__name__)

RelExtrBaseConfig

RelExtrBaseConfig(pretrained_model_name_or_path, **kwargs)

Bases: PretrainedConfig

Base class for the RelCAT models

Methods:

Attributes:

Source code in medcat-v2/medcat/components/addons/relation_extraction/config.py
16
17
18
19
20
21
def __init__(self, pretrained_model_name_or_path, **kwargs):
    super().__init__(**kwargs)
    self.model_type = "relcat"
    self.pretrained_model_name_or_path = pretrained_model_name_or_path
    self.hf_model_config: PretrainedConfig = kwargs.get(
        "model_config", PretrainedConfig())

hf_model_config instance-attribute

hf_model_config: PretrainedConfig = get('model_config', PretrainedConfig())

model_type instance-attribute

model_type = 'relcat'

name class-attribute instance-attribute

name = 'base-config-relcat'

pretrained_model_name_or_path instance-attribute

pretrained_model_name_or_path = pretrained_model_name_or_path

load classmethod

load(pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> RelExtrBaseConfig
Source code in medcat-v2/medcat/components/addons/relation_extraction/config.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@classmethod
def load(cls, pretrained_model_name_or_path: str,
         relcat_config: ConfigRelCAT, **kwargs
         ) -> "RelExtrBaseConfig":

    model_config_path = os.path.join(
        pretrained_model_name_or_path, "model_config.json")
    model_config = RelExtrBaseConfig(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        relcat_config=relcat_config, **kwargs)

    if os.path.exists(model_config_path):
        if "modern-bert" in relcat_config.general.tokenizer_name or \
           "modern-bert" in relcat_config.general.model_name:
            from medcat.components.addons.relation_extraction.modernbert.config import RelExtrModernBertConfig  # noqa
            model_config = RelExtrModernBertConfig.load(
                model_config_path, relcat_config=relcat_config, **kwargs)
        elif "bert" in relcat_config.general.tokenizer_name or \
                "bert" in relcat_config.general.model_name:
            from medcat.components.addons.relation_extraction.bert.config import RelExtrBertConfig  # noqa
            model_config = RelExtrBertConfig.load(
                model_config_path, relcat_config=relcat_config, **kwargs)
        elif "llama" in relcat_config.general.tokenizer_name or \
                "llama" in relcat_config.general.model_name:
            from medcat.components.addons.relation_extraction.llama.config import RelExtrLlamaConfig  # noqa
            model_config = RelExtrLlamaConfig.load(
                model_config_path, relcat_config=relcat_config, **kwargs)
    else:
        if pretrained_model_name_or_path:
            pretrained_path = pretrained_model_name_or_path
            model_config.hf_model_config = (
                PretrainedConfig.from_pretrained(
                    pretrained_model_name_or_path=pretrained_path,
                    **kwargs))
        else:
            model_name = relcat_config.general.model_name
            model_config.hf_model_config = (
                PretrainedConfig.from_pretrained(
                    pretrained_model_name_or_path=model_name, **kwargs))
        logger.info("Loaded config from : " + model_config_path)

    return model_config

save

save(save_path: str)
Source code in medcat-v2/medcat/components/addons/relation_extraction/config.py
31
32
33
def save(self, save_path: str):
    self.hf_model_config.to_json_file(
        os.path.join(save_path, "model_config.json"))

to_dict

to_dict()
Source code in medcat-v2/medcat/components/addons/relation_extraction/config.py
23
24
25
26
27
28
29
def to_dict(self):
    output = super().to_dict()
    output["model_type"] = self.model_type
    output["pretrained_model_name_or_path"
           ] = self.pretrained_model_name_or_path
    output["model_config"] = self.hf_model_config
    return output