diff --git a/agents/l4m_agent.py b/agents/l4m_agent.py deleted file mode 100644 index 302da89..0000000 --- a/agents/l4m_agent.py +++ /dev/null @@ -1,31 +0,0 @@ -from langchain.agents import initialize_agent -from langchain.agents import AgentType -from langchain.prompts import MessagesPlaceholder -from langchain.memory import ConversationBufferMemory - - -def base_agent( - llm, tools, agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION -): - """Base agent to perform xyz slippy map tiles operations. - - llm: LLM object - tools: List of tools to use by the agent - """ - # chat_history = MessagesPlaceholder(variable_name="chat_history") - # memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) - agent = initialize_agent( - llm=llm, - tools=tools, - agent=agent_type, - max_iterations=5, - early_stopping_method="generate", - verbose=True, - # memory=memory, - # agent_kwargs={ - # "memory_prompts": [chat_history], - # "input_variables": ["input", "agent_scratchpad", "chat_history"], - # }, - ) - print("agent initialized") - return agent diff --git a/app.py b/app.py index cf916ae..c0c688e 100644 --- a/app.py +++ b/app.py @@ -1,211 +1,20 @@ -import os - -import rasterio as rio -import folium import streamlit as st -from streamlit_folium import folium_static - -import langchain -from langchain.agents import AgentType -from langchain.chat_models import ChatOpenAI -from langchain.tools import Tool, DuckDuckGoSearchRun -from langchain.callbacks import ( - StreamlitCallbackHandler, - AimCallbackHandler, - get_openai_callback, -) - -from tools.mercantile_tool import MercantileTool -from tools.geopy.geocode import GeopyGeocodeTool -from tools.geopy.distance import GeopyDistanceTool -from tools.osmnx.geometry import OSMnxGeometryTool -from tools.osmnx.network import OSMnxNetworkTool -from tools.stac.search import STACSearchTool -from agents.l4m_agent import base_agent - -# DEBUG -langchain.debug = True - - -@st.cache_resource(ttl="1h") -def get_agent( - openai_api_key, agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION -): - llm = ChatOpenAI( - temperature=0, - openai_api_key=openai_api_key, - model_name="gpt-3.5-turbo-0613", - ) - # define a set of tools the agent has access to for queries - duckduckgo_tool = Tool( - name="DuckDuckGo", - description="Use this tool to answer questions about current events and places. \ - Please ask targeted questions.", - func=DuckDuckGoSearchRun().run, - ) - geocode_tool = GeopyGeocodeTool() - distance_tool = GeopyDistanceTool() - mercantile_tool = MercantileTool() - geometry_tool = OSMnxGeometryTool() - network_tool = OSMnxNetworkTool() - search_tool = STACSearchTool() - - tools = [ - duckduckgo_tool, - geocode_tool, - distance_tool, - mercantile_tool, - geometry_tool, - network_tool, - search_tool, - ] - - agent = base_agent(llm, tools, agent_type=agent_type) - return agent - - -def run_query(agent, query): - return response - - -def plot_raster(items): - st.subheader("Preview of the first item sorted by cloud cover") - selected_item = min(items, key=lambda item: item.properties["eo:cloud_cover"]) - href = selected_item.assets["rendered_preview"].href - # arr = rio.open(href).read() - - # m = folium.Map(location=[28.6, 77.7], zoom_start=6) - - # img = folium.raster_layers.ImageOverlay( - # name="Sentinel 2", - # image=arr.transpose(1, 2, 0), - # bounds=selected_item.bbox, - # opacity=0.9, - # interactive=True, - # cross_origin=False, - # zindex=1, - # ) - - # img.add_to(m) - # folium.LayerControl().add_to(m) - - # folium_static(m) - st.image(href) - - -def plot_vector(df): - st.subheader("Add the geometry to the Map") - center = df.centroid.iloc[0] - m = folium.Map(location=[center.y, center.x], zoom_start=12) - folium.GeoJson(df).add_to(m) - folium_static(m) - - -st.set_page_config(page_title="LLLLM", page_icon="🤖", layout="wide") -st.subheader("🤖 I am Geo LLM Agent!") - -if "msgs" not in st.session_state: - st.session_state.msgs = [] - -if "total_tokens" not in st.session_state: - st.session_state.total_tokens = 0 - -if "prompt_tokens" not in st.session_state: - st.session_state.prompt_tokens = 0 - -if "completion_tokens" not in st.session_state: - st.session_state.completion_tokens = 0 - -if "total_cost" not in st.session_state: - st.session_state.total_cost = 0 - -with st.sidebar: - openai_api_key = os.getenv("OPENAI_API_KEY") - if not openai_api_key: - openai_api_key = st.text_input("OpenAI API Key", type="password") - - st.subheader("OpenAI Usage") - total_tokens = st.empty() - prompt_tokens = st.empty() - completion_tokens = st.empty() - total_cost = st.empty() - - total_tokens.write(f"Total Tokens: {st.session_state.total_tokens:,.0f}") - prompt_tokens.write(f"Prompt Tokens: {st.session_state.prompt_tokens:,.0f}") - completion_tokens.write( - f"Completion Tokens: {st.session_state.completion_tokens:,.0f}" - ) - total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}") - - -for msg in st.session_state.msgs: - with st.chat_message(name=msg["role"], avatar=msg["avatar"]): - st.markdown(msg["content"]) - -if prompt := st.chat_input("Ask me anything about the flat world..."): - with st.chat_message(name="user", avatar="🧑‍💻"): - st.markdown(prompt) - - st.session_state.msgs.append({"role": "user", "avatar": "🧑‍💻", "content": prompt}) - - if not openai_api_key: - st.info("Please add your OpenAI API key to continue.") - st.stop() - - aim_callback = AimCallbackHandler( - repo=".", - experiment_name="LLLLLM: Base Agent v0.1", - ) - - agent = get_agent(openai_api_key) - - with get_openai_callback() as cb: - st_callback = StreamlitCallbackHandler(st.container()) - response = agent.run(prompt, callbacks=[st_callback, aim_callback]) - - aim_callback.flush_tracker(langchain_asset=agent, reset=False, finish=True) - - # Log OpenAI stats - # print(f"Model name: {response.llm_output.get('model_name', '')}") - st.session_state.total_tokens += cb.total_tokens - st.session_state.prompt_tokens += cb.prompt_tokens - st.session_state.completion_tokens += cb.completion_tokens - st.session_state.total_cost += cb.total_cost - - total_tokens.write(f"Total Tokens: {st.session_state.total_tokens:,.0f}") - prompt_tokens.write(f"Prompt Tokens: {st.session_state.prompt_tokens:,.0f}") - completion_tokens.write( - f"Completion Tokens: {st.session_state.completion_tokens:,.0f}" - ) - total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}") - - with st.chat_message(name="assistant", avatar="🤖"): - if type(response) == str: - content = response - st.markdown(response) - else: - tool, result = response - - match tool: - case "stac-search": - content = f"Found {len(result)} items from the catalog." - st.markdown(content) - if len(result) > 0: - plot_raster(result) - case "geometry": - content = f"Found {len(result)} geometries." - gdf = result - st.markdown(content) - plot_vector(gdf) - case "network": - content = f"Found {len(result)} network geometries." - ndf = result - st.markdown(content) - plot_vector(ndf) - case _: - content = response - st.markdown(content) - - st.session_state.msgs.append( - {"role": "assistant", "avatar": "🤖", "content": content} - ) +from langchain_core.messages import HumanMessage + +from graphs.l4m_graph import graph + +if prompt := st.chat_input(): + st.chat_message("user").write(prompt) + config = {"configurable": {"thread_id": "1"}} + for chunk in graph.stream( + {"messages": [HumanMessage(content=prompt)]}, config, stream_mode="updates" + ): + # for chunk in graph.invoke( + # {"messages": [HumanMessage(content=prompt)]}, config, stream_mode="updates" + # ): + # st.markdown(chunk) + + node = "assistant" if "assistant" in chunk else "tools" + with st.chat_message(node): + for msg in chunk[node]["messages"]: + st.markdown(msg.content) diff --git a/environment.yaml b/environment.yaml index 631315c..49757d0 100644 --- a/environment.yaml +++ b/environment.yaml @@ -2,20 +2,22 @@ name: llllm-env channels: - conda-forge dependencies: - - python=3 + - python=3.12 - pip - - osmnx=1.3.1 + - rasterio - pip: - - openai==0.27.8 - - langchain==0.0.215 - - duckduckgo-search==3.8.3 - - mercantile==1.2.1 - - geopy==2.3.0 - - ipywidgets==8.0.6 - - jupyterlab==4.0.2 - - planetary-computer==0.5.1 - - pystac-client==0.7.2 - - streamlit==1.24.1 - - streamlit-folium==0.12.0 - - watchdog==3.0.0 - - aim==3.17.5 + - langchain + - langchain-ollama + - langchain-community + - duckduckgo-search + - mercantile + - geopy + - ipywidgets + - jupyterlab + - planetary-computer + - pystac-client + - streamlit + - streamlit-folium + - watchdog + - altair + - osmnx diff --git a/agents/.gitkeep b/graphs/.gitkeep similarity index 100% rename from agents/.gitkeep rename to graphs/.gitkeep diff --git a/graphs/l4m_graph.py b/graphs/l4m_graph.py new file mode 100644 index 0000000..376293c --- /dev/null +++ b/graphs/l4m_graph.py @@ -0,0 +1,72 @@ +import langchain +from langchain_core.messages import SystemMessage +from langchain_ollama import ChatOllama +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + +from tools.geopy.distance import distance_tool +from tools.geopy.geocode import geocode_tool +from tools.mercantile_tool import mercantile_tool +from tools.osmnx.geometry import geometry_tool +from tools.osmnx.network import network_tool +from tools.stac.search import stac_search + +# from tools.duck_tool import duckduckgo_tool + +# DEBUG +langchain.debug = True + +llm = ChatOllama( + model="llama3.2", + temperature=0, +) + +tools = [ + # duckduckgo_tool, + geocode_tool, + distance_tool, + mercantile_tool, + geometry_tool, + network_tool, + stac_search, +] + +# For this ipynb we set parallel tool calling to false as math generally is done sequentially, and this time we have 3 tools that can do math +# the OpenAI model specifically defaults to parallel tool calling for efficiency, see https://python.langchain.com/docs/how_to/tool_calling_parallel/ +# play around with it and see how the model behaves with math equations! +llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False) + + +# System message +sys_msg = SystemMessage( + content="You are a helpful assistant tasked with answering questions on a set of geographic inputs." + # "You are a helpful assistant tasked with performing arithmetic on a set of inputs. " + "do not use tools unless the message does not contain geographic inputs" + # "do NOT use tools unless strictly necessary to answer the question" + # " Do NOT answer the question, just reformulate it if needed and otherwise return it as is.." +) + + +# Node +def assistant(state: MessagesState): + return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} + + +# Graph +builder = StateGraph(MessagesState) + +# Define nodes: these do the work +builder.add_node("assistant", assistant) +builder.add_node("tools", ToolNode(tools)) + +# Define edges: these determine how the control flow moves +builder.add_edge(START, "assistant") +builder.add_conditional_edges( + "assistant", + # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools + # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END + tools_condition, +) +builder.add_edge("tools", "assistant") + +graph = builder.compile() diff --git a/tools/duck_tool.py b/tools/duck_tool.py new file mode 100644 index 0000000..5a24c03 --- /dev/null +++ b/tools/duck_tool.py @@ -0,0 +1,8 @@ +from langchain.tools import DuckDuckGoSearchRun, Tool + +duckduckgo_tool = Tool( + name="DuckDuckGo", + description="Use this tool to answer questions about current events and places. \ + Please ask targeted questions.", + func=DuckDuckGoSearchRun().run, +) diff --git a/tools/geopy/distance.py b/tools/geopy/distance.py index 79cbf4f..471d3a8 100644 --- a/tools/geopy/distance.py +++ b/tools/geopy/distance.py @@ -1,26 +1,22 @@ -from typing import Type +from typing import Union from geopy.distance import distance +from langchain_core.tools import tool from pydantic import BaseModel, Field -from langchain.tools import BaseTool class GeopyDistanceInput(BaseModel): """Input for GeopyDistanceTool.""" - point_1: tuple[float, float] = Field(..., description="lat,lng of a place") - point_2: tuple[float, float] = Field(..., description="lat,lng of a place") + lat1: float = Field(description="Latitude of a first location") + lon1: float = Field(description="Longitude of a first location") + lat2: float = Field(description="Latitude of a second location") + lon2: float = Field(description="Longitude of a second location") -class GeopyDistanceTool(BaseTool): - """Custom tool to calculate geodesic distance between two points.""" - - name: str = "distance" - args_schema: Type[BaseModel] = GeopyDistanceInput - description: str = "Use this tool to compute distance between two points available in lat,lng format." - - def _run(self, point_1: tuple[int, int], point_2: tuple[int, int]) -> float: - return ("distance", distance(point_1, point_2).km) - - def _arun(self, place: str): - raise NotImplementedError +@tool("distance-tool", args_schema=GeopyDistanceInput, return_direct=False) +def distance_tool(lat1: float, lon1: float, lat2: float, lon2: float) -> float: + """ + Tool to calculate distance in kilometers between two points. + """ + return distance((lat1, lon1), (lat2, lon2)).km diff --git a/tools/geopy/geocode.py b/tools/geopy/geocode.py index 8b14198..f113532 100644 --- a/tools/geopy/geocode.py +++ b/tools/geopy/geocode.py @@ -1,29 +1,24 @@ -from typing import Type +from typing import Tuple from geopy.geocoders import Nominatim +from langchain_core.tools import tool from pydantic import BaseModel, Field -from langchain.tools import BaseTool class GeopyGeocodeInput(BaseModel): """Input for GeopyGeocodeTool.""" - place: str = Field(..., description="name of a place") + place: str = Field(description="name of a place") -class GeopyGeocodeTool(BaseTool): - """Custom tool to perform geocoding.""" +@tool("geocode-tool", args_schema=GeopyGeocodeInput, return_direct=True) +def geocode_tool(place: str) -> Tuple[float, float]: + """ + Custom tool to perform geocoding. - name: str = "geocode" - args_schema: Type[BaseModel] = GeopyGeocodeInput - description: str = "Use this tool for geocoding." - - def _run(self, place: str) -> tuple: - locator = Nominatim(user_agent="geocode") - location = locator.geocode(place) - if location is None: - return ("geocode", "Not a recognised address in Nomatim.") - return ("geocode", (location.latitude, location.longitude)) - - def _arun(self, place: str): - raise NotImplementedError + Use this tool for geocoding an address of a place. + """ + return 30.1, 40.1 + locator = Nominatim(user_agent="geocode") + location = locator.geocode(place) + return location.latitude, location.longitude diff --git a/tools/mercantile_tool.py b/tools/mercantile_tool.py index 63e8bcd..e8ecf5f 100644 --- a/tools/mercantile_tool.py +++ b/tools/mercantile_tool.py @@ -1,19 +1,23 @@ -import mercantile -from langchain.tools import BaseTool +from typing import List +import mercantile +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool +from pydantic import BaseModel, Field -class MercantileTool(BaseTool): - """Tool to perform mercantile operations.""" - name = "mercantile" - description = "use this tool to get the xyz tiles for a place. \ - To use this tool you need to provide lng,lat,zoom level of the place separated by comma." +class MercantileToolInput(BaseModel): + latitude: float = Field(description="Latitude of a location") + longitude: float = Field(description="Longitude of a location") + zoom: str = Field(description="Zoom level for mercantile") - def _run(self, query): - lng, lat, zoom = map(float, query.split(",")) - return ("mercantile", mercantile.tile(lng, lat, zoom)) - def _arun(self, query): - raise NotImplementedError( - "Mercantile tool doesn't have an async implementation." - ) +@tool("mercantile-tool", args_schema=MercantileToolInput, return_direct=True) +def mercantile_tool(latitude: float, longitude: float, zoom: int) -> mercantile.Tile: + """ + Tool to perform mercantile operations. + Use this tool to get the xyz tiles for a place. To use this tool you need to provide + lng,lat,zoom level of the place separated by comma. + """ + lng, lat, zoom = map(float, query.split(",")) + return mercantile.tile(lng, lat, zoom) diff --git a/tools/osmnx/geometry.py b/tools/osmnx/geometry.py index ecc1287..ff23ea7 100644 --- a/tools/osmnx/geometry.py +++ b/tools/osmnx/geometry.py @@ -1,32 +1,26 @@ -from typing import Type, Dict +from typing import Dict, Type -import osmnx as ox import geopandas as gpd +import osmnx as ox +from geopandas import GeoDataFrame +from langchain_core.tools import tool from pydantic import BaseModel, Field -from langchain.tools import BaseTool class PlaceWithTags(BaseModel): "Name of a place on the map and tags in OSM." - - place: str = Field(..., description="name of a place on the map.") - tags: Dict[str, str] = Field(..., description="open street maps tags.") - - -class OSMnxGeometryTool(BaseTool): - """Tool to query geometries from Open Street Map (OSM).""" - - name: str = "geometry" - args_schema: Type[BaseModel] = PlaceWithTags - description: str = "Use this tool to get geometry of different features of the place like building footprints, parks, lakes, hospitals, schools etc. \ - Pass the name of the place & tags of OSM as args." - return_direct = True - - def _run(self, place: str, tags: Dict[str, str]) -> gpd.GeoDataFrame: - gdf = ox.geometries_from_place(place, tags) - gdf = gdf[gdf["geometry"].type.isin({"Polygon", "MultiPolygon"})] - gdf = gdf[["name", "geometry"]].reset_index(drop=True) - return ("geometry", gdf) - - def _arun(self, place: str): - raise NotImplementedError + place: str = Field(description="name of a place on the map.") + tags: Dict[str, str] = Field(description="open street maps tags.") + + +@tool("geometry-tool", args_schema=PlaceWithTags, return_direct=True) +def geometry_tool(place: str, tags: Dict[str, str]) -> GeoDataFrame: + """ + Tool to query geometries from Open Street Map (OSM). + Use this tool to get geometry of different features of the place + like building footprints, parks, lakes, hospitals, schools etc. + Pass the name of the place & tags of OSM as args. + """ + gdf = ox.geometries_from_place(place, tags) + gdf = gdf[gdf["geometry"].type.isin({"Polygon", "MultiPolygon"})] + return gdf[["name", "geometry"]].reset_index(drop=True) diff --git a/tools/osmnx/network.py b/tools/osmnx/network.py index 24f8375..8cc1f26 100644 --- a/tools/osmnx/network.py +++ b/tools/osmnx/network.py @@ -1,34 +1,26 @@ -from typing import Type, Dict - +import geopandas as gpd import osmnx as ox +from geopandas import GeoDataFrame +from langchain_core.tools import tool from osmnx import utils_graph -import geopandas as gpd from pydantic import BaseModel, Field -from langchain.tools import BaseTool class PlaceWithNetworktype(BaseModel): "Name of a place on the map" - place: str = Field(..., description="name of a place on the map") + place: str = Field(description="name of a place on the map") network_type: str = Field( - ..., description="network type: one of walk, bike, drive or all" + description="network type: one of walk, bike, drive or all" ) -class OSMnxNetworkTool(BaseTool): - """Custom tool to query road networks from OSM.""" - - name: str = "network" - args_schema: Type[BaseModel] = PlaceWithNetworktype - description: str = "Use this tool to get road network of a place. \ - Pass the name of the place & type of road network i.e walk, bike, drive or all." - return_direct = True - - def _run(self, place: str, network_type: str) -> gpd.GeoDataFrame: - G = ox.graph_from_place(place, network_type=network_type, simplify=True) - network = utils_graph.graph_to_gdfs(G, nodes=False) - network = network[["name", "geometry"]].reset_index(drop=True) - return ("network", network) - - def _arun(self, place: str): - raise NotImplementedError +@tool("network-tool", args_schema=PlaceWithNetworktype, return_direct=True) +def network_tool(place: str, network_type: str) -> GeoDataFrame: + """ + Custom tool to query road networks from OSM. + Use this tool to get road network of a place. + Pass the name of the place & type of road network i.e walk, bike, drive or all + """ + G = ox.graph_from_place(place, network_type=network_type, simplify=True) + network = utils_graph.graph_to_gdfs(G, nodes=False) + return network[["name", "geometry"]].reset_index(drop=True) diff --git a/tools/stac/search.py b/tools/stac/search.py index ce8d285..222a67d 100644 --- a/tools/stac/search.py +++ b/tools/stac/search.py @@ -1,41 +1,46 @@ -from typing import Type +from datetime import datetime +from typing import List, Type -from pystac_client import Client import planetary_computer as pc +from langchain_core.tools import tool from pydantic import BaseModel, Field -from langchain.tools import BaseTool +from pystac import Item +from pystac_client import Client PC_STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1" -class PlaceWithDatetimeAndBBox(BaseModel): - "Name of a place and date." +STAC_API = "https://earth-search.aws.element84.com/v1" +COLLECTION = "sentinel-2-l2a" + - bbox: str = Field(..., description="bbox of the place") - datetime: str = Field(..., description="datetime for the stac catalog search") +class StacSearchInput(BaseModel): + latitude: float = Field(description="Latitude of a location") + longitude: float = Field(description="Longitude of a location") + start: datetime = Field(description="Start date") + end: datetime = Field(description="End date") -class STACSearchTool(BaseTool): - """Tool to search for STAC items in a catalog.""" +@tool("stac-search-tool", args_schema=StacSearchInput, return_direct=True) +def stac_search( + latitude: float, longitude: float, start: datetime, end: datetime +) -> List[Item]: + """ + Search Sentinel-2 STAC items. - name: str = "stac-search" - args_schema: Type[BaseModel] = PlaceWithDatetimeAndBBox - description: str = "Use this tool to search for STAC items in a catalog. \ - Pass the bbox of the place & date as args." - return_direct = True + Use this tool to perform a STAC scene search for a Sentinel-2 images at a + latitude and longitude and between a start and an end date. + """ - def _run(self, bbox: str, datetime: str): - catalog = Client.open(PC_STAC_API, modifier=pc.sign_inplace) + catalog = Client.open(STAC_API) - search = catalog.search( - collections=["sentinel-2-l2a"], - bbox=bbox, - datetime=datetime, - max_items=10, - ) - items = search.get_all_items() + search = catalog.search( + collections=[COLLECTION], + datetime=f"{start.date()}/{end.date()}", + bbox=(longitude - 1e-5, latitude - 1e-5, longitude + 1e-5, latitude + 1e-5), + max_items=100, + ) - return ("stac-search", items) + items = search.get_all_items() - def _arun(self, bbox: str, datetime: str): - raise NotImplementedError + return [item for item in items]