Skip to content

Commit c92d0f6

Browse files
authored
Added Langchain Embedding Plugin (#1045)
2 parents 7ce9342 + 61a5e04 commit c92d0f6

File tree

6 files changed

+284
-8
lines changed

6 files changed

+284
-8
lines changed

‎ads/llm/__init__.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
try:
87
import langchain
9-
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10-
OCIModelDeploymentVLLM,
11-
OCIModelDeploymentTGI,
12-
)
8+
9+
from ads.llm.chat_template import ChatTemplates
1310
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
1411
ChatOCIModelDeployment,
15-
ChatOCIModelDeploymentVLLM,
1612
ChatOCIModelDeploymentTGI,
13+
ChatOCIModelDeploymentVLLM,
14+
)
15+
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
16+
OCIDataScienceEmbedding,
17+
)
18+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
19+
OCIModelDeploymentTGI,
20+
OCIModelDeploymentVLLM,
1721
)
18-
from ads.llm.chat_template import ChatTemplates
1922
except ImportError as ex:
2023
if ex.name == "langchain":
2124
raise ImportError(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from typing import Any, Callable, Dict, List, Mapping, Optional
7+
8+
import requests
9+
from langchain_core.embeddings import Embeddings
10+
from langchain_core.language_models.llms import create_base_retry_decorator
11+
from pydantic import BaseModel, Field
12+
13+
DEFAULT_HEADER = {
14+
"Content-Type": "application/json",
15+
}
16+
17+
18+
class TokenExpiredError(Exception):
19+
pass
20+
21+
22+
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
23+
"""Creates a retry decorator."""
24+
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
25+
decorator = create_base_retry_decorator(
26+
error_types=errors, max_retries=llm.max_retries
27+
)
28+
return decorator
29+
30+
31+
class OCIDataScienceEmbedding(BaseModel, Embeddings):
32+
"""Embedding model deployed on OCI Data Science Model Deployment.
33+
34+
Example:
35+
36+
.. code-block:: python
37+
38+
from ads.llm import OCIDataScienceEmbedding
39+
40+
embeddings = OCIDataScienceEmbedding(
41+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
42+
)
43+
""" # noqa: E501
44+
45+
auth: dict = Field(default_factory=dict, exclude=True)
46+
"""ADS auth dictionary for OCI authentication:
47+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
48+
This can be generated by calling `ads.common.auth.api_keys()`
49+
or `ads.common.auth.resource_principal()`. If this is not
50+
provided then the `ads.common.default_signer()` will be used."""
51+
52+
endpoint: str = ""
53+
"""The uri of the endpoint from the deployed Model Deployment model."""
54+
55+
model_kwargs: Optional[Dict] = None
56+
"""Keyword arguments to pass to the model."""
57+
58+
endpoint_kwargs: Optional[Dict] = None
59+
"""Optional attributes (except for headers) passed to the request.post
60+
function.
61+
"""
62+
63+
max_retries: int = 1
64+
"""The maximum number of retries to make when generating."""
65+
66+
@property
67+
def _identifying_params(self) -> Mapping[str, Any]:
68+
"""Get the identifying parameters."""
69+
_model_kwargs = self.model_kwargs or {}
70+
return {
71+
**{"endpoint": self.endpoint},
72+
**{"model_kwargs": _model_kwargs},
73+
}
74+
75+
def _embed_with_retry(self, **kwargs) -> Any:
76+
"""Use tenacity to retry the call."""
77+
retry_decorator = _create_retry_decorator(self)
78+
79+
@retry_decorator
80+
def _completion_with_retry(**kwargs: Any) -> Any:
81+
try:
82+
response = requests.post(self.endpoint, **kwargs)
83+
response.raise_for_status()
84+
return response
85+
except requests.exceptions.HTTPError as http_err:
86+
if response.status_code == 401 and self._refresh_signer():
87+
raise TokenExpiredError() from http_err
88+
else:
89+
raise ValueError(
90+
f"Server error: {str(http_err)}. Message: {response.text}"
91+
) from http_err
92+
except Exception as e:
93+
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
94+
95+
return _completion_with_retry(**kwargs)
96+
97+
def _embedding(self, texts: List[str]) -> List[List[float]]:
98+
"""Call out to OCI Data Science Model Deployment Endpoint.
99+
100+
Args:
101+
texts: A list of texts to embed.
102+
103+
Returns:
104+
A list of list of floats representing the embeddings, or None if an
105+
error occurs.
106+
"""
107+
_model_kwargs = self.model_kwargs or {}
108+
body = self._construct_request_body(texts, _model_kwargs)
109+
request_kwargs = self._construct_request_kwargs(body)
110+
response = self._embed_with_retry(**request_kwargs)
111+
return self._proceses_response(response)
112+
113+
def _construct_request_kwargs(self, body: Any) -> dict:
114+
"""Constructs the request kwargs as a dictionary."""
115+
from ads.model.common.utils import _is_json_serializable
116+
117+
_endpoint_kwargs = self.endpoint_kwargs or {}
118+
headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
119+
return (
120+
dict(
121+
headers=headers,
122+
json=body,
123+
auth=self.auth.get("signer"),
124+
**_endpoint_kwargs,
125+
)
126+
if _is_json_serializable(body)
127+
else dict(
128+
headers=headers,
129+
data=body,
130+
auth=self.auth.get("signer"),
131+
**_endpoint_kwargs,
132+
)
133+
)
134+
135+
def _construct_request_body(self, texts: List[str], params: dict) -> Any:
136+
"""Constructs the request body."""
137+
return {"input": texts}
138+
139+
def _proceses_response(self, response: requests.Response) -> List[List[float]]:
140+
"""Extracts results from requests.Response."""
141+
try:
142+
res_json = response.json()
143+
embeddings = res_json["data"][0]["embedding"]
144+
except Exception as e:
145+
raise ValueError(
146+
f"Error raised by inference API: {e}.\nResponse: {response.text}"
147+
) from e
148+
return embeddings
149+
150+
def embed_documents(
151+
self,
152+
texts: List[str],
153+
chunk_size: Optional[int] = None,
154+
) -> List[List[float]]:
155+
"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
156+
157+
Args:
158+
texts: The list of texts to embed.
159+
chunk_size: The chunk size defines how many input texts will
160+
be grouped together as request. If None, will use the
161+
chunk size specified by the class.
162+
163+
Returns:
164+
List of embeddings, one for each text.
165+
"""
166+
results = []
167+
_chunk_size = (
168+
len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
169+
)
170+
for i in range(0, len(texts), _chunk_size):
171+
response = self._embedding(texts[i : i + _chunk_size])
172+
results.extend(response)
173+
return results
174+
175+
def embed_query(self, text: str) -> List[float]:
176+
"""Compute query embeddings using OCI Data Science Model Deployment Endpoint.
177+
178+
Args:
179+
text: The text to embed.
180+
181+
Returns:
182+
Embeddings for the text.
183+
"""
184+
return self._embedding([text])[0]

‎docs/source/user_guide/large_language_model/langchain_models.rst

+20
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,26 @@ Chat models takes `chat messages <https://python.langchain.com/docs/concepts/#me
127127
print(chunk.content, end="")
128128
129129
130+
Embedding Models
131+
================
132+
133+
You can also use embedding model that's hosted on a `OCI Data Science Model Deployment <https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm>`_.
134+
135+
136+
.. code-block:: python3
137+
138+
from ads.llm import OCIDataScienceEmbedding
139+
140+
# Create an instance of OCI Model Deployment Endpoint
141+
# Replace the endpoint uri with your own
142+
embeddings = OCIDataScienceEmbedding(
143+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict",
144+
)
145+
146+
query = "Hello World!"
147+
embeddings.embed_query(query)
148+
149+
130150
Tool Calling
131151
============
132152

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2025 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2025 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
"""Test OCI Data Science Model Deployment Endpoint."""
8+
9+
import pytest
10+
import sys
11+
from unittest.mock import MagicMock,
12+
13+
if sys.version_info < (3, 9):
14+
pytest.skip(allow_module_level=True)
15+
16+
from ads.llm import OCIDataScienceEmbedding
17+
18+
19+
@("ads.llm.OCIDataScienceEmbedding._embed_with_retry")
20+
def test_embed_documents(mock_embed_with_retry) -> None:
21+
"""Test valid call to oci model deployment endpoint."""
22+
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
23+
result = MagicMock()
24+
result.json = MagicMock(
25+
return_value={
26+
"data": [{"embedding": expected_output}],
27+
}
28+
)
29+
mock_embed_with_retry.return_value = result
30+
endpoint = "https://MD_OCID/predict"
31+
documents = ["Hello", "World"]
32+
33+
embeddings = OCIDataScienceEmbedding(
34+
endpoint=endpoint,
35+
)
36+
37+
output = embeddings.embed_documents(documents)
38+
assert output == expected_output
39+
40+
41+
@("ads.llm.OCIDataScienceEmbedding._embed_with_retry")
42+
def test_embed_query(mock_embed_with_retry) -> None:
43+
"""Test valid call to oci model deployment endpoint."""
44+
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
45+
result = MagicMock()
46+
result.json = MagicMock(
47+
return_value={
48+
"data": [{"embedding": expected_output}],
49+
}
50+
)
51+
mock_embed_with_retry.return_value = result
52+
endpoint = "https://MD_OCID/predict"
53+
query = "Hello world"
54+
55+
embeddings = OCIDataScienceEmbedding(
56+
endpoint=endpoint,
57+
)
58+
59+
output = embeddings.embed_query(query)
60+
assert output == expected_output[0]

0 commit comments

Comments
 (0)