Skip to content

medcat.stats.kfold

Classes:

Functions:

Attributes:

FloatValuedMetric module-attribute

FloatValuedMetric = Union[dict[str, float], dict[str, tuple[float, float]]]

IntValuedMetric module-attribute

IntValuedMetric = Union[dict[str, int], dict[str, tuple[int, float]]]

FoldCreator

Bases: ABC

The FoldCreator based on a MCT export.

Parameters:

  • mct_export

    (MedCATTrainerExport) –

    The MCT export dict.

  • nr_of_folds

    (int) –

    Number of folds to create.

  • use_annotations

    (bool) –

    Whether to fold on number of annotations or documents.

Methods:

Attributes:

Source code in medcat-v2/medcat/stats/kfold.py
59
60
61
62
def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int
             ) -> None:
    self.mct_export = mct_export
    self.nr_of_folds = nr_of_folds

mct_export instance-attribute

mct_export = mct_export

nr_of_folds instance-attribute

nr_of_folds = nr_of_folds

create_folds abstractmethod

create_folds() -> list[MedCATTrainerExport]

Create folds.

Raises:

Returns:

Source code in medcat-v2/medcat/stats/kfold.py
109
110
111
112
113
114
115
116
117
118
@abstractmethod
def create_folds(self) -> list[MedCATTrainerExport]:
    """Create folds.

    Raises:
        ValueError: If something went wrong.

    Returns:
        list[MedCATTrainerExport]: The created folds.
    """

PerAnnsFoldCreator

PerAnnsFoldCreator(mct_export: MedCATTrainerExport, nr_of_folds: int)

Bases: SimpleFoldCreator

Source code in medcat-v2/medcat/stats/kfold.py
177
178
179
def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int
             ) -> None:
    super().__init__(mct_export, nr_of_folds, count_all_annotations)

PerCUIMetrics

Bases: BaseModel

Methods:

Attributes:

vals class-attribute instance-attribute

vals: list[Union[int, float]] = []

weights class-attribute instance-attribute

weights: list[Union[int, float]] = []

add

add(val, weight: int = 1)
Source code in medcat-v2/medcat/stats/kfold.py
334
335
336
def add(self, val, weight: int = 1):
    self.weights.append(weight)
    self.vals.append(val)

get_mean

get_mean()
Source code in medcat-v2/medcat/stats/kfold.py
338
339
340
def get_mean(self):
    return sum(w * v for w, v
               in zip(self.weights, self.vals)) / sum(self.weights)

get_std

get_std()
Source code in medcat-v2/medcat/stats/kfold.py
342
343
344
345
def get_std(self):
    mean = self.get_mean()
    return (sum(w * (v - mean)**2 for w, v
                in zip(self.weights, self.vals)) / sum(self.weights))**.5

PerDocsFoldCreator

PerDocsFoldCreator(mct_export: MedCATTrainerExport, nr_of_folds: int)

Bases: FoldCreator

Methods:

Attributes:

Source code in medcat-v2/medcat/stats/kfold.py
153
154
155
156
157
158
def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int
             ) -> None:
    super().__init__(mct_export, nr_of_folds)
    self.nr_of_docs = count_all_docs(self.mct_export)
    self.per_doc_simple = self.nr_of_docs // self.nr_of_folds
    self._all_docs = list(iter_docs(self.mct_export))

nr_of_docs instance-attribute

nr_of_docs = count_all_docs(mct_export)

per_doc_simple instance-attribute

per_doc_simple = nr_of_docs // nr_of_folds

create_folds

create_folds() -> list[MedCATTrainerExport]
Source code in medcat-v2/medcat/stats/kfold.py
169
170
171
172
def create_folds(self) -> list[MedCATTrainerExport]:
    return [
        self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds)
    ]

SimpleFoldCreator

SimpleFoldCreator(mct_export: MedCATTrainerExport, nr_of_folds: int, counter: Callable[[MedCATTrainerExport], int])

Bases: FoldCreator

Methods:

Attributes:

Source code in medcat-v2/medcat/stats/kfold.py
123
124
125
126
127
128
def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int,
             counter: Callable[[MedCATTrainerExport], int]) -> None:
    super().__init__(mct_export, nr_of_folds)
    self._counter = counter
    self.total = self._counter(mct_export)
    self.per_fold = self._init_per_fold()

per_fold instance-attribute

per_fold = _init_per_fold()

total instance-attribute

total = _counter(mct_export)

create_folds

create_folds() -> list[MedCATTrainerExport]
Source code in medcat-v2/medcat/stats/kfold.py
145
146
147
148
def create_folds(self) -> list[MedCATTrainerExport]:
    return [
        self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds)
    ]

SplitType

Bases: Enum

The split type.

Attributes:

ANNOTATIONS class-attribute instance-attribute

ANNOTATIONS = auto()

Split over number of annotations.

DOCUMENTS class-attribute instance-attribute

DOCUMENTS = auto()

Split over number of documents.

DOCUMENTS_WEIGHTED class-attribute instance-attribute

DOCUMENTS_WEIGHTED = auto()

Split over number of documents based on the number of annotations. So essentially this ensures that the same document isn't in 2 folds while trying to more equally distribute documents with different number of annotations. For example: If we have 6 documents that we want to split into 3 folds. The number of annotations per document are as follows: [40, 40, 20, 10, 5, 5] If we were to split this trivially over documents, we'd end up with the 3 folds with number of annotations that are far from even: [80, 30, 10] However, if we use the annotations as weights, we would be able to create folds that have more evenly distributed annotations, e.g: [[D1,], [D2], [D3, D4, D5, D6]] where D# denotes the number of the documents, with the number of annotations being equal: [ 40, 40, 20 + 10 + 5 + 5 = 40]

WeightedDocumentsCreator

WeightedDocumentsCreator(mct_export: MedCATTrainerExport, nr_of_folds: int, weight_calculator: Callable[[MedCATTrainerExportDocument], int])

Bases: FoldCreator

Methods:

Source code in medcat-v2/medcat/stats/kfold.py
221
222
223
224
225
226
227
228
229
def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int,
             weight_calculator: Callable[
                 [MedCATTrainerExportDocument], int]) -> None:
    super().__init__(mct_export, nr_of_folds)
    self._weight_calculator = weight_calculator
    docs = [(doc, self._weight_calculator(doc[1]))
            for doc in iter_docs(self.mct_export)]
    # descending order in weight
    self._weighted_docs = sorted(docs, key=lambda d: d[1], reverse=True)

create_folds

create_folds() -> list[MedCATTrainerExport]
Source code in medcat-v2/medcat/stats/kfold.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def create_folds(self) -> list[MedCATTrainerExport]:
    doc_folds: list[list[tuple[MedCATTrainerExportProjectInfo,
                               MedCATTrainerExportDocument]]]
    doc_folds = [[] for _ in range(self.nr_of_folds)]
    fold_weights = [0] * self.nr_of_folds

    for item, weight in self._weighted_docs:
        # Find the subset with the minimum total weight
        min_subset_idx = np.argmin(fold_weights)
        # add the most heavily weighted document
        doc_folds[min_subset_idx].append(item)
        fold_weights[min_subset_idx] += weight

    return [self._create_export_with_documents(docs) for docs in doc_folds]

get_fold_creator

Get the appropriate fold creator.

Parameters:

Raises:

  • ValueError

    In case of an unknown split type.

Returns:

  • FoldCreator ( FoldCreator ) –

    The corresponding fold creator.

Source code in medcat-v2/medcat/stats/kfold.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def get_fold_creator(mct_export: MedCATTrainerExport,
                     nr_of_folds: int,
                     split_type: SplitType) -> FoldCreator:
    """Get the appropriate fold creator.

    Args:
        mct_export (MedCATTrainerExport): The MCT export.
        nr_of_folds (int): Number of folds to use.
        split_type (SplitType): The type of split to use.

    Raises:
        ValueError: In case of an unknown split type.

    Returns:
        FoldCreator: The corresponding fold creator.
    """
    if split_type is SplitType.DOCUMENTS:
        return PerDocsFoldCreator(
            mct_export=mct_export, nr_of_folds=nr_of_folds)
    elif split_type is SplitType.ANNOTATIONS:
        return PerAnnsFoldCreator(
            mct_export=mct_export, nr_of_folds=nr_of_folds)
    elif split_type is SplitType.DOCUMENTS_WEIGHTED:
        return WeightedDocumentsCreator(
            mct_export=mct_export, nr_of_folds=nr_of_folds,
            weight_calculator=get_nr_of_annotations)
    else:
        raise ValueError(f"Unknown Split Type: {split_type}")

get_k_fold_stats

Get the k-fold stats for the model with the specified data.

First this will split the MCT export into k folds. You can do this either per document or per-annotation.

For each of the k folds, it will start from the base model, train it with with the other k-1 folds and record the metrics. After that the base model state is restored before doing the next fold. After all the folds have been done, the metrics are averaged.

Parameters:

  • cat

    (CAT) –

    The model pack.

  • mct_export_data

    (MedCATTrainerExport) –

    The MCT export.

  • k

    (int, default: 3 ) –

    The number of folds. Defaults to 3.

  • use_project_filters

    (bool, default: False ) –

    Whether to use per project filters. Defaults to False.

  • split_type

    (SplitType, default: DOCUMENTS_WEIGHTED ) –

    Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED.

  • include_std

    (bool, default: False ) –

    Whether to include stanrdard deviation. Defaults to False.

  • *args

    Arguments passed to the CAT.train_supervised_raw method.

  • **kwargs

    Keyword arguments passed to the CAT.train_supervised_raw method.

Returns:

  • tuple ( tuple ) –

    The averaged metrics. Potentially with their corresponding standard deviations.

Source code in medcat-v2/medcat/stats/kfold.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def get_k_fold_stats(cat: CAT, mct_export_data: MedCATTrainerExport,
                     k: int = 3, use_project_filters: bool = False,
                     split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED,
                     include_std: bool = False, *args, **kwargs) -> tuple:
    """Get the k-fold stats for the model with the specified data.

    First this will split the MCT export into `k` folds. You can do
    this either per document or per-annotation.

    For each of the `k` folds, it will start from the base model,
    train it with with the other `k-1` folds and record the metrics.
    After that the base model state is restored before doing the next fold.
    After all the folds have been done, the metrics are averaged.

    Args:
        cat (CAT): The model pack.
        mct_export_data (MedCATTrainerExport): The MCT export.
        k (int): The number of folds. Defaults to 3.
        use_project_filters (bool): Whether to use per project filters.
            Defaults to `False`.
        split_type (SplitType): Whether to use annodations or docs.
            Defaults to DOCUMENTS_WEIGHTED.
        include_std (bool): Whether to include stanrdard deviation.
            Defaults to False.
        *args: Arguments passed to the `CAT.train_supervised_raw` method.
        **kwargs: Keyword arguments passed to the `CAT.train_supervised_raw`
            method.

    Returns:
        tuple: The averaged metrics. Potentially with their corresponding
            standard deviations.
    """
    creator = get_fold_creator(mct_export_data, k, split_type=split_type)
    folds = creator.create_folds()
    per_fold_metrics = get_per_fold_metrics(
        cat, folds, use_project_filters, *args, **kwargs)
    means = get_metrics_mean(per_fold_metrics, include_std)
    return means

get_metrics_mean

The the mean of the provided metrics.

Parameters:

  • metrics

    (list[tuple[dict, dict, dict, dict, dict, dict, dict, dict]) –

    The metrics.

  • include_std

    (bool) –

    Whether to include the standard deviation.

Returns:

  • fps ( dict ) –

    False positives for each CUI.

  • fns ( dict ) –

    False negatives for each CUI.

  • tps ( dict ) –

    True positives for each CUI.

  • cui_prec ( dict ) –

    Precision for each CUI.

  • cui_rec ( dict ) –

    Recall for each CUI.

  • cui_f1 ( dict ) –

    F1 for each CUI.

  • cui_counts ( dict ) –

    Number of occurrence for each CUI.

  • examples ( dict ) –

    Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][].

Source code in medcat-v2/medcat/stats/kfold.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def get_metrics_mean(metrics: list[tuple[dict, dict, dict,
                                         dict, dict, dict, dict, dict]],
                     include_std: bool) -> tuple[dict, dict, dict,
                                                 dict, dict, dict, dict, dict]:
    """The the mean of the provided metrics.

    Args:
        metrics (list[tuple[dict, dict, dict, dict, dict, dict, dict, dict]):
            The metrics.
        include_std (bool): Whether to include the standard deviation.

    Returns:
        fps (dict):
            False positives for each CUI.
        fns (dict):
            False negatives for each CUI.
        tps (dict):
            True positives for each CUI.
        cui_prec (dict):
            Precision for each CUI.
        cui_rec (dict):
            Recall for each CUI.
        cui_f1 (dict):
            F1 for each CUI.
        cui_counts (dict):
            Number of occurrence for each CUI.
        examples (dict):
            Examples for each of the fp, fn, tp.
            Format will be examples['fp']['cui'][<list_of_examples>].
    """
    # additives
    all_fps: dict[str, PerCUIMetrics] = {}
    all_fns: dict[str, PerCUIMetrics] = {}
    all_tps: dict[str, PerCUIMetrics] = {}
    # weighted-averages
    all_cui_prec: dict[str, PerCUIMetrics] = {}
    all_cui_rec: dict[str, PerCUIMetrics] = {}
    all_cui_f1: dict[str, PerCUIMetrics] = {}
    # additive
    all_cui_counts: dict[str, PerCUIMetrics] = {}
    # combined
    all_additives = [
        all_fps, all_fns, all_tps, all_cui_counts
    ]
    all_weighted_averages = [
        all_cui_prec, all_cui_rec, all_cui_f1
    ]
    # examples
    all_examples: dict = {}
    for current in metrics:
        cur_wa: list = list(current[3:-2])
        cur_counts = current[-2]
        cur_adds = list(current[:3]) + [cur_counts]
        _add_helper(all_additives, cur_adds)
        _add_weighted_helper(all_weighted_averages, cur_wa, cur_counts)
        cur_examples = current[-1]
        _merge_examples(all_examples, cur_examples)
    # conversion from PerCUI metrics to int/float and (if needed) STD
    cui_fps: IntValuedMetric = {}
    cui_fns: IntValuedMetric = {}
    cui_tps: IntValuedMetric = {}
    cui_prec: FloatValuedMetric = {}
    cui_rec: FloatValuedMetric = {}
    cui_f1: FloatValuedMetric = {}
    final_counts: IntValuedMetric = {}
    to_change: list[Union[IntValuedMetric, FloatValuedMetric]] = [
        cui_fps, cui_fns, cui_tps, final_counts,
        cui_prec, cui_rec, cui_f1,
    ]
    # get the mean and/or std
    for nr, (df, d) in enumerate(zip(to_change, all_additives +
                                     all_weighted_averages)):
        for k, v in d.items():
            if nr == 3 and not include_std:
                # counts need to be added up
                # NOTE: the type:ignore comment _shouldn't_ be necessary
                #       but mypy thinks we're setting a float or integer
                #       where a tuple is expected
                df[k] = sum(v.vals)  # type: ignore
                # NOTE: The current implementation shows the sum for counts
                #       if not STD is required, but the mean along with the
                #       standard deviation if the latter is required.
            elif not include_std:
                df[k] = v.get_mean()
            else:
                # NOTE: the type:ignore comment _shouldn't_ be necessary
                #       but mypy thinks we're setting a tuple
                #       where a float or integer is expected
                df[k] = (v.get_mean(), v.get_std())  # type: ignore
    return (cui_fps, cui_fns, cui_tps, cui_prec, cui_rec, cui_f1,
            final_counts, all_examples)

get_per_fold_metrics

get_per_fold_metrics(cat: CAT, folds: list[MedCATTrainerExport], use_project_filters: bool, *args, **kwargs) -> list[tuple]

Get per fold metrics for a given set of folds.

This method captures the state of the before processing each fold. For each fold, it trains on all other folds, and runs metrics on the fold itself.

Parameters:

Returns:

  • list[tuple]

    list[tuple]: The metrics for each fold.

Source code in medcat-v2/medcat/stats/kfold.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def get_per_fold_metrics(cat: CAT, folds: list[MedCATTrainerExport],
                         use_project_filters: bool,
                         *args, **kwargs) -> list[tuple]:
    """Get per fold metrics for a given set of folds.

    This method captures the state of the before processing each fold.
    For each fold, it trains on all other folds, and runs metrics on
    the fold itself.

    Args:
        cat (CAT): The model pack.
        folds (list[MedCATTrainerExport]): The folds.
        use_project_filters (bool): Whether to use project filters.

    Returns:
        list[tuple]: The metrics for each fold.
    """
    metrics = []
    for fold_nr, cur_fold in enumerate(folds):
        others = list(folds)
        others.pop(fold_nr)
        with captured_state_cdb(cat.cdb):
            for other in others:
                cat.trainer.train_supervised_raw(
                    cast(dict[str, Any], other), *args, **kwargs)
            stats = get_stats(cat, cast(MedCATTrainerExport, cur_fold),
                              use_project_filters=use_project_filters)
            metrics.append(stats)
    return metrics