rikai.contrib.torchhub.pytorch.vision package

Submodules

rikai.contrib.torchhub.pytorch.vision.resnet module

class rikai.contrib.torchhub.pytorch.vision.resnet.ResnetModelType(name: str, pretrained_fn: Optional[Callable] = None, label_fn: Optional[Callable[[int], str]] = None, collate_fn: Optional[Callable] = None, register: bool = True)

Bases: TorchModelType

predict(images, *args, **kwargs) Any

Run model inference and convert return types into Rikai-compatible types.

schema() str

Return the string value of model schema.

Examples

>>> model_type.schema()
... "array<struct<box:box2d, score:float, label_id:int>>"
transform() Callable

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

Module contents