rikai.pytorch package

Subpackages

Submodules

rikai.pytorch.data module

Pytorch Dataset and DataLoader

class rikai.pytorch.data.Dataset(data_ref: Union[str, Path, pyspark.sql.DataFrame], columns: List[str] = None, transform: Callable = ToTensor)

Bases: IterableDataset

Rikai Pytorch Dataset.

A torch.utils.data.IterableDataset that reads Rikai data format. This Dataset works with multi-process data loading using torch.utils.data.DataLoader.

Parameters
  • data_ref (str, Path, pyspark.sql.DataFrame) – URI to the data files or the dataframe

  • columns (list of str, optional) – An optional list of column to load from parquet files.

  • transform (Callable, default instance of RikaiToTensor) – Apply row level transformation before yielding each sample

Note

Up to pytorch==1.7, IterableDataset does not work with torch.utils.data.Sampler with torch.utils.data.DataLoader.

Use torch.utils.data.BufferedShuffleDataset (torch>=1.8) with the Rikai dataset for randomness.

Example

>>> from rikai.pytorch.data import Dataset
>>> from torch.utils.data import DataLoader
>>>
>>> dataset = Dataset("dataset", columns=["image", "label_id"])
>>> # dataset = BufferedShuffleDataset(dataset)
>>> loader = DataLoader(dataset, num_workers=8)

rikai.pytorch.pandas module

class rikai.pytorch.pandas.PandasDataset(data: Union[DataFrame, Series], transform: Optional[Callable] = None, unpickle: bool = False, use_pil: bool = False)

Bases: Dataset

a Map-style Pytorch dataset from a pandas.DataFrame or a pandas.Series.

Note

This class is used in Rikai’s SQL-ML Spark implementation, which utilizes pandas UDF to run inference.

rikai.pytorch.transforms module

class rikai.pytorch.transforms.RikaiToTensor(use_pil: bool = False)

Bases: object

Convert a Row in the Rikai parquet dataset into Pytorch Tensors

Warning

Internal use only

rikai.pytorch.vision module

Torchvision compatible Dataset

class rikai.pytorch.vision.Dataset(uri_or_df: Union[str, Path, pyspark.sql.DataFrame], image_column: str, target_column: Optional[Union[str, List[str]]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)

Bases: Dataset

A Rikai Dataset compatible with torchvision.

Parameters
  • uri_or_df (str, Path, or pyspark.sql.DataFrame) – URI of the dataset or the dataset as a pyspark DataFrame

  • image_column (str) – The column name for the image data.

  • target_column (str or list[str], optional) – The column(s) of the target / label.

  • transform (Callable, optional) – A function/transform that takes in an PIL.Image.Image and returns a transformed version. E.g, torchvision.transforms.ToTensor

  • target_transform (Callable, optional) – A function/transform that takes in the target and transforms it.

Yields

(image, target)

Examples

>>> from torchvision import transforms
>>> from rikai.pytorch.vision import Dataset
>>> transform = transforms.Compose(
...     transforms=[
...         transforms.Resize(128),
...         transforms.ToTensor(),
...         transforms.Normalize(
...             (0.485, 0.456, 0.406),
...             (0.229, 0.224, 0.225)
...         ),
...     ])
>>> dataset = Dataset("out", "image", ["label"], transform=transform)
>>> next(iter(dataset))
... tensor([[[-1.8610, -0.8678, -0.4226,  ..., -1.7583,  0.0569, -0.6794],
     [-1.5870, -1.8782, -1.7069,  ..., -1.1075, -1.1760, -1.8782],
     [-2.1179, -0.5253, -1.7925,  ..., -0.3712, -1.4843, -1.2959],
     ...,
     [-1.1073, -0.3927, -0.8110,  ..., -0.9853,  0.1128, -1.0027]]]) dog

Module contents

Pytorch support