diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 579eada3d7..51ba5720b0 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -15,3 +15,4 @@ Advanced guide advanced_guide/piechart_icons advanced_guide/polygons_from_list_of_points advanced_guide/customize_javascript_and_css + advanced_guide/override_leaflet_class_methods diff --git a/docs/advanced_guide/override_leaflet_class_methods.md b/docs/advanced_guide/override_leaflet_class_methods.md new file mode 100644 index 0000000000..de605aee52 --- /dev/null +++ b/docs/advanced_guide/override_leaflet_class_methods.md @@ -0,0 +1,46 @@ +# Overriding Leaflet class methods + +```{code-cell} ipython3 +--- +nbsphinx: hidden +--- +import folium +``` + +## Customizing Leaflet behavior +Sometimes you want to override Leaflet's javascript behavior. This can be done using the `Class.include` statement. This mimics Leaflet's +`L.Class.include` method. See [here](https://leafletjs.com/examples/extending/extending-1-classes.html) for more details. + +### Example: adding an authentication header to a TileLayer +One such use case is if you need to override the `createTile` on `L.TileLayer`, because your tiles are hosted on an oauth2 protected +server. This can be done like this: + +```{code-cell} +create_tile = folium.JsCode(""" + function(coords, done) { + const url = this.getTileUrl(coords); + const img = document.createElement('img'); + fetch(url, { + headers: { + "Authorization": "Bearer " + }, + }) + .then((response) => { + img.src = URL.createObjectURL(response.body); + done(null, img); + }) + return img; + } +""") + +folium.TileLayer.include(create_tile=create_tile) +tiles = folium.TileLayer( + tiles="OpenStreetMap", +) +m = folium.Map( + tiles=tiles, +) + + +m = folium.Map() +``` diff --git a/folium/elements.py b/folium/elements.py index 8344abc3a8..f52e8b6fa0 100644 --- a/folium/elements.py +++ b/folium/elements.py @@ -159,6 +159,26 @@ def __init__(self, element_name: str, element_parent_name: str): self.element_parent_name = element_parent_name +class IncludeStatement(MacroElement): + """Generate an include statement on a class.""" + + _template = Template( + """ + {{ this.leaflet_class_name }}.include( + {{ this.options | tojavascript }} + ) + """ + ) + + def __init__(self, leaflet_class_name: str, **kwargs): + super().__init__() + self.leaflet_class_name = leaflet_class_name + self.options = kwargs + + def render(self, *args, **kwargs): + return super().render(*args, **kwargs) + + class MethodCall(MacroElement): """Abstract class to add an element to another element.""" diff --git a/folium/features.py b/folium/features.py index 982b7b3e54..8f7e5230c7 100644 --- a/folium/features.py +++ b/folium/features.py @@ -36,7 +36,7 @@ from folium.elements import JSCSSMixin from folium.folium import Map -from folium.map import FeatureGroup, Icon, Layer, Marker, Popup, Tooltip +from folium.map import Class, FeatureGroup, Icon, Layer, Marker, Popup, Tooltip from folium.template import Template from folium.utilities import ( JsCode, @@ -2023,7 +2023,7 @@ def __init__( self.add_child(PolyLine(val, color=key, weight=weight, opacity=opacity)) -class Control(JSCSSMixin, MacroElement): +class Control(JSCSSMixin, Class): """ Add a Leaflet Control object to the map diff --git a/folium/map.py b/folium/map.py index 0d57822d37..278b97a1bb 100644 --- a/folium/map.py +++ b/folium/map.py @@ -4,12 +4,12 @@ """ import warnings -from collections import OrderedDict -from typing import TYPE_CHECKING, Optional, Sequence, Union, cast +from collections import OrderedDict, defaultdict +from typing import TYPE_CHECKING, DefaultDict, Optional, Sequence, Union, cast from branca.element import Element, Figure, Html, MacroElement -from folium.elements import ElementAddToElement, EventHandler +from folium.elements import ElementAddToElement, EventHandler, IncludeStatement from folium.template import Template from folium.utilities import ( JsCode, @@ -22,11 +22,58 @@ validate_location, ) + +class classproperty: + def __init__(self, f): + self.f = f + + def __get__(self, obj, owner): + return self.f(owner) + + if TYPE_CHECKING: from folium.features import CustomIcon, DivIcon -class Evented(MacroElement): +class Class(MacroElement): + """The root class of the leaflet class hierarchy""" + + _includes: DefaultDict[str, dict] = defaultdict(dict) + + @classmethod + def include(cls, **kwargs): + cls._includes[cls].update(**kwargs) + + @classproperty + def includes(cls): + return cls._includes[cls] + + @property + def leaflet_class_name(self): + # TODO: I did not check all Folium classes to see if + # this holds up. This breaks at least for CustomIcon. + return f"L.{self._name}" + + def render(self, **kwargs): + figure = self.get_root() + assert isinstance( + figure, Figure + ), "You cannot render this Element if it is not in a Figure." + if self.includes: + stmt = IncludeStatement(self.leaflet_class_name, **self.includes) + # A bit weird. I tried adding IncludeStatement directly to both + # figure and script, but failed. So we render this ourself. + figure.script.add_child( + Element(stmt._template.render(this=stmt, kwargs=self.includes)), + # make sure each class include gets rendered only once + name=self._name + "_includes", + # make sure this renders before the element itself + index=-1, + ) + super().render(**kwargs) + + +class Evented(Class): """The base class for Layer and Map Adds the `on` and `once` methods for event handling capabilities. diff --git a/tests/test_map.py b/tests/test_map.py index cc3728586a..cf6635a0b9 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -10,8 +10,8 @@ import pytest from folium import GeoJson, Map, TileLayer -from folium.map import CustomPane, Icon, LayerControl, Marker, Popup -from folium.utilities import normalize +from folium.map import Class, CustomPane, Icon, LayerControl, Marker, Popup +from folium.utilities import JsCode, normalize tmpl = """
" + }, + }) + .then((response) => { + img.src = URL.createObjectURL(response.body); + done(null, img); + }) + return img; + } + """ + TileLayer.include(create_tile=JsCode(create_tile)) + tiles = TileLayer( + tiles="OpenStreetMap", + ) + m = Map( + tiles=tiles, + ) + rendered = m.get_root().render() + Class._includes.clear() + expected = """ + L.TileLayer.include({ + "createTile": + function(coords, done) { + const url = this.getTileUrl(coords); + const img = document.createElement('img'); + fetch(url, { + headers: { + "Authorization": "Bearer " + }, + }) + .then((response) => { + img.src = URL.createObjectURL(response.body); + done(null, img); + }) + return img; + }, + }) + """ + assert normalize(expected) in normalize(rendered) + + +def test_include_once(): + abc = "MY BEAUTIFUL SENTINEL" + TileLayer.include(abc=abc) + tiles = TileLayer( + tiles="OpenStreetMap", + ) + m = Map( + tiles=tiles, + ) + TileLayer( + tiles="OpenStreetMap", + ).add_to(m) + + rendered = m.get_root().render() + Class._includes.clear() + + assert rendered.count(abc) == 1, "Includes should happen only once per class" + + def test_popup_backticks(): m = Map() popup = Popup("back`tick`tick").add_to(m)