API Reference

This reference is automatically generated from the docstrings in the kraken source code.

High-Level API

These modules provide the main entry points for using kraken programmatically.

kraken.containers

kraken.containers

Container classes replacing the old dictionaries returned by kraken’s functional blocks.

class kraken.containers.BBoxLine

Bounding box-type line record.

A container class for a single line in axis-aligned bounding box format, optionally containing a transcription, tags, or associated regions.

bbox

tuple in form (xmin, ymin, xmax, ymax) defining the bounding box.

text_direction

Sets the principal orientation (of the line) and reading direction (of the document).

bbox: tuple[int, int, int, int] | None = None
text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr'
type: str = 'bbox'
class kraken.containers.BBoxOCRRecord(prediction, cuts, confidences, line, base_dir=None, display_order=True, logits=None, image=None)

A record object containing the recognition result of a single line in bbox format.

Parameters:
  • prediction (str)

  • cuts (list[tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]]])

  • confidences (list[float])

  • line (Union[BBoxLine, dict[str, Any]])

  • base_dir (Optional[Literal['L', 'R']])

  • display_order (bool)

  • logits (Optional[torch.FloatTensor])

  • image (Optional[PIL.Image.Image])

type

‘bbox’ to indicate a bounding box record

prediction

The text predicted by the network as one continuous string.

Return type:

str

cuts

The absolute bounding polygons for each code point in prediction as a list of 4-tuples ((x0, y0), (x1, y0), (x1, y1), (x0, y1)).

Return type:

list

confidences

A list of floats indicating the confidence value of each code point.

Return type:

list[float]

base_dir

An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

display_order

Flag indicating the order of the code points in the prediction. In display order (True) the n-th code point in the string corresponds to the n-th leftmost code point, in logical order (False) the n-th code point corresponds to the n-th read code point. See [UAX #9](https://unicode.org/reports/tr9) for more details.

Parameters:

base_dir (Optional[Literal['L', 'R']])

Return type:

BBoxOCRRecord

logits

The logits for the prediction.

image

The line image used to produce the prediction.

Notes

When slicing the record the behavior of the cuts is changed from earlier versions of kraken. Instead of returning per-character bounding polygons a single polygons section of the line bounding polygon starting at the first and extending to the last code point emitted by the network is returned. This aids numerical stability when computing aggregated bounding polygons such as for words. Individual code point bounding polygons are still accessible through the cuts attribute or by iterating over the record code point by code point.

base_dir = None
display_order(base_dir=None)

Returns the OCR record in Unicode display order, i.e. ordered from left to right inside the line.

Parameters:

base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

Return type:

BBoxOCRRecord

logical_order(base_dir=None)

Returns the OCR record in Unicode logical order, i.e. in the order the characters in the line would be read by a human.

Parameters:

base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

Return type:

BBoxOCRRecord

type = 'bbox'
class kraken.containers.BaselineLine

Baseline-type line record.

A container class for a single line in baseline + bounding polygon format, optionally containing a transcription, tags, or associated regions.

baseline

list of tuples (x_n, y_n) defining the baseline.

boundary

list of tuples (x_n, y_n) defining the bounding polygon of the line. The first and last points should be identical.

baseline: list[tuple[int, int]] | None = None
boundary: list[tuple[int, int]] | None = None
type: str = 'baselines'
class kraken.containers.BaselineOCRRecord(prediction, cuts, confidences, line, base_dir=None, display_order=True, logits=None, image=None)

A record object containing the recognition result of a single line in baseline format.

Parameters:
  • prediction (str)

  • cuts (list[tuple[int, int]])

  • confidences (list[float])

  • line (Union[BaselineLine, dict[str, Any]])

  • base_dir (Optional[Literal['L', 'R']])

  • display_order (bool)

  • logits (Optional[torch.FloatTensor])

  • image (Optional[PIL.Image.Image])

type

‘baselines’ to indicate a baseline record

prediction

The text predicted by the network as one continuous string.

Return type:

str

cuts

The absolute bounding polygons for each code point in prediction as a list of tuples [(x0, y0), (x1, y2), …].

Return type:

tuple

confidences

A list of floats indicating the confidence value of each code point.

Return type:

list[float]

base_dir

An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

display_order

Flag indicating the order of the code points in the prediction. In display order (True) the n-th code point in the string corresponds to the n-th leftmost code point, in logical order (False) the n-th code point corresponds to the n-th read code point. See [UAX #9](https://unicode.org/reports/tr9) for more details.

Parameters:

base_dir (Optional[Literal['L', 'R']])

Return type:

BaselineOCRRecord

logits

The logits for the prediction.

image

The line image used to produce the prediction.

Notes

When slicing the record the behavior of the cuts is changed from earlier versions of kraken. Instead of returning per-character bounding polygons a single polygons section of the line bounding polygon starting at the first and extending to the last code point emitted by the network is returned. This aids numerical stability when computing aggregated bounding polygons such as for words. Individual code point bounding polygons are still accessible through the cuts attribute or by iterating over the record code point by code point.

base_dir = None
property cuts: tuple
Return type:

tuple

display_order(base_dir=None)

Returns the OCR record in Unicode display order, i.e. ordered from left to right inside the line.

Parameters:

base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

Return type:

BaselineOCRRecord

logical_order(base_dir=None)

Returns the OCR record in Unicode logical order, i.e. in the order the characters in the line would be read by a human.

Parameters:

base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

Return type:

BaselineOCRRecord

type = 'baselines'
class kraken.containers.ProcessingStep

A processing step in the recognition pipeline.

id

Unique identifier

category

Category of processing step that has been performed.

description

Natural-language description of the process.

settings

dict describing the parameters of the processing step.

category: Literal['preprocessing', 'processing', 'postprocessing']
description: str
id: str
settings: dict[str, dict | str | float | int | bool]
class kraken.containers.Region

Container class of a single polygonal region.

id

Unique identifier

boundary

list of tuples (x_n, y_n) defining the bounding polygon of the region. The first and last points should be identical.

imagename

Path to the image associated with the region.

tags

A dict mapping types to values.

boundary: list[tuple[int, int]]
id: str
imagename: str | os.PathLike | None = None
language: list[str] | None = None
tags: dict[str, list[dict[str, str]]] | None = None
class kraken.containers.Segmentation

A container class for segmentation or recognition results.

In order to allow easy JSON de-/serialization, nested classes for lines (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their dictionaries.

type

Field indicating if baselines (kraken.containers.BaselineLine) or bbox (kraken.containers.BBoxLine) line records are in the segmentation.

imagename

Path to the image associated with the segmentation.

text_direction

Sets the principal orientation (of the line), i.e. horizontal/vertical, and reading direction (of the document), i.e. lr/rl.

script_detection

Flag indicating if the line records have tags.

lines

list of line records. Records are expected to be in a valid reading order.

regions

dict mapping types to lists of regions.

line_orders

list of alternative reading orders for the segmentation. Each reading order is a list of line indices.

imagename: str | os.PathLike
language: list[str] | None = None
line_orders: list[list[int]] | None = None
lines: list[BaselineLine | BBoxLine] | None = None
regions: dict[str, list[Region]] | None = None
script_detection: bool
text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl']
type: Literal['baselines', 'bbox']
kraken.containers.compute_polygon_section(baseline, boundary, dist1, dist2)

Given a baseline, polygonal boundary, and two points on the baseline return the rectangle formed by the orthogonal cuts on that baseline segment. The resulting polygon is not garantueed to have a non-zero area.

The distance can be larger than the actual length of the baseline if the baseline endpoints are inside the bounding polygon. In that case the baseline will be extrapolated to the polygon edge.

Parameters:
  • baseline (collections.abc.Sequence[tuple[int, int]]) – A polyline ((x1, y1), …, (xn, yn))

  • boundary (collections.abc.Sequence[tuple[int, int]]) – A bounding polygon around the baseline (same format as baseline). Last and first point are automatically connected.

  • dist1 (int) – Absolute distance along the baseline of the first point.

  • dist2 (int) – Absolute distance along the baseline of the second point.

Returns:

A sequence of polygon points.

Return type:

tuple[tuple[int, int]]

class kraken.containers.ocr_line

A line record.

id

Unique identifier

text

Transcription of this line.

base_dir

An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.

imagename

Path to the image associated with the line.

tags

A dict mapping types to values.

split

Defines whether this line is in the train, validation, or test set during training.

regions

A list of identifiers of regions the line is associated with.

languages

A list of identifiers of regions the line is associated with.

base_dir: Literal['L', 'R'] | None = None
id: str
imagename: str | os.PathLike | None = None
language: list[str] | None = None
regions: list[str] | None = None
split: Literal['train', 'validation', 'test'] | None = None
tags: dict[str, list[dict[str, str]]] | None = None
text: str | None = None
class kraken.containers.ocr_record(prediction, cuts, confidences, display_order=True, logits=None, image=None)

A record object containing the recognition result of a single line

Parameters:
  • prediction (str)

  • cuts (list[Union[tuple[int, int], tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]]]])

  • confidences (list[float])

  • display_order (bool)

  • logits (Optional[torch.FloatTensor])

  • image (Optional[PIL.Image.Image])

base_dir = None
property confidences: list[float]
Return type:

list[float]

property cuts: list
Return type:

list

abstractmethod display_order(base_dir)
Return type:

ocr_record

image = None
abstractmethod logical_order(base_dir)
Return type:

ocr_record

logits = None
property prediction: str
Return type:

str

abstract property type
kraken.containers.precompute_polygon_sections(baseline, boundary, cut_pairs)

Batch-precompute polygon sections for all characters, amortizing the expensive baseline extension and cumulative distance computation.

Parameters:
  • baseline (collections.abc.Sequence[tuple[int, int]]) – A polyline ((x1, y1), …, (xn, yn))

  • boundary (collections.abc.Sequence[tuple[int, int]]) – A bounding polygon around the baseline.

  • cut_pairs (list[tuple[int, int]]) – List of (dist1, dist2) tuples for each character.

Returns:

  • char_polygons is a list of polygon tuples, one per character

  • intersection_cache maps clamped distance values to raw _test_intersect result arrays (or None on failure)

  • bl_length is the total baseline length after extension

Return type:

A tuple of (char_polygons, intersection_cache, bl_length) where

kraken.lib.xml

ALTO/Page data loaders for segmentation training

class kraken.lib.xml.Segmentation

A container class for segmentation or recognition results.

In order to allow easy JSON de-/serialization, nested classes for lines (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their dictionaries.

type

Field indicating if baselines (kraken.containers.BaselineLine) or bbox (kraken.containers.BBoxLine) line records are in the segmentation.

imagename

Path to the image associated with the segmentation.

text_direction

Sets the principal orientation (of the line), i.e. horizontal/vertical, and reading direction (of the document), i.e. lr/rl.

script_detection

Flag indicating if the line records have tags.

lines

list of line records. Records are expected to be in a valid reading order.

regions

dict mapping types to lists of regions.

line_orders

list of alternative reading orders for the segmentation. Each reading order is a list of line indices.

imagename: str | os.PathLike
language: list[str] | None = None
line_orders: list[list[int]] | None = None
lines: list[BaselineLine | BBoxLine] | None = None
regions: dict[str, list[Region]] | None = None
script_detection: bool
text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl']
type: Literal['baselines', 'bbox']
class kraken.lib.xml.XMLPage(filename, filetype='xml', linetype='baselines')

Parses XML facsimiles in ALTO or PageXML format.

The parser is able to deal with most (but not all) features supported by those standards. In particular, any data below the line level is discarded.

Parameters:
  • filename (Union[str, os.PathLike]) – Path to the XML file

  • filetype (Literal['xml', 'alto', 'page']) – Selector for explicit subparser choice.

  • linetype (Literal['baselines', 'bbox']) – Parse line data as baselines or bounding box type.

type

Either ‘baselines’ or ‘bbox’.

imagename

Path to the image to the XML file.

image_size

Size of the image as a (width, height) tuple

has_tags

Indicates if the source document contains tag information

has_splits

Indicates if the source document contains explicit training splits

base_dir: Literal['L', 'R'] | None = None
filename
filetype = 'xml'
get_lines_by_split(split)
Parameters:

split (Literal['train', 'validation', 'test'])

get_lines_by_tag(key, value)
get_sorted_lines(ro='line_implicit')

Returns ordered baselines from particular reading order.

get_sorted_lines_by_region(region, ro='line_implicit')

Returns ordered lines in region.

get_sorted_regions(ro='region_implicit')

Returns ordered regions from particular reading order.

has_splits: bool = False
has_tags: bool = False
image_size: tuple[int, int] = None
imagename: os.PathLike = None
property lines
property reading_orders
property regions
property splits
property tags
to_container()

Returns a Segmentation object.

Return type:

kraken.containers.Segmentation

type: Literal['baselines', 'bbox'] = 'baselines'
kraken.lib.xml.alto_regions
kraken.lib.xml.flatten_order_to_lines(raw_order, lines_dict, region_ids, line_implicit_order, string_to_line_map=None, missing_region_ids=None)

Flatten a raw reading order (list of IDs) to line-level.

For each ID: - Line ID: append directly - Region ID: expand to contained lines using implicit order - String ID (ALTO only): map to parent TextLine, deduplicate consecutive - Unknown ID: log warning, skip

Parameters:
  • raw_order (list[str])

  • lines_dict (dict)

  • region_ids (set[str])

  • line_implicit_order (list[str])

  • string_to_line_map (Optional[dict[str, str]])

  • missing_region_ids (Optional[set[str]])

Return type:

list[str]

kraken.lib.xml.flatten_order_to_regions(raw_order, lines_dict, region_ids, string_to_line_map=None, missing_region_ids=None)

Flatten a raw reading order (list of IDs) to region-level.

For each ID: - Region ID: append directly - Line ID: resolve to parent region, deduplicate consecutive - String ID (ALTO only): resolve to parent line then region, deduplicate - Unknown ID: log warning, skip

Parameters:
  • raw_order (list[str])

  • lines_dict (dict)

  • region_ids (set[str])

  • string_to_line_map (Optional[dict[str, str]])

  • missing_region_ids (Optional[set[str]])

Return type:

list[str]

kraken.lib.xml.logger
kraken.lib.xml.page_regions
kraken.lib.xml.parse_alto(doc, filename, linetype)

Parse an ALTO XML document.

Parameters:
  • doc – Parsed lxml document.

  • filename – Path to the XML file (for error messages and resolving image paths).

  • linetype – ‘baselines’ or ‘bbox’.

Returns:

imagename, image_size, regions, lines, orders, tag_set,

raw_orders, string_to_line_map

Return type:

dict with keys

kraken.lib.xml.parse_page(doc, filename, linetype)

Parse a PageXML document.

Parameters:
  • doc – Parsed lxml document.

  • filename – Path to the XML file (for error messages and resolving image paths).

  • linetype – ‘baselines’ or ‘bbox’.

Returns:

imagename, image_size, regions, lines, orders, tag_set,

raw_orders

Return type:

Dict with keys

kraken.lib.xml.validate_and_clean_order(flat_order, valid_ids)

Validate a flattened order.

Checks: - All IDs exist in valid_ids - No duplicate IDs (indicates circular reference)

Returns:

(cleaned_order, is_valid)

Parameters:
  • flat_order (list[str])

  • valid_ids (set[str])

Return type:

tuple[list[str], bool]

kraken.tasks

class kraken.tasks.ForcedAlignmentTaskModel(models)

A wrapper for forced alignment of CTC output.

Using a text recognition model the existing transcription of a page will be aligned to the character positions of the network output.

Raises:

ValueError – Is raised when the model type is not a sequence recognizer.

Parameters:

models (list[torch.nn.Module])

classmethod load_model(path)
Parameters:

path (Union[str, os.PathLike])

net
one_channel_mode
predict(im, segmentation, config)

Aligns the transcription of an image with the output of the text recognition model, producing approximate character locations.

When the character sets of transcription and recognition model differ, the affected code points in the furnished transcription will silently be ignored. In case inference fails on a line, a record without cuts/confidences is returned.

Parameters:
  • im (PIL.Image.Image) – The input image

  • segmentation (kraken.containers.Segmentation) – A segmentation with transcriptions to align.

  • config (kraken.configs.RecognitionInferenceConfig) – A recognition inference configuration. The task model will automatically set some required configuration flags on it.

Returns:

A single segmentation that contains the aligned ocr_record objects.

Return type:

kraken.containers.Segmentation

Example

>>> from PIL import Image
>>> from kraken.tasks import ForcedAlignmentTaskModel
>>> from kraken.containers import Segmentation, BaselineLine
>>> from kraken.configs import RecognitionInferenceConfig
>>> # Assume `model.mlmodel` is a recognition model
>>> model = ForcedAlignmentTaskModel.load_model('model.mlmodel')
>>> im = Image.open('image.png')
>>> # Create a dummy segmentation with a line and a transcription
>>> line = BaselineLine(baseline=[(0,0), (100,0)], boundary=[(0,-10), (100,-10), (100,10), (0,10)], text='Hello World')
>>> segmentation = Segmentation(lines=[line])
>>> config = RecognitionInferenceConfig()
>>> aligned_segmentation = model.predict(im, segmentation, config)
>>> record = aligned_segmentation.lines[0]:
>>> print(record.prediction)
>>> print(record.cuts)
seg_type
class kraken.tasks.RecognitionTaskModel(models)

A wrapper for a model performing a recognition task.

A recognition task is the process of transcribing a line of text from an image. This class provides a high-level interface for running a recognition model on an image, given a segmentation.

Raises:

ValueError – Is raised when the model type is not a sequence recognizer.

Parameters:

models (list[torch.nn.Module])

classmethod load_model(path)
Parameters:

path (Union[str, os.PathLike])

net
one_channel_mode
predict(im, segmentation, config)

Inference using a recognition model.

Parameters:
  • im (PIL.Image.Image) – Input image

  • segmentation (kraken.containers.Segmentation) – The segmentation corresponding to the input image.

  • config (kraken.configs.RecognitionInferenceConfig) – A configuration object containing inference parameters, such as the batch size and the precision.

Yields:

One ocr_record for each line.

Return type:

collections.abc.Generator[kraken.containers.ocr_record, None, None]

Example

>>> from PIL import Image
>>> from kraken.tasks import RecognitionTaskModel
>>> from kraken.containers import Segmentation
>>> from kraken.configs import RecognitionInferenceConfig
>>> model = RecognitionTaskModel.load_model('model.mlmodel')
>>> im = Image.open('image.png')
>>> segmentation = Segmentation(...)
>>> config = RecognitionInferenceConfig()
>>> for record in model.predict(im, segmentation, config):
...     print(record.prediction)
seg_type
class kraken.tasks.SegmentationTaskModel(models)

A wrapper class collecting one or more models that perform segmentation.

A segmentation task is the process of identifying the regions and lines of text in an image. This class provides a high-level interface for running segmentation models on an image.

It deals with the following tasks:
  • region segmentation

  • line detection

  • line reading order

If no neural reading order model is part of the model collection handed to the task, a simple heuristic will be used.

Parameters:

models (list[torch.nn.Module]) – A collection of models performing segmentation tasks.

Raises:

ValueError – Is raised when no segmentation models are in the model list.

classmethod load_model(path=None)

Loads a collection from layout analysis models from the given file path.

If no path is provided, the default BLLA segmentation model will be loaded.

Parameters:

path (Optional[Union[str, os.PathLike]]) – Path to model weights file.

Return type:

SegmentationTaskModel

predict(im, config)

Runs all models associated with the task to produce a segmentation for the input page.

Parameters:
  • im (PIL.Image.Image) – Input image with an arbitrary color mode and size

  • config (kraken.configs.SegmentationInferenceConfig) – A configuration object for the segmentation task, such as the batch size and the precision.

Returns:

A single Segmentation object that contains the merged output of all associated segmentation models.

Return type:

kraken.containers.Segmentation

Example

>>> from PIL import Image
>>> from kraken.tasks import SegmentationTaskModel
>>> from kraken.configs import SegmentationInferenceConfig
>>> model = SegmentationTaskModel.load_model()
>>> im = Image.open('image.png')
>>> config = SegmentationInferenceConfig()
>>> segmentation = model.predict(im, config)
ro_models
seg_models

kraken.train

class kraken.train.BLLASegmentationDataModule(data_config)

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def on_exception(self, exception):
        # clean up state after the trainer faced an exception
        ...

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...
Parameters:

data_config (kraken.configs.BLLASegmentationTrainingDataConfig)

setup(stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader()

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader()

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class kraken.train.BLLASegmentationModel(config, model=None)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
  • config (kraken.configs.BLLASegmentationTrainingConfig)

  • model (Optional[kraken.models.BaseModel])

configure_callbacks()

Configure model-specific callbacks. When the model gets attached, e.g., when .fit() or .test() gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’s callbacks argument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sure ModelCheckpoint callbacks run last.

Returns:

A callback or a list of callbacks which will extend the list of callbacks in the Trainer.

Example:

def configure_callbacks(self):
    early_stop = EarlyStopping(monitor="val_acc", mode="max")
    checkpoint = ModelCheckpoint(monitor="val_loss")
    return [early_stop, checkpoint]
configure_optimizers()

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

criterion
example_input_array

The example input array is a specification of what the module can consume in the forward() method. The return type is interpreted as follows:

  • Single tensor: It is assumed the model takes a single argument, i.e., model.forward(model.example_input_array)

  • Tuple: The input array should be interpreted as a sequence of positional arguments, i.e., model.forward(*model.example_input_array)

  • Dict: The input array represents named keyword arguments, i.e., model.forward(**model.example_input_array)

forward(x)

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

  • x (torch.Tensor)

Returns:

Your model’s output

Return type:

tuple[torch.Tensor, torch.Tensor]

classmethod load_from_weights(path, config)

Initializes the module from a model weights file.

Parameters:
  • path (Union[str, os.PathLike])

  • config (kraken.configs.BLLASegmentationTrainingConfig)

Return type:

BLLASegmentationModel

lr_scheduler_step(scheduler, metric)

Override this method to adjust the default way the Trainer calls each scheduler. By default, Lightning calls step() and as shown in the example for each scheduler based on its interval.

Parameters:
  • scheduler – Learning rate scheduler.

  • metric – Value of the monitor used for schedulers like ReduceLROnPlateau.

Examples:

# DEFAULT
def lr_scheduler_step(self, scheduler, metric):
    if metric is None:
        scheduler.step()
    else:
        scheduler.step(metric)

# Alternative way to update schedulers if it requires an epoch value
def lr_scheduler_step(self, scheduler, metric):
    scheduler.step(epoch=self.current_epoch)
on_load_checkpoint(checkpoint)

Reconstruct the model from the spec here and not in setup() as otherwise the weight loading will fail.

on_save_checkpoint(checkpoint)

Save hyperparameters a second time so we can set parameters that shouldn’t be overwritten in on_load_checkpoint.

on_test_epoch_end()

Called in the test loop at the very end of the epoch.

on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)

Override this method to adjust the default way the Trainer calls the optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:
  • epoch – Current epoch

  • batch_idx – Index of current batch

  • optimizer – A PyTorch optimizer

  • optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

Examples:

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    # Add your custom logic to run directly before `optimizer.step()`

    optimizer.step(closure=optimizer_closure)

    # Add your custom logic to run directly after `optimizer.step()`
setup(stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (Optional[str]) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_step(batch, batch_idx, test_dataloader=0)

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

training_step(batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

class kraken.train.KrakenTrainer(enable_progress_bar=True, enable_summary=True, min_epochs=5, max_epochs=100, freeze_backbone=-1, pl_logger=None, log_dir=None, *args, **kwargs)
Parameters:
  • enable_progress_bar (bool)

  • enable_summary (bool)

  • min_epochs (int)

  • max_epochs (int)

  • pl_logger (Union[lightning.pytorch.loggers.logger.Logger, str, None])

  • log_dir (Optional[os.PathLike])

automatic_optimization = False
fit(*args, **kwargs)

Runs the full optimization routine.

Parameters:
  • model – Model to fit.

  • train_dataloaders – An iterable or collection of iterables specifying training samples. Alternatively, a LightningDataModule that defines the train_dataloader hook.

  • val_dataloaders – An iterable or collection of iterables specifying validation samples.

  • datamodule – A LightningDataModule that defines the train_dataloader hook.

  • ckpt_path

    Path/URL of the checkpoint from which training is resumed. Could also be one of three special keywords "last", "hpc" and "registry". Otherwise, if there is no checkpoint file at the path, an exception is raised.

    • best: the best model checkpoint from the previous trainer.fit call will be loaded

    • last: the last model checkpoint from the previous trainer.fit call will be loaded

    • registry: the model will be downloaded from the Lightning Model Registry with following notations:

      • 'registry': uses the latest/default version of default model set with Trainer(..., model_registry="my-model")

      • 'registry:model-name': uses the latest/default version of this model model-name

      • 'registry:model-name:version:v2': uses the specific version ‘v2’ of the model model-name

      • 'registry:version:v2': uses the default model set with Trainer(..., model_registry="my-model") and version ‘v2’

  • weights_only – Defaults to None. If True, restricts loading to state_dicts of plain torch.Tensor and other primitive types. If loading a checkpoint from a trusted source that contains an nn.Module, use weights_only=False. If loading checkpoint from an untrusted source, we recommend using weights_only=True. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.

For more information about multiple dataloaders, see this section.

Return type:

None

Raises:

TypeError – If model is not LightningModule for torch version less than 2.0.0 and if model is not LightningModule or torch._dynamo.OptimizedModule for torch versions greater than or equal to 2.0.0 .

test(*args, **kwargs)

Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to.

Parameters:
  • model – The model to test.

  • dataloaders – An iterable or collection of iterables specifying test samples. Alternatively, a LightningDataModule that defines the test_dataloader hook.

  • ckpt_path – Either "best", "last", "hpc", "registry" or path to the checkpoint you wish to test. If None and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous trainer.fit call will be loaded if a checkpoint callback is configured.

  • verbose – If True, prints the test results.

  • datamodule – A LightningDataModule that defines the test_dataloader hook.

  • weights_only

    Defaults to None. If True, restricts loading to state_dicts of plain torch.Tensor and other primitive types. If loading a checkpoint from a trusted source that contains an nn.Module, use weights_only=False. If loading checkpoint from an untrusted source, we recommend using weights_only=True. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.

Return type:

TestMetrics

For more information about multiple dataloaders, see this section.

Returns:

List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like test_step() etc. The length of the list corresponds to the number of test dataloaders used.

Raises:
  • TypeError – If no model is passed and there was no LightningModule passed in the previous run. If model passed is not LightningModule or torch._dynamo.OptimizedModule.

  • MisconfigurationException – If both dataloaders and datamodule are passed. Pass only one of these.

  • RuntimeError – If a compiled model is passed and the strategy is not supported.

Return type:

TestMetrics

class kraken.train.VGSLRecognitionDataModule(data_config)

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def on_exception(self, exception):
        # clean up state after the trainer faced an exception
        ...

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...
Parameters:

data_config (kraken.configs.VGSLRecognitionTrainingDataConfig)

setup(stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader()

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader()

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class kraken.train.VGSLRecognitionModel(config, model=None)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
  • config (kraken.configs.VGSLRecognitionTrainingConfig)

  • model (Optional[kraken.models.BaseModel])

configure_callbacks()

Configure model-specific callbacks. When the model gets attached, e.g., when .fit() or .test() gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’s callbacks argument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sure ModelCheckpoint callbacks run last.

Returns:

A callback or a list of callbacks which will extend the list of callbacks in the Trainer.

Example:

def configure_callbacks(self):
    early_stop = EarlyStopping(monitor="val_acc", mode="max")
    checkpoint = ModelCheckpoint(monitor="val_loss")
    return [early_stop, checkpoint]
configure_optimizers()

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

example_input_array

The example input array is a specification of what the module can consume in the forward() method. The return type is interpreted as follows:

  • Single tensor: It is assumed the model takes a single argument, i.e., model.forward(model.example_input_array)

  • Tuple: The input array should be interpreted as a sequence of positional arguments, i.e., model.forward(*model.example_input_array)

  • Dict: The input array represents named keyword arguments, i.e., model.forward(**model.example_input_array)

forward(x, seq_lens=None)

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

classmethod load_from_weights(path, config)

Initializes the module from a model weights file.

Parameters:
  • path (Union[str, os.PathLike])

  • config (kraken.configs.VGSLRecognitionTrainingConfig)

Return type:

VGSLRecognitionModel

lr_scheduler_step(scheduler, metric)

Override this method to adjust the default way the Trainer calls each scheduler. By default, Lightning calls step() and as shown in the example for each scheduler based on its interval.

Parameters:
  • scheduler – Learning rate scheduler.

  • metric – Value of the monitor used for schedulers like ReduceLROnPlateau.

Examples:

# DEFAULT
def lr_scheduler_step(self, scheduler, metric):
    if metric is None:
        scheduler.step()
    else:
        scheduler.step(metric)

# Alternative way to update schedulers if it requires an epoch value
def lr_scheduler_step(self, scheduler, metric):
    scheduler.step(epoch=self.current_epoch)
on_load_checkpoint(checkpoint)

Reconstruct the model from the spec here and not in setup() as otherwise the weight loading will fail.

on_save_checkpoint(checkpoint)

Save hyperparameters a second time so we can set parameters that shouldn’t be overwritten in on_load_checkpoint.

on_test_epoch_end()

Called in the test loop at the very end of the epoch.

on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)

Override this method to adjust the default way the Trainer calls the optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:
  • epoch – Current epoch

  • batch_idx – Index of current batch

  • optimizer – A PyTorch optimizer

  • optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

Examples:

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    # Add your custom logic to run directly before `optimizer.step()`

    optimizer.step(closure=optimizer_closure)

    # Add your custom logic to run directly after `optimizer.step()`
setup(stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (Optional[str]) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_step(batch, batch_idx, test_dataloader=0)

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

training_step(batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

Low-Level API

These modules provide lower-level access to the core components of kraken. In most cases, it is recommended to use the high-level API instead.

kraken.ketos.pretrain

Command line driver for unsupervised recognition pretraining

Top-level module containing datasets for recognition and segmentation training.

Pytorch compatible codec with many-to-many mapping between labels and graphemes.

kraken.lib.exceptions

All custom exceptions raised by kraken’s modules and packages. Packages should always define their exceptions here.

Processing for baseline segmenter output

Decoders for softmax outputs of CTC trained networks.

Decoders extract label sequences out of the raw output matrix of the line recognition network. There are multiple different approaches implemented here, from a simple greedy decoder, to the legacy ocropy thresholding decoder, and a more complex beam search decoder.

Extracted label sequences are converted into the code point domain using kraken.lib.codec.PytorchCodec.

Legacy Modules

These modules are retained for compatibility reasons or highly specialized use cases. Their use is not recommended.

kraken.binarization

An adaptive binarization algorithm. This code is legacy and only remains for historical reasons. Binarization is no longer necessary for most workflows.

kraken.pageseg

The legacy bounding box segmentation method using conventional image processing techniques.

kraken.rpred

Legacy line text recognition API. New code should use the RecognitionTaskModel from kraken.tasks which is more versatile and offers higher performance.

kraken.lib.models

Wrapper around TorchVGSLModel including a variety of forward pass helpers for sequence classification.