|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Copyright (c) 2025 Oracle and/or its affiliates. |
| 4 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 5 | + |
| 6 | +from typing import Any, Callable, Dict, List, Mapping, Optional |
| 7 | + |
| 8 | +import requests |
| 9 | +from langchain_core.embeddings import Embeddings |
| 10 | +from langchain_core.language_models.llms import create_base_retry_decorator |
| 11 | +from pydantic import BaseModel, Field |
| 12 | + |
| 13 | +DEFAULT_HEADER = { |
| 14 | +"Content-Type": "application/json", |
| 15 | +} |
| 16 | + |
| 17 | + |
| 18 | +class TokenExpiredError(Exception): |
| 19 | +pass |
| 20 | + |
| 21 | + |
| 22 | +def _create_retry_decorator(llm) -> Callable[[Any], Any]: |
| 23 | +"""Creates a retry decorator.""" |
| 24 | +errors = [requests.exceptions.ConnectTimeout, TokenExpiredError] |
| 25 | +decorator = create_base_retry_decorator( |
| 26 | +error_types=errors, max_retries=llm.max_retries |
| 27 | +) |
| 28 | +return decorator |
| 29 | + |
| 30 | + |
| 31 | +class OCIDataScienceEmbedding(BaseModel, Embeddings): |
| 32 | +"""Embedding model deployed on OCI Data Science Model Deployment. |
| 33 | +
|
| 34 | +Example: |
| 35 | +
|
| 36 | +.. code-block:: python |
| 37 | +
|
| 38 | +from ads.llm import OCIDataScienceEmbedding |
| 39 | +
|
| 40 | +embeddings = OCIDataScienceEmbedding( |
| 41 | +endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict", |
| 42 | +) |
| 43 | +""" # noqa: E501 |
| 44 | + |
| 45 | +auth: dict = Field(default_factory=dict, exclude=True) |
| 46 | +"""ADS auth dictionary for OCI authentication: |
| 47 | +https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html. |
| 48 | +This can be generated by calling `ads.common.auth.api_keys()` |
| 49 | +or `ads.common.auth.resource_principal()`. If this is not |
| 50 | +provided then the `ads.common.default_signer()` will be used.""" |
| 51 | + |
| 52 | +endpoint: str = "" |
| 53 | +"""The uri of the endpoint from the deployed Model Deployment model.""" |
| 54 | + |
| 55 | +model_kwargs: Optional[Dict] = None |
| 56 | +"""Keyword arguments to pass to the model.""" |
| 57 | + |
| 58 | +endpoint_kwargs: Optional[Dict] = None |
| 59 | +"""Optional attributes (except for headers) passed to the request.post |
| 60 | +function. |
| 61 | +""" |
| 62 | + |
| 63 | +max_retries: int = 1 |
| 64 | +"""The maximum number of retries to make when generating.""" |
| 65 | + |
| 66 | +@property |
| 67 | +def _identifying_params(self) -> Mapping[str, Any]: |
| 68 | +"""Get the identifying parameters.""" |
| 69 | +_model_kwargs = self.model_kwargs or {} |
| 70 | +return { |
| 71 | +**{"endpoint": self.endpoint}, |
| 72 | +**{"model_kwargs": _model_kwargs}, |
| 73 | +} |
| 74 | + |
| 75 | +def _embed_with_retry(self, **kwargs) -> Any: |
| 76 | +"""Use tenacity to retry the call.""" |
| 77 | +retry_decorator = _create_retry_decorator(self) |
| 78 | + |
| 79 | +@retry_decorator |
| 80 | +def _completion_with_retry(**kwargs: Any) -> Any: |
| 81 | +try: |
| 82 | +response = requests.post(self.endpoint, **kwargs) |
| 83 | +response.raise_for_status() |
| 84 | +return response |
| 85 | +except requests.exceptions.HTTPError as http_err: |
| 86 | +if response.status_code == 401 and self._refresh_signer(): |
| 87 | +raise TokenExpiredError() from http_err |
| 88 | +else: |
| 89 | +raise ValueError( |
| 90 | +f"Server error: {str(http_err)}. Message: {response.text}" |
| 91 | +) from http_err |
| 92 | +except Exception as e: |
| 93 | +raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e |
| 94 | + |
| 95 | +return _completion_with_retry(**kwargs) |
| 96 | + |
| 97 | +def _embedding(self, texts: List[str]) -> List[List[float]]: |
| 98 | +"""Call out to OCI Data Science Model Deployment Endpoint. |
| 99 | +
|
| 100 | +Args: |
| 101 | +texts: A list of texts to embed. |
| 102 | +
|
| 103 | +Returns: |
| 104 | +A list of list of floats representing the embeddings, or None if an |
| 105 | +error occurs. |
| 106 | +""" |
| 107 | +_model_kwargs = self.model_kwargs or {} |
| 108 | +body = self._construct_request_body(texts, _model_kwargs) |
| 109 | +request_kwargs = self._construct_request_kwargs(body) |
| 110 | +response = self._embed_with_retry(**request_kwargs) |
| 111 | +return self._proceses_response(response) |
| 112 | + |
| 113 | +def _construct_request_kwargs(self, body: Any) -> dict: |
| 114 | +"""Constructs the request kwargs as a dictionary.""" |
| 115 | +from ads.model.common.utils import _is_json_serializable |
| 116 | + |
| 117 | +_endpoint_kwargs = self.endpoint_kwargs or {} |
| 118 | +headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER) |
| 119 | +return ( |
| 120 | +dict( |
| 121 | +headers=headers, |
| 122 | +json=body, |
| 123 | +auth=self.auth.get("signer"), |
| 124 | +**_endpoint_kwargs, |
| 125 | +) |
| 126 | +if _is_json_serializable(body) |
| 127 | +else dict( |
| 128 | +headers=headers, |
| 129 | +data=body, |
| 130 | +auth=self.auth.get("signer"), |
| 131 | +**_endpoint_kwargs, |
| 132 | +) |
| 133 | +) |
| 134 | + |
| 135 | +def _construct_request_body(self, texts: List[str], params: dict) -> Any: |
| 136 | +"""Constructs the request body.""" |
| 137 | +return {"input": texts} |
| 138 | + |
| 139 | +def _proceses_response(self, response: requests.Response) -> List[List[float]]: |
| 140 | +"""Extracts results from requests.Response.""" |
| 141 | +try: |
| 142 | +res_json = response.json() |
| 143 | +embeddings = res_json["data"][0]["embedding"] |
| 144 | +except Exception as e: |
| 145 | +raise ValueError( |
| 146 | +f"Error raised by inference API: {e}.\nResponse: {response.text}" |
| 147 | +) from e |
| 148 | +return embeddings |
| 149 | + |
| 150 | +def embed_documents( |
| 151 | +self, |
| 152 | +texts: List[str], |
| 153 | +chunk_size: Optional[int] = None, |
| 154 | +) -> List[List[float]]: |
| 155 | +"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint. |
| 156 | +
|
| 157 | +Args: |
| 158 | +texts: The list of texts to embed. |
| 159 | +chunk_size: The chunk size defines how many input texts will |
| 160 | +be grouped together as request. If None, will use the |
| 161 | +chunk size specified by the class. |
| 162 | +
|
| 163 | +Returns: |
| 164 | +List of embeddings, one for each text. |
| 165 | +""" |
| 166 | +results = [] |
| 167 | +_chunk_size = ( |
| 168 | +len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size |
| 169 | +) |
| 170 | +for i in range(0, len(texts), _chunk_size): |
| 171 | +response = self._embedding(texts[i : i + _chunk_size]) |
| 172 | +results.extend(response) |
| 173 | +return results |
| 174 | + |
| 175 | +def embed_query(self, text: str) -> List[float]: |
| 176 | +"""Compute query embeddings using OCI Data Science Model Deployment Endpoint. |
| 177 | +
|
| 178 | +Args: |
| 179 | +text: The text to embed. |
| 180 | +
|
| 181 | +Returns: |
| 182 | +Embeddings for the text. |
| 183 | +""" |
| 184 | +return self._embedding([text])[0] |
0 commit comments