Skip to content

medcat.stats

Modules:

Functions:

  • get_k_fold_stats

    Get the k-fold stats for the model with the specified data.

  • get_stats

    TODO: Refactor and make nice

get_k_fold_stats

Get the k-fold stats for the model with the specified data.

First this will split the MCT export into k folds. You can do this either per document or per-annotation.

For each of the k folds, it will start from the base model, train it with with the other k-1 folds and record the metrics. After that the base model state is restored before doing the next fold. After all the folds have been done, the metrics are averaged.

Parameters:

  • cat

    (CAT) –

    The model pack.

  • mct_export_data

    (MedCATTrainerExport) –

    The MCT export.

  • k

    (int, default: 3 ) –

    The number of folds. Defaults to 3.

  • use_project_filters

    (bool, default: False ) –

    Whether to use per project filters. Defaults to False.

  • split_type

    (SplitType, default: DOCUMENTS_WEIGHTED ) –

    Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED.

  • include_std

    (bool, default: False ) –

    Whether to include stanrdard deviation. Defaults to False.

  • *args

    Arguments passed to the CAT.train_supervised_raw method.

  • **kwargs

    Keyword arguments passed to the CAT.train_supervised_raw method.

Returns:

  • tuple ( tuple ) –

    The averaged metrics. Potentially with their corresponding standard deviations.

Source code in medcat-v2/medcat/stats/kfold.py
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def get_k_fold_stats(cat: CAT, mct_export_data: MedCATTrainerExport,
                     k: int = 3, use_project_filters: bool = False,
                     split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED,
                     include_std: bool = False, *args, **kwargs) -> tuple:
    """Get the k-fold stats for the model with the specified data.

    First this will split the MCT export into `k` folds. You can do
    this either per document or per-annotation.

    For each of the `k` folds, it will start from the base model,
    train it with with the other `k-1` folds and record the metrics.
    After that the base model state is restored before doing the next fold.
    After all the folds have been done, the metrics are averaged.

    Args:
        cat (CAT): The model pack.
        mct_export_data (MedCATTrainerExport): The MCT export.
        k (int): The number of folds. Defaults to 3.
        use_project_filters (bool): Whether to use per project filters.
            Defaults to `False`.
        split_type (SplitType): Whether to use annodations or docs.
            Defaults to DOCUMENTS_WEIGHTED.
        include_std (bool): Whether to include stanrdard deviation.
            Defaults to False.
        *args: Arguments passed to the `CAT.train_supervised_raw` method.
        **kwargs: Keyword arguments passed to the `CAT.train_supervised_raw`
            method.

    Returns:
        tuple: The averaged metrics. Potentially with their corresponding
            standard deviations.
    """
    creator = get_fold_creator(mct_export_data, k, split_type=split_type)
    folds = creator.create_folds()
    per_fold_metrics = get_per_fold_metrics(
        cat, folds, use_project_filters, *args, **kwargs)
    means = get_metrics_mean(per_fold_metrics, include_std)
    return means

get_stats

TODO: Refactor and make nice Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP.

Parameters:

  • cat

    (CAT) –

    (CAT): The model pack.

  • data

    (dict) –

    The json object that we get from MedCATtrainer on export.

  • epoch

    (int, default: 0 ) –

    Used during training, so we know what epoch is it.

  • use_project_filters

    (bool, default: False ) –

    Each project in MedCATtrainer can have filters, do we want to respect those filters when calculating metrics.

  • use_overlaps

    (bool, default: False ) –

    Allow overlapping entities, nearly always False as it is very difficult to annotate overlapping entities.

  • use_cui_doc_limit

    (bool) –

    If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words if the document was annotated for that CUI. Useful in very specific situations when during the annotation process the set of CUIs changed.

  • use_groups

    (bool) –

    If True concepts that have groups will be combined and stats will be reported on groups.

  • extra_cui_filter

    (Optional[set], default: None ) –

    This filter will be intersected with all other filters, or if all others are not set then only this one will be used.

  • do_print

    (bool, default: True ) –

    Whether to print stats out. Defaults to True.

Returns:

  • fps ( dict ) –

    False positives for each CUI.

  • fns ( dict ) –

    False negatives for each CUI.

  • tps ( dict ) –

    True positives for each CUI.

  • cui_prec ( dict ) –

    Precision for each CUI.

  • cui_rec ( dict ) –

    Recall for each CUI.

  • cui_f1 ( dict ) –

    F1 for each CUI.

  • cui_counts ( dict ) –

    Number of occurrence for each CUI.

  • examples ( dict ) –

    Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][].

Source code in medcat-v2/medcat/stats/stats.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def get_stats(cat: CAT,
              data: MedCATTrainerExport,
              epoch: int = 0,
              use_project_filters: bool = False,
              use_overlaps: bool = False,
              #   use_cui_doc_limit: bool = False,
              #   use_groups: bool = False,
              extra_cui_filter: Optional[set[str]] = None,
              do_print: bool = True) -> tuple[
        dict[str, int], dict[str, int], dict[str, int],
        dict[str, float], dict[str, float], dict[str, float],
        dict[str, int], dict
]:
    """TODO: Refactor and make nice
    Print metrics on a dataset (F1, P, R), it will also print the concepts
    that have the most FP,FN,TP.

    Args:
        cat: (CAT):
            The model pack.
        data (dict):
            The json object that we get from MedCATtrainer on export.
        epoch (int):
            Used during training, so we know what epoch is it.
        use_project_filters (bool):
            Each project in MedCATtrainer can have filters, do we want to
            respect those filters when calculating metrics.
        use_overlaps (bool):
            Allow overlapping entities, nearly always False as it is very
            difficult to annotate overlapping entities.
        use_cui_doc_limit (bool):
            If True the metrics for a CUI will be only calculated if that CUI
            appears in a document, in other words if the document was
            annotated for that CUI. Useful in very specific situations when
            during the annotation process the set of CUIs changed.
        use_groups (bool):
            If True concepts that have groups will be combined and stats will
            be reported on groups.
        extra_cui_filter(Optional[set]):
            This filter will be intersected with all other filters, or if all
            others are not set then only this one will be used.
        do_print (bool):
            Whether to print stats out. Defaults to True.

    Returns:
        fps (dict):
            False positives for each CUI.
        fns (dict):
            False negatives for each CUI.
        tps (dict):
            True positives for each CUI.
        cui_prec (dict):
            Precision for each CUI.
        cui_rec (dict):
            Recall for each CUI.
        cui_f1 (dict):
            F1 for each CUI.
        cui_counts (dict):
            Number of occurrence for each CUI.
        examples (dict):
            Examples for each of the fp, fn, tp.
            Format will be examples['fp']['cui'][<list_of_examples>].
    """
    builder = StatsBuilder.from_cat(cat,
                                    use_project_filters=use_project_filters,
                                    use_overlaps=use_overlaps,
                                    # use_cui_doc_limit=use_cui_doc_limit,
                                    # use_groups=use_groups,
                                    extra_cui_filter=extra_cui_filter)
    for pind, project in tqdm(enumerate(data['projects']),
                              desc="Stats project",
                              total=len(data['projects']),
                              leave=False):
        with project_filters(cat.config.components.linking.filters,
                             project,
                             builder.extra_cui_filter,
                             builder.use_project_filters):
            builder.process_project(project)
    # this is the part that prints out the stats
    builder.finalise_report(epoch, do_print=do_print)
    return builder.unwrap()