rikai.pytorch package¶
Subpackages¶
- rikai.pytorch.models package
- Submodules
- rikai.pytorch.models.convnext module
- rikai.pytorch.models.efficientnet module
- rikai.pytorch.models.fasterrcnn module
- rikai.pytorch.models.feature_extractor module
- rikai.pytorch.models.keypointrcnn module
- rikai.pytorch.models.maskrcnn module
- rikai.pytorch.models.resnet module
- rikai.pytorch.models.retinanet module
- rikai.pytorch.models.ssd module
- rikai.pytorch.models.ssd_class_scores module
- rikai.pytorch.models.torch module
- Module contents
Submodules¶
rikai.pytorch.data module¶
Pytorch Dataset and DataLoader
- class rikai.pytorch.data.Dataset(data_ref: Union[str, Path, pyspark.sql.DataFrame], columns: List[str] = None, transform: Callable = ToTensor)¶
Bases:
IterableDataset
Rikai Pytorch Dataset.
A
torch.utils.data.IterableDataset
that reads Rikai data format. ThisDataset
works with multi-process data loading usingtorch.utils.data.DataLoader
.- Parameters
data_ref (str, Path, pyspark.sql.DataFrame) – URI to the data files or the dataframe
columns (list of str, optional) – An optional list of column to load from parquet files.
transform (Callable, default instance of RikaiToTensor) – Apply row level transformation before yielding each sample
Note
Up to
pytorch==1.7
,IterableDataset
does not work withtorch.utils.data.Sampler
withtorch.utils.data.DataLoader
.Use
torch.utils.data.BufferedShuffleDataset
(torch>=1.8) with the Rikai dataset for randomness.Example
>>> from rikai.pytorch.data import Dataset >>> from torch.utils.data import DataLoader >>> >>> dataset = Dataset("dataset", columns=["image", "label_id"]) >>> # dataset = BufferedShuffleDataset(dataset) >>> loader = DataLoader(dataset, num_workers=8)
rikai.pytorch.pandas module¶
- class rikai.pytorch.pandas.PandasDataset(data: Union[DataFrame, Series], transform: Optional[Callable] = None, unpickle: bool = False, use_pil: bool = False)¶
Bases:
Dataset
a Map-style Pytorch dataset from a
pandas.DataFrame
or apandas.Series
.Note
This class is used in Rikai’s SQL-ML Spark implementation, which utilizes pandas UDF to run inference.
rikai.pytorch.transforms module¶
rikai.pytorch.vision module¶
Torchvision compatible Dataset
- class rikai.pytorch.vision.Dataset(uri_or_df: Union[str, Path, pyspark.sql.DataFrame], image_column: str, target_column: Optional[Union[str, List[str]]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)¶
Bases:
Dataset
A Rikai Dataset compatible with torchvision.
- Parameters
uri_or_df (str, Path, or pyspark.sql.DataFrame) – URI of the dataset or the dataset as a pyspark DataFrame
image_column (str) – The column name for the image data.
target_column (str or list[str], optional) – The column(s) of the target / label.
transform (Callable, optional) – A function/transform that takes in an
PIL.Image.Image
and returns a transformed version. E.g,torchvision.transforms.ToTensor
target_transform (Callable, optional) – A function/transform that takes in the target and transforms it.
- Yields
(image, target)
See also
Examples
>>> from torchvision import transforms >>> from rikai.pytorch.vision import Dataset >>> transform = transforms.Compose( ... transforms=[ ... transforms.Resize(128), ... transforms.ToTensor(), ... transforms.Normalize( ... (0.485, 0.456, 0.406), ... (0.229, 0.224, 0.225) ... ), ... ]) >>> dataset = Dataset("out", "image", ["label"], transform=transform) >>> next(iter(dataset)) ... tensor([[[-1.8610, -0.8678, -0.4226, ..., -1.7583, 0.0569, -0.6794], [-1.5870, -1.8782, -1.7069, ..., -1.1075, -1.1760, -1.8782], [-2.1179, -0.5253, -1.7925, ..., -0.3712, -1.4843, -1.2959], ..., [-1.1073, -0.3927, -0.8110, ..., -0.9853, 0.1128, -1.0027]]]) dog
Module contents¶
Pytorch support