API Reference¶
This reference is automatically generated from the docstrings in the kraken source code.
High-Level API¶
These modules provide the main entry points for using kraken programmatically.
kraken.containers¶
kraken.containers¶
Container classes replacing the old dictionaries returned by kraken’s functional blocks.
- class kraken.containers.BBoxLine¶
Bounding box-type line record.
A container class for a single line in axis-aligned bounding box format, optionally containing a transcription, tags, or associated regions.
- bbox¶
tuple in form (xmin, ymin, xmax, ymax) defining the bounding box.
- text_direction¶
Sets the principal orientation (of the line) and reading direction (of the document).
- bbox: tuple[int, int, int, int] | None = None¶
- text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr'¶
- type: str = 'bbox'¶
- class kraken.containers.BBoxOCRRecord(prediction, cuts, confidences, line, base_dir=None, display_order=True, logits=None, image=None)¶
A record object containing the recognition result of a single line in bbox format.
- Parameters:
prediction (str)
cuts (list[tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]]])
confidences (list[float])
line (Union[BBoxLine, dict[str, Any]])
base_dir (Optional[Literal['L', 'R']])
display_order (bool)
logits (Optional[torch.FloatTensor])
image (Optional[PIL.Image.Image])
- type¶
‘bbox’ to indicate a bounding box record
- prediction¶
The text predicted by the network as one continuous string.
- Return type:
str
- cuts¶
The absolute bounding polygons for each code point in prediction as a list of 4-tuples ((x0, y0), (x1, y0), (x1, y1), (x0, y1)).
- Return type:
list
- confidences¶
A list of floats indicating the confidence value of each code point.
- Return type:
list[float]
- base_dir¶
An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- display_order¶
Flag indicating the order of the code points in the prediction. In display order (True) the n-th code point in the string corresponds to the n-th leftmost code point, in logical order (False) the n-th code point corresponds to the n-th read code point. See [UAX #9](https://unicode.org/reports/tr9) for more details.
- Parameters:
base_dir (Optional[Literal['L', 'R']])
- Return type:
- logits¶
The logits for the prediction.
- image¶
The line image used to produce the prediction.
Notes
When slicing the record the behavior of the cuts is changed from earlier versions of kraken. Instead of returning per-character bounding polygons a single polygons section of the line bounding polygon starting at the first and extending to the last code point emitted by the network is returned. This aids numerical stability when computing aggregated bounding polygons such as for words. Individual code point bounding polygons are still accessible through the cuts attribute or by iterating over the record code point by code point.
- base_dir = None¶
- display_order(base_dir=None)¶
Returns the OCR record in Unicode display order, i.e. ordered from left to right inside the line.
- Parameters:
base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- Return type:
- logical_order(base_dir=None)¶
Returns the OCR record in Unicode logical order, i.e. in the order the characters in the line would be read by a human.
- Parameters:
base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- Return type:
- type = 'bbox'¶
- class kraken.containers.BaselineLine¶
Baseline-type line record.
A container class for a single line in baseline + bounding polygon format, optionally containing a transcription, tags, or associated regions.
- baseline¶
list of tuples (x_n, y_n) defining the baseline.
- boundary¶
list of tuples (x_n, y_n) defining the bounding polygon of the line. The first and last points should be identical.
- baseline: list[tuple[int, int]] | None = None¶
- boundary: list[tuple[int, int]] | None = None¶
- type: str = 'baselines'¶
- class kraken.containers.BaselineOCRRecord(prediction, cuts, confidences, line, base_dir=None, display_order=True, logits=None, image=None)¶
A record object containing the recognition result of a single line in baseline format.
- Parameters:
prediction (str)
cuts (list[tuple[int, int]])
confidences (list[float])
line (Union[BaselineLine, dict[str, Any]])
base_dir (Optional[Literal['L', 'R']])
display_order (bool)
logits (Optional[torch.FloatTensor])
image (Optional[PIL.Image.Image])
- type¶
‘baselines’ to indicate a baseline record
- prediction¶
The text predicted by the network as one continuous string.
- Return type:
str
- cuts¶
The absolute bounding polygons for each code point in prediction as a list of tuples [(x0, y0), (x1, y2), …].
- Return type:
tuple
- confidences¶
A list of floats indicating the confidence value of each code point.
- Return type:
list[float]
- base_dir¶
An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- display_order¶
Flag indicating the order of the code points in the prediction. In display order (True) the n-th code point in the string corresponds to the n-th leftmost code point, in logical order (False) the n-th code point corresponds to the n-th read code point. See [UAX #9](https://unicode.org/reports/tr9) for more details.
- Parameters:
base_dir (Optional[Literal['L', 'R']])
- Return type:
- logits¶
The logits for the prediction.
- image¶
The line image used to produce the prediction.
Notes
When slicing the record the behavior of the cuts is changed from earlier versions of kraken. Instead of returning per-character bounding polygons a single polygons section of the line bounding polygon starting at the first and extending to the last code point emitted by the network is returned. This aids numerical stability when computing aggregated bounding polygons such as for words. Individual code point bounding polygons are still accessible through the cuts attribute or by iterating over the record code point by code point.
- base_dir = None¶
- property cuts: tuple¶
- Return type:
tuple
- display_order(base_dir=None)¶
Returns the OCR record in Unicode display order, i.e. ordered from left to right inside the line.
- Parameters:
base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- Return type:
- logical_order(base_dir=None)¶
Returns the OCR record in Unicode logical order, i.e. in the order the characters in the line would be read by a human.
- Parameters:
base_dir (Optional[Literal['L', 'R']]) – An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- Return type:
- type = 'baselines'¶
- class kraken.containers.ProcessingStep¶
A processing step in the recognition pipeline.
- id¶
Unique identifier
- category¶
Category of processing step that has been performed.
- description¶
Natural-language description of the process.
- settings¶
dict describing the parameters of the processing step.
- category: Literal['preprocessing', 'processing', 'postprocessing']¶
- description: str¶
- id: str¶
- settings: dict[str, dict | str | float | int | bool]¶
- class kraken.containers.Region¶
Container class of a single polygonal region.
- id¶
Unique identifier
- boundary¶
list of tuples (x_n, y_n) defining the bounding polygon of the region. The first and last points should be identical.
- imagename¶
Path to the image associated with the region.
- tags¶
A dict mapping types to values.
- boundary: list[tuple[int, int]]¶
- id: str¶
- imagename: str | os.PathLike | None = None¶
- language: list[str] | None = None¶
- tags: dict[str, list[dict[str, str]]] | None = None¶
- class kraken.containers.Segmentation¶
A container class for segmentation or recognition results.
In order to allow easy JSON de-/serialization, nested classes for lines (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their dictionaries.
- type¶
Field indicating if baselines (
kraken.containers.BaselineLine) or bbox (kraken.containers.BBoxLine) line records are in the segmentation.
- imagename¶
Path to the image associated with the segmentation.
- text_direction¶
Sets the principal orientation (of the line), i.e. horizontal/vertical, and reading direction (of the document), i.e. lr/rl.
- script_detection¶
Flag indicating if the line records have tags.
- lines¶
list of line records. Records are expected to be in a valid reading order.
- regions¶
dict mapping types to lists of regions.
- line_orders¶
list of alternative reading orders for the segmentation. Each reading order is a list of line indices.
- imagename: str | os.PathLike¶
- language: list[str] | None = None¶
- line_orders: list[list[int]] | None = None¶
- lines: list[BaselineLine | BBoxLine] | None = None¶
- script_detection: bool¶
- text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl']¶
- type: Literal['baselines', 'bbox']¶
- kraken.containers.compute_polygon_section(baseline, boundary, dist1, dist2)¶
Given a baseline, polygonal boundary, and two points on the baseline return the rectangle formed by the orthogonal cuts on that baseline segment. The resulting polygon is not garantueed to have a non-zero area.
The distance can be larger than the actual length of the baseline if the baseline endpoints are inside the bounding polygon. In that case the baseline will be extrapolated to the polygon edge.
- Parameters:
baseline (collections.abc.Sequence[tuple[int, int]]) – A polyline ((x1, y1), …, (xn, yn))
boundary (collections.abc.Sequence[tuple[int, int]]) – A bounding polygon around the baseline (same format as baseline). Last and first point are automatically connected.
dist1 (int) – Absolute distance along the baseline of the first point.
dist2 (int) – Absolute distance along the baseline of the second point.
- Returns:
A sequence of polygon points.
- Return type:
tuple[tuple[int, int]]
- class kraken.containers.ocr_line¶
A line record.
- id¶
Unique identifier
- text¶
Transcription of this line.
- base_dir¶
An optional string defining the base direction (also called paragraph direction) for the BiDi algorithm. Valid values are ‘L’ or ‘R’. If None is given the default auto-resolution will be used.
- imagename¶
Path to the image associated with the line.
- tags¶
A dict mapping types to values.
- split¶
Defines whether this line is in the train, validation, or test set during training.
- regions¶
A list of identifiers of regions the line is associated with.
- languages¶
A list of identifiers of regions the line is associated with.
- base_dir: Literal['L', 'R'] | None = None¶
- id: str¶
- imagename: str | os.PathLike | None = None¶
- language: list[str] | None = None¶
- regions: list[str] | None = None¶
- split: Literal['train', 'validation', 'test'] | None = None¶
- tags: dict[str, list[dict[str, str]]] | None = None¶
- text: str | None = None¶
- class kraken.containers.ocr_record(prediction, cuts, confidences, display_order=True, logits=None, image=None)¶
A record object containing the recognition result of a single line
- Parameters:
prediction (str)
cuts (list[Union[tuple[int, int], tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int]]]])
confidences (list[float])
display_order (bool)
logits (Optional[torch.FloatTensor])
image (Optional[PIL.Image.Image])
- base_dir = None¶
- property confidences: list[float]¶
- Return type:
list[float]
- property cuts: list¶
- Return type:
list
- abstractmethod display_order(base_dir)¶
- Return type:
- image = None¶
- abstractmethod logical_order(base_dir)¶
- Return type:
- logits = None¶
- property prediction: str¶
- Return type:
str
- abstract property type¶
- kraken.containers.precompute_polygon_sections(baseline, boundary, cut_pairs)¶
Batch-precompute polygon sections for all characters, amortizing the expensive baseline extension and cumulative distance computation.
- Parameters:
baseline (collections.abc.Sequence[tuple[int, int]]) – A polyline ((x1, y1), …, (xn, yn))
boundary (collections.abc.Sequence[tuple[int, int]]) – A bounding polygon around the baseline.
cut_pairs (list[tuple[int, int]]) – List of (dist1, dist2) tuples for each character.
- Returns:
char_polygons is a list of polygon tuples, one per character
intersection_cache maps clamped distance values to raw _test_intersect result arrays (or None on failure)
bl_length is the total baseline length after extension
- Return type:
A tuple of (char_polygons, intersection_cache, bl_length) where
kraken.lib.xml¶
ALTO/Page data loaders for segmentation training
- class kraken.lib.xml.Segmentation¶
A container class for segmentation or recognition results.
In order to allow easy JSON de-/serialization, nested classes for lines (BaselineLine/BBoxLine) and regions (Region) are reinstantiated from their dictionaries.
- type¶
Field indicating if baselines (
kraken.containers.BaselineLine) or bbox (kraken.containers.BBoxLine) line records are in the segmentation.
- imagename¶
Path to the image associated with the segmentation.
- text_direction¶
Sets the principal orientation (of the line), i.e. horizontal/vertical, and reading direction (of the document), i.e. lr/rl.
- script_detection¶
Flag indicating if the line records have tags.
- lines¶
list of line records. Records are expected to be in a valid reading order.
- regions¶
dict mapping types to lists of regions.
- line_orders¶
list of alternative reading orders for the segmentation. Each reading order is a list of line indices.
- imagename: str | os.PathLike¶
- language: list[str] | None = None¶
- line_orders: list[list[int]] | None = None¶
- lines: list[BaselineLine | BBoxLine] | None = None¶
- script_detection: bool¶
- text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl']¶
- type: Literal['baselines', 'bbox']¶
- class kraken.lib.xml.XMLPage(filename, filetype='xml', linetype='baselines')¶
Parses XML facsimiles in ALTO or PageXML format.
The parser is able to deal with most (but not all) features supported by those standards. In particular, any data below the line level is discarded.
- Parameters:
filename (Union[str, os.PathLike]) – Path to the XML file
filetype (Literal['xml', 'alto', 'page']) – Selector for explicit subparser choice.
linetype (Literal['baselines', 'bbox']) – Parse line data as baselines or bounding box type.
- type¶
Either ‘baselines’ or ‘bbox’.
- imagename¶
Path to the image to the XML file.
- image_size¶
Size of the image as a (width, height) tuple
- has_tags¶
Indicates if the source document contains tag information
- has_splits¶
Indicates if the source document contains explicit training splits
- base_dir: Literal['L', 'R'] | None = None¶
- filename¶
- filetype = 'xml'¶
- get_lines_by_split(split)¶
- Parameters:
split (Literal['train', 'validation', 'test'])
- get_lines_by_tag(key, value)¶
- get_sorted_lines(ro='line_implicit')¶
Returns ordered baselines from particular reading order.
- get_sorted_lines_by_region(region, ro='line_implicit')¶
Returns ordered lines in region.
- get_sorted_regions(ro='region_implicit')¶
Returns ordered regions from particular reading order.
- has_splits: bool = False¶
- has_tags: bool = False¶
- image_size: tuple[int, int] = None¶
- imagename: os.PathLike = None¶
- property lines¶
- property reading_orders¶
- property regions¶
- property splits¶
- property tags¶
- to_container()¶
Returns a Segmentation object.
- Return type:
- type: Literal['baselines', 'bbox'] = 'baselines'¶
- kraken.lib.xml.alto_regions¶
- kraken.lib.xml.flatten_order_to_lines(raw_order, lines_dict, region_ids, line_implicit_order, string_to_line_map=None, missing_region_ids=None)¶
Flatten a raw reading order (list of IDs) to line-level.
For each ID: - Line ID: append directly - Region ID: expand to contained lines using implicit order - String ID (ALTO only): map to parent TextLine, deduplicate consecutive - Unknown ID: log warning, skip
- Parameters:
raw_order (list[str])
lines_dict (dict)
region_ids (set[str])
line_implicit_order (list[str])
string_to_line_map (Optional[dict[str, str]])
missing_region_ids (Optional[set[str]])
- Return type:
list[str]
- kraken.lib.xml.flatten_order_to_regions(raw_order, lines_dict, region_ids, string_to_line_map=None, missing_region_ids=None)¶
Flatten a raw reading order (list of IDs) to region-level.
For each ID: - Region ID: append directly - Line ID: resolve to parent region, deduplicate consecutive - String ID (ALTO only): resolve to parent line then region, deduplicate - Unknown ID: log warning, skip
- Parameters:
raw_order (list[str])
lines_dict (dict)
region_ids (set[str])
string_to_line_map (Optional[dict[str, str]])
missing_region_ids (Optional[set[str]])
- Return type:
list[str]
- kraken.lib.xml.logger¶
- kraken.lib.xml.page_regions¶
- kraken.lib.xml.parse_alto(doc, filename, linetype)¶
Parse an ALTO XML document.
- Parameters:
doc – Parsed lxml document.
filename – Path to the XML file (for error messages and resolving image paths).
linetype – ‘baselines’ or ‘bbox’.
- Returns:
- imagename, image_size, regions, lines, orders, tag_set,
raw_orders, string_to_line_map
- Return type:
dict with keys
- kraken.lib.xml.parse_page(doc, filename, linetype)¶
Parse a PageXML document.
- Parameters:
doc – Parsed lxml document.
filename – Path to the XML file (for error messages and resolving image paths).
linetype – ‘baselines’ or ‘bbox’.
- Returns:
- imagename, image_size, regions, lines, orders, tag_set,
raw_orders
- Return type:
Dict with keys
- kraken.lib.xml.validate_and_clean_order(flat_order, valid_ids)¶
Validate a flattened order.
Checks: - All IDs exist in valid_ids - No duplicate IDs (indicates circular reference)
- Returns:
(cleaned_order, is_valid)
- Parameters:
flat_order (list[str])
valid_ids (set[str])
- Return type:
tuple[list[str], bool]
kraken.tasks¶
- class kraken.tasks.ForcedAlignmentTaskModel(models)¶
A wrapper for forced alignment of CTC output.
Using a text recognition model the existing transcription of a page will be aligned to the character positions of the network output.
- Raises:
ValueError – Is raised when the model type is not a sequence recognizer.
- Parameters:
models (list[torch.nn.Module])
- classmethod load_model(path)¶
- Parameters:
path (Union[str, os.PathLike])
- net¶
- one_channel_mode¶
- predict(im, segmentation, config)¶
Aligns the transcription of an image with the output of the text recognition model, producing approximate character locations.
When the character sets of transcription and recognition model differ, the affected code points in the furnished transcription will silently be ignored. In case inference fails on a line, a record without cuts/confidences is returned.
- Parameters:
im (PIL.Image.Image) – The input image
segmentation (kraken.containers.Segmentation) – A segmentation with transcriptions to align.
config (kraken.configs.RecognitionInferenceConfig) – A recognition inference configuration. The task model will automatically set some required configuration flags on it.
- Returns:
A single segmentation that contains the aligned ocr_record objects.
- Return type:
Example
>>> from PIL import Image >>> from kraken.tasks import ForcedAlignmentTaskModel >>> from kraken.containers import Segmentation, BaselineLine >>> from kraken.configs import RecognitionInferenceConfig
>>> # Assume `model.mlmodel` is a recognition model >>> model = ForcedAlignmentTaskModel.load_model('model.mlmodel') >>> im = Image.open('image.png') >>> # Create a dummy segmentation with a line and a transcription >>> line = BaselineLine(baseline=[(0,0), (100,0)], boundary=[(0,-10), (100,-10), (100,10), (0,10)], text='Hello World') >>> segmentation = Segmentation(lines=[line]) >>> config = RecognitionInferenceConfig()
>>> aligned_segmentation = model.predict(im, segmentation, config) >>> record = aligned_segmentation.lines[0]: >>> print(record.prediction) >>> print(record.cuts)
- seg_type¶
- class kraken.tasks.RecognitionTaskModel(models)¶
A wrapper for a model performing a recognition task.
A recognition task is the process of transcribing a line of text from an image. This class provides a high-level interface for running a recognition model on an image, given a segmentation.
- Raises:
ValueError – Is raised when the model type is not a sequence recognizer.
- Parameters:
models (list[torch.nn.Module])
- classmethod load_model(path)¶
- Parameters:
path (Union[str, os.PathLike])
- net¶
- one_channel_mode¶
- predict(im, segmentation, config)¶
Inference using a recognition model.
- Parameters:
im (PIL.Image.Image) – Input image
segmentation (kraken.containers.Segmentation) – The segmentation corresponding to the input image.
config (kraken.configs.RecognitionInferenceConfig) – A configuration object containing inference parameters, such as the batch size and the precision.
- Yields:
One ocr_record for each line.
- Return type:
collections.abc.Generator[kraken.containers.ocr_record, None, None]
Example
>>> from PIL import Image >>> from kraken.tasks import RecognitionTaskModel >>> from kraken.containers import Segmentation >>> from kraken.configs import RecognitionInferenceConfig
>>> model = RecognitionTaskModel.load_model('model.mlmodel') >>> im = Image.open('image.png') >>> segmentation = Segmentation(...) >>> config = RecognitionInferenceConfig()
>>> for record in model.predict(im, segmentation, config): ... print(record.prediction)
- seg_type¶
- class kraken.tasks.SegmentationTaskModel(models)¶
A wrapper class collecting one or more models that perform segmentation.
A segmentation task is the process of identifying the regions and lines of text in an image. This class provides a high-level interface for running segmentation models on an image.
- It deals with the following tasks:
region segmentation
line detection
line reading order
If no neural reading order model is part of the model collection handed to the task, a simple heuristic will be used.
- Parameters:
models (list[torch.nn.Module]) – A collection of models performing segmentation tasks.
- Raises:
ValueError – Is raised when no segmentation models are in the model list.
- classmethod load_model(path=None)¶
Loads a collection from layout analysis models from the given file path.
If no path is provided, the default BLLA segmentation model will be loaded.
- Parameters:
path (Optional[Union[str, os.PathLike]]) – Path to model weights file.
- Return type:
- predict(im, config)¶
Runs all models associated with the task to produce a segmentation for the input page.
- Parameters:
im (PIL.Image.Image) – Input image with an arbitrary color mode and size
config (kraken.configs.SegmentationInferenceConfig) – A configuration object for the segmentation task, such as the batch size and the precision.
- Returns:
A single Segmentation object that contains the merged output of all associated segmentation models.
- Return type:
Example
>>> from PIL import Image >>> from kraken.tasks import SegmentationTaskModel >>> from kraken.configs import SegmentationInferenceConfig
>>> model = SegmentationTaskModel.load_model() >>> im = Image.open('image.png') >>> config = SegmentationInferenceConfig()
>>> segmentation = model.predict(im, config)
- ro_models¶
- seg_models¶
kraken.train¶
- class kraken.train.BLLASegmentationDataModule(data_config)¶
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
import lightning as L import torch.utils.data as data from lightning.pytorch.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ... def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) self.train, self.val, self.test = data.random_split( dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) ) def train_dataloader(self): return data.DataLoader(self.train) def val_dataloader(self): return data.DataLoader(self.val) def test_dataloader(self): return data.DataLoader(self.test) def on_exception(self, exception): # clean up state after the trainer faced an exception ... def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP ...
- Parameters:
data_config (kraken.configs.BLLASegmentationTrainingDataConfig)
- setup(stage=None)¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (str) – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- train_dataloader()¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- class kraken.train.BLLASegmentationModel(config, model=None)¶
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- Parameters:
config (kraken.configs.BLLASegmentationTrainingConfig)
model (Optional[kraken.models.BaseModel])
- configure_callbacks()¶
Configure model-specific callbacks. When the model gets attached, e.g., when
.fit()or.test()gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’scallbacksargument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sureModelCheckpointcallbacks run last.- Returns:
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example:
def configure_callbacks(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint]
- configure_optimizers()¶
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- criterion¶
- example_input_array¶
The example input array is a specification of what the module can consume in the
forward()method. The return type is interpreted as follows:Single tensor: It is assumed the model takes a single argument, i.e.,
model.forward(model.example_input_array)Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
model.forward(*model.example_input_array)Dict: The input array represents named keyword arguments, i.e.,
model.forward(**model.example_input_array)
- forward(x)¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
x (torch.Tensor)
- Returns:
Your model’s output
- Return type:
tuple[torch.Tensor, torch.Tensor]
- classmethod load_from_weights(path, config)¶
Initializes the module from a model weights file.
- Parameters:
path (Union[str, os.PathLike])
config (kraken.configs.BLLASegmentationTrainingConfig)
- Return type:
- lr_scheduler_step(scheduler, metric)¶
Override this method to adjust the default way the
Trainercalls each scheduler. By default, Lightning callsstep()and as shown in the example for each scheduler based on itsinterval.- Parameters:
scheduler – Learning rate scheduler.
metric – Value of the monitor used for schedulers like
ReduceLROnPlateau.
Examples:
# DEFAULT def lr_scheduler_step(self, scheduler, metric): if metric is None: scheduler.step() else: scheduler.step(metric) # Alternative way to update schedulers if it requires an epoch value def lr_scheduler_step(self, scheduler, metric): scheduler.step(epoch=self.current_epoch)
- on_load_checkpoint(checkpoint)¶
Reconstruct the model from the spec here and not in setup() as otherwise the weight loading will fail.
- on_save_checkpoint(checkpoint)¶
Save hyperparameters a second time so we can set parameters that shouldn’t be overwritten in on_load_checkpoint.
- on_test_epoch_end()¶
Called in the test loop at the very end of the epoch.
- on_test_epoch_start()¶
Called in the test loop at the very beginning of the epoch.
- on_validation_epoch_end()¶
Called in the validation loop at the very end of the epoch.
- optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)¶
Override this method to adjust the default way the
Trainercalls the optimizer.By default, Lightning calls
step()andzero_grad()as shown in the example. This method (andzero_grad()) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.- Parameters:
epoch – Current epoch
batch_idx – Index of current batch
optimizer – A PyTorch optimizer
optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to
training_step(),optimizer.zero_grad(), andbackward().
Examples:
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # Add your custom logic to run directly before `optimizer.step()` optimizer.step(closure=optimizer_closure) # Add your custom logic to run directly after `optimizer.step()`
- setup(stage=None)¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (Optional[str]) – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_step(batch, batch_idx, test_dataloader=0)¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- training_step(batch, batch_idx)¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch, batch_idx)¶
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- class kraken.train.KrakenTrainer(enable_progress_bar=True, enable_summary=True, min_epochs=5, max_epochs=100, freeze_backbone=-1, pl_logger=None, log_dir=None, *args, **kwargs)¶
- Parameters:
enable_progress_bar (bool)
enable_summary (bool)
min_epochs (int)
max_epochs (int)
pl_logger (Union[lightning.pytorch.loggers.logger.Logger, str, None])
log_dir (Optional[os.PathLike])
- automatic_optimization = False¶
- fit(*args, **kwargs)¶
Runs the full optimization routine.
- Parameters:
model – Model to fit.
train_dataloaders – An iterable or collection of iterables specifying training samples. Alternatively, a
LightningDataModulethat defines thetrain_dataloaderhook.val_dataloaders – An iterable or collection of iterables specifying validation samples.
datamodule – A
LightningDataModulethat defines thetrain_dataloaderhook.ckpt_path –
Path/URL of the checkpoint from which training is resumed. Could also be one of three special keywords
"last","hpc"and"registry". Otherwise, if there is no checkpoint file at the path, an exception is raised.best: the best model checkpoint from the previous
trainer.fitcall will be loadedlast: the last model checkpoint from the previous
trainer.fitcall will be loadedregistry: the model will be downloaded from the Lightning Model Registry with following notations:
'registry': uses the latest/default version of default model set withTrainer(..., model_registry="my-model")'registry:model-name': uses the latest/default version of this model model-name'registry:model-name:version:v2': uses the specific version ‘v2’ of the model model-name'registry:version:v2': uses the default model set withTrainer(..., model_registry="my-model")and version ‘v2’
weights_only – Defaults to
None. IfTrue, restricts loading tostate_dictsof plaintorch.Tensorand other primitive types. If loading a checkpoint from a trusted source that contains annn.Module, useweights_only=False. If loading checkpoint from an untrusted source, we recommend usingweights_only=True. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.
For more information about multiple dataloaders, see this section.
- Return type:
None- Raises:
TypeError – If
modelis notLightningModulefor torch version less than 2.0.0 and ifmodelis notLightningModuleortorch._dynamo.OptimizedModulefor torch versions greater than or equal to 2.0.0 .
- test(*args, **kwargs)¶
Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to.
- Parameters:
model – The model to test.
dataloaders – An iterable or collection of iterables specifying test samples. Alternatively, a
LightningDataModulethat defines thetest_dataloaderhook.ckpt_path – Either
"best","last","hpc","registry"or path to the checkpoint you wish to test. IfNoneand the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previoustrainer.fitcall will be loaded if a checkpoint callback is configured.verbose – If True, prints the test results.
datamodule – A
LightningDataModulethat defines thetest_dataloaderhook.weights_only –
Defaults to
None. IfTrue, restricts loading tostate_dictsof plaintorch.Tensorand other primitive types. If loading a checkpoint from a trusted source that contains annn.Module, useweights_only=False. If loading checkpoint from an untrusted source, we recommend usingweights_only=True. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.
- Return type:
TestMetrics
For more information about multiple dataloaders, see this section.
- Returns:
List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like
test_step()etc. The length of the list corresponds to the number of test dataloaders used.- Raises:
TypeError – If no
modelis passed and there was noLightningModulepassed in the previous run. Ifmodelpassed is not LightningModule or torch._dynamo.OptimizedModule.MisconfigurationException – If both
dataloadersanddatamoduleare passed. Pass only one of these.RuntimeError – If a compiled
modelis passed and the strategy is not supported.
- Return type:
TestMetrics
- class kraken.train.VGSLRecognitionDataModule(data_config)¶
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
import lightning as L import torch.utils.data as data from lightning.pytorch.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ... def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) self.train, self.val, self.test = data.random_split( dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) ) def train_dataloader(self): return data.DataLoader(self.train) def val_dataloader(self): return data.DataLoader(self.val) def test_dataloader(self): return data.DataLoader(self.test) def on_exception(self, exception): # clean up state after the trainer faced an exception ... def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP ...
- Parameters:
data_config (kraken.configs.VGSLRecognitionTrainingDataConfig)
- setup(stage=None)¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (str) – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- train_dataloader()¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- class kraken.train.VGSLRecognitionModel(config, model=None)¶
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- Parameters:
config (kraken.configs.VGSLRecognitionTrainingConfig)
model (Optional[kraken.models.BaseModel])
- configure_callbacks()¶
Configure model-specific callbacks. When the model gets attached, e.g., when
.fit()or.test()gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’scallbacksargument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sureModelCheckpointcallbacks run last.- Returns:
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example:
def configure_callbacks(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint]
- configure_optimizers()¶
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- example_input_array¶
The example input array is a specification of what the module can consume in the
forward()method. The return type is interpreted as follows:Single tensor: It is assumed the model takes a single argument, i.e.,
model.forward(model.example_input_array)Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
model.forward(*model.example_input_array)Dict: The input array represents named keyword arguments, i.e.,
model.forward(**model.example_input_array)
- forward(x, seq_lens=None)¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- classmethod load_from_weights(path, config)¶
Initializes the module from a model weights file.
- Parameters:
path (Union[str, os.PathLike])
config (kraken.configs.VGSLRecognitionTrainingConfig)
- Return type:
- lr_scheduler_step(scheduler, metric)¶
Override this method to adjust the default way the
Trainercalls each scheduler. By default, Lightning callsstep()and as shown in the example for each scheduler based on itsinterval.- Parameters:
scheduler – Learning rate scheduler.
metric – Value of the monitor used for schedulers like
ReduceLROnPlateau.
Examples:
# DEFAULT def lr_scheduler_step(self, scheduler, metric): if metric is None: scheduler.step() else: scheduler.step(metric) # Alternative way to update schedulers if it requires an epoch value def lr_scheduler_step(self, scheduler, metric): scheduler.step(epoch=self.current_epoch)
- on_load_checkpoint(checkpoint)¶
Reconstruct the model from the spec here and not in setup() as otherwise the weight loading will fail.
- on_save_checkpoint(checkpoint)¶
Save hyperparameters a second time so we can set parameters that shouldn’t be overwritten in on_load_checkpoint.
- on_test_epoch_end()¶
Called in the test loop at the very end of the epoch.
- on_test_epoch_start()¶
Called in the test loop at the very beginning of the epoch.
- on_validation_epoch_end()¶
Called in the validation loop at the very end of the epoch.
- optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)¶
Override this method to adjust the default way the
Trainercalls the optimizer.By default, Lightning calls
step()andzero_grad()as shown in the example. This method (andzero_grad()) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.- Parameters:
epoch – Current epoch
batch_idx – Index of current batch
optimizer – A PyTorch optimizer
optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to
training_step(),optimizer.zero_grad(), andbackward().
Examples:
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # Add your custom logic to run directly before `optimizer.step()` optimizer.step(closure=optimizer_closure) # Add your custom logic to run directly after `optimizer.step()`
- setup(stage=None)¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (Optional[str]) – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_step(batch, batch_idx, test_dataloader=0)¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- training_step(batch, batch_idx)¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch, batch_idx)¶
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
Low-Level API¶
These modules provide lower-level access to the core components of kraken. In most cases, it is recommended to use the high-level API instead.
kraken.ketos.pretrain¶
Command line driver for unsupervised recognition pretraining
Top-level module containing datasets for recognition and segmentation training.
Pytorch compatible codec with many-to-many mapping between labels and graphemes.
kraken.lib.exceptions¶
All custom exceptions raised by kraken’s modules and packages. Packages should always define their exceptions here.
Processing for baseline segmenter output
Decoders for softmax outputs of CTC trained networks.
Decoders extract label sequences out of the raw output matrix of the line recognition network. There are multiple different approaches implemented here, from a simple greedy decoder, to the legacy ocropy thresholding decoder, and a more complex beam search decoder.
Extracted label sequences are converted into the code point domain using kraken.lib.codec.PytorchCodec.
Legacy Modules¶
These modules are retained for compatibility reasons or highly specialized use cases. Their use is not recommended.
kraken.binarization¶
An adaptive binarization algorithm. This code is legacy and only remains for historical reasons. Binarization is no longer necessary for most workflows.
kraken.pageseg¶
The legacy bounding box segmentation method using conventional image processing techniques.
kraken.rpred¶
Legacy line text recognition API. New code should use the RecognitionTaskModel from kraken.tasks which is more versatile and offers higher performance.
kraken.lib.models¶
Wrapper around TorchVGSLModel including a variety of forward pass helpers for sequence classification.