Skip to content

Latest commit

 

History

History
273 lines (182 loc) · 8.97 KB

pytorch.md

File metadata and controls

273 lines (182 loc) · 8.97 KB
layout title permalink redirect_from
post
PYTORCH
/docs/pytorch
/pytorch.md/
/docs/pytorch.md/

The AIStore PyTorch integration is a growing set of datasets, samplers, and more that allow you to use easily add AIStore support to a codebase using PyTorch. This document contains API documentation for the AIStore PyTorch integration.

For usage examples, please see:

PyTorch Structure

Base class for AIS Map Style Datasets

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISBaseMapDataset

class AISBaseMapDataset(ABC, Dataset)

A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses should implement :meth:__getitem__ which fetches a samples given a key from the dataset and can optionally override other methods from torch Dataset such as :meth:__len__ and :meth:__getitems__.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source

get_obj_list

def get_obj_list() -> List[Object]

Getter for internal object data list.

Returns:

  • List[Object] - Object data of the dataset

Base class for AIS Iterable Style Datasets

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISBaseIterDataset

class AISBaseIterDataset(ABC, torch_utils.IterableDataset)

A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses should implement :meth:__iter__ which returns the samples from the dataset and can optionally override other methods from torch IterableDataset such as :meth:__len__.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source

__iter__

@abstractmethod
def __iter__() -> Iterator

Return iterator with samples in this dataset.

Returns:

  • Iterator - Iterator of samples

__len__

def __len__()

Returns the length of the dataset. Note that calling this will iterate through the dataset, taking O(N) time.

NOTE: If you want the length of the dataset after iterating through it, use for i, data in enumerate(dataset) instead.

PyTorch Dataset for AIS.

Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.

Class: AISMapDataset

class AISMapDataset(AISBaseMapDataset)

A map-style dataset for objects in AIS. If etl_name is provided, that ETL must already exist on the AIStore cluster.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source

  • etl_name str, optional - Optional ETL on the AIS cluster to apply to each object

    NOTE: Each object is represented as a tuple of object_name (str) and object_content (bytes)

Iterable Dataset for AIS

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISIterDataset

class AISIterDataset(AISBaseIterDataset)

An iterable-style dataset that iterates over objects in AIS and yields samples represented as a tuple of object_name (str) and object_content (bytes). If etl_name is provided, that ETL must already exist on the AIStore cluster.

Arguments:

  • ais_source_list Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
  • etl_name str, optional - Optional ETL on the AIS cluster to apply to each object
  • show_progress bool, optional - Enables console dataset reading progress indicator

Yields:

Tuple[str, bytes]: Each item is a tuple where the first element is the name of the object and the second element is the byte representation of the object data.

AIS Shard Reader for PyTorch

PyTorch Dataset and DataLoader for AIS.

Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Class: AISShardReader

class AISShardReader(AISBaseIterDataset)

An iterable-style dataset that iterates over objects stored as Webdataset shards and yields samples represented as a tuple of basename (str) and contents (dictionary).

Arguments:

  • bucket_list Union[Bucket, List[Bucket]] - Single or list of Bucket objects to load data prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of Bucket objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
  • etl_name str, optional - Optional ETL on the AIS cluster to apply to each object
  • show_progress bool, optional - Enables console shard reading progress indicator

Yields:

Tuple[str, Dict(str, bytes)]: Each item is a tuple where the first element is the basename of the shard and the second element is a dictionary mapping strings of file extensions to bytes.

__len__

def __len__()

Returns the length of the dataset. Note that calling this will iterate through the dataset, taking O(N) time.

NOTE: If you want the length of the dataset after iterating through it, use for i, data in enumerate(dataset) instead.

Class: ZeroDict

class ZeroDict(dict)

When collate_fn is called while using ShardReader with a dataloader, the content dictionaries for each sample are merged into a single dictionary with file extensions as keys and lists of contents as values. This means, however, that each sample must have a value for that file extension in the batch at iteration time or else collation will fail. To avoid forcing the user to pass in a custom collation function, we workaround the default implementation of collation.

As such, we define a dictionary that has a default value of b"" (zero bytes) for every key that we have seen so far. We cannot use None as collation does not accept None. Initially, when we open a shard tar, we collect every file type (pre-processing pass) from its members and cache those. Then, we read the shard files. Lastly, before yielding the sample, we wrap its content dictionary with this custom dictionary to insert any keys that it does not contain, hence ensuring consistent keys across samples.

NOTE: For our use case, defaultdict does not work due to needing a lambda which cannot be pickled in multithreaded contexts.

Multishard Stream Dataset for AIS.

Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.

Class: AISMultiShardStream

class AISMultiShardStream(IterableDataset)

An iterable-style dataset that iterates over multiple shard streams and yields combined samples.

Arguments:

  • data_sources List[DataShard] - List of DataShard objects

Returns:

  • Iterable - Iterable over the combined samples, where each sample is a tuple of one object bytes from each shard stream