rikai.spark.sql.codegen package¶
Submodules¶
rikai.spark.sql.codegen.base module¶
rikai.spark.sql.codegen.dummy module¶
rikai.spark.sql.codegen.fs module¶
rikai.spark.sql.codegen.mlflow_logger module¶
Custom Mlflow model logger to make sure models have the right logging for Rikai SQL ML
- class rikai.spark.sql.codegen.mlflow_logger.MlflowLogger(flavor: str)¶
Bases:
object
An alternative model logger for use during training instead of the vanilla mlflow logger.
- log_model(model: Any, artifact_path: str, schema: Optional[str] = None, registered_model_name: Optional[str] = None, customized_flavor: Optional[str] = None, model_type: Optional[str] = None, labels: Optional[dict] = None, **kwargs)¶
Convenience function to log the model with tags needed by rikai. This should be called during training when the model is produced.
- Parameters
model (Any) – The model artifact object
artifact_path (str) – The relative (to the run) artifact path
schema (str) – Output schema (pyspark DataType)
registered_model_name (str, default None) – Model name in the mlflow model registry
model_type (str) – Model type
kwargs (dict) – Passed to mlflow.<flavor>.log_model
Examples
import rikai.mlflow # Log PyTorch model with mlflow.start_run() as run: # Training loop # ... # Assume `model` is the trained model from the training loop rikai.mlflow.pytorch.log_model(model, "model", model_type="ssd", registered_model_name="MyPytorchModel")
For more details see mlflow docs.
rikai.spark.sql.codegen.mlflow_registry module¶
rikai.spark.sql.codegen.pytorch module¶
- rikai.spark.sql.codegen.pytorch.move_tensor_to_device(data, device)¶