Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # MLflow API notes
- ## Tracking store
- There are currently three implementations of tracking store, with the following
- interfaces:
- * `FileStore(root_directory, artifact_root_uri)`
- * `SqlAlchemyStore(db_uri, default_artifact_root)`
- * `RestStore(get_host_creds)`
- In each case, the first argument provides information on the location of
- tracking information storage. The second argument to `FileStore` and
- `SqlAlchemyStore` sets the root directory within which artfifacts will be
- stored. This is overridable when creating an experiment.
- The principal inconsistency in the current interface is that default artifact
- root URIs are determined on construction of the file and sqlalchemy stores, but
- by the server in the REST store case.
- In addition, a custom root artifact URI passed to
- `mlflow.tracking.utils._get_store()` is ignored for the file store case (the
- `store_uri` is passed as both `store_uri` and `artifact_root_uri` to
- `FileStore`).
- Proposals:
- 1. The default root artifact URI is made configurable when building a REST
- store.
- 2. The default root artifact URI is removed as an option when constructing a
- file or sqlalchemy store and is either:
- a. Configured in the actual file / sqlalchemy stores (stored in a config
- file in the file store case and in a config table in the sqlalchemy store
- case). This would mirror the behaviour of the REST store, which
- determines the default location on the server side.
- b. Read from an environment variable instead of being passed in through
- code.
- 3. Leave the current interface as-is.
- ## Artifact repository
- The API for building artifact repositories makes sense, except for the need to
- pass through an associated tracking store. This is only used by the DBFS
- artifact repository, which uses the tracking store's `get_host_creds` attribute
- (only in place on a `RestStore`) to avoid loading the host credentials multiple
- times.
- If the motivation for accessing the host credentials through the tracking store
- is to avoid reloading the host credentials multiple times, we propose moving
- this implicit caching to another part of the code base. I propose that each
- artifact store takes only a URI as an argument, and is then responsible for
- loading any extra credentials it needs to access that URI.
- This is implicitly already the case with artifact repositories like the S3
- artifact repository. It relies on boto, which will load AWS credentials from
- the caller's environment as appropriate.
- Proposal:
- * Remove `store` as an argument from
- `mlflow.store.artifact_repo.ArtifactRepository.from_artifact_uri()`
- * Remove `get_host_creds` as an argument from
- `mlflow.store.dbfs_artifact_repo.DbfsArtifactRepository()`
- * Have the `DbfsArtifactRepository` call
- `mlflow.utils.databricks_utils.get_databricks_host_creds` to get host
- credentials instead of passing them in at construction time (optionally add
- caching around the above if it's expensive)
- This would result in a simpler and consistent interface for constructing
- `ArtifactRepository`s - only the URI would be needed.
- ## Back to tracking stores
- The above proposed change to the artifact repository interface has the benefit
- of having a consistent API for constructing them (only a URI needs to be
- passed). The same simplification could be brought to the tracking store by
- having the `RestStore` take a tracking URI instead of a
- `get_host_creds` function.
- The difference in logic required between Databricks and non-Databricks REST
- stores can be provided by two slightly different implementations:
- ```python
- from mlflow.store.abstract_store import AbstractStore
- from mlflow.utils import rest_utils
- from mlflow.utils.databricks_utils import get_databricks_host_creds
- class AbstractRestStore(AbstractStore):
- def __init__(self, store_uri):
- super(AbstractRestStore, self).__init__()
- self.store_uri = store_uri
- @abstractmethod
- def _get_host_creds(self):
- pass
- class RestStore(AbstractRestStore):
- def _get_host_creds(self):
- # Currently in mlflow.tracking.utils._get_rest_store
- return rest_utils.MlflowHostCreds(
- host=self.store_uri,
- username=os.environ.get(_TRACKING_USERNAME_ENV_VAR),
- password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR),
- token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
- ignore_tls_verification=os.environ.get(_TRACKING_INSECURE_TLS_ENV_VAR) == 'true',
- )
- class DatabricksRestStore(AbstractRestStore):
- def _get_host_creds(self):
- # Get the Databricks profile specified by the tracking URI
- parsed_uri = urllib.parse.urlparse(self.store_uri)
- profile = parsed_uri.netloc
- return get_databricks_host_creds(profile)
- ```
Add Comment
Please, Sign In to add comment