|
| 1 | +local config = require('orgmode.config') |
| 2 | +local fs = require('orgmode.utils.fs') |
| 3 | +local utils = require('orgmode.utils') |
| 4 | +local Menu = require('orgmode.ui.menu') |
| 5 | +local State = require('orgmode.state.state') |
| 6 | +local Promise = require('orgmode.utils.promise') |
| 7 | + |
| 8 | +local M = {} |
| 9 | + |
| 10 | +---Return true if the URI should be fetched. |
| 11 | +---@param uri string |
| 12 | +---@return OrgPromise<boolean> safe |
| 13 | +function M.should_fetch(uri) |
| 14 | + local policy = config.org_resource_download_policy |
| 15 | + return Promise.resolve(policy == 'always' or M.is_uri_safe(uri)):next(function(safe) |
| 16 | + if safe then |
| 17 | + return true |
| 18 | + end |
| 19 | + if policy == 'prompt' then |
| 20 | + return M.confirm_safe(uri) |
| 21 | + end |
| 22 | + return false |
| 23 | + end) |
| 24 | +end |
| 25 | + |
| 26 | +---@param resource_uri string |
| 27 | +---@param file_uri string | false |
| 28 | +---@param patterns string[] |
| 29 | +---@return boolean matches |
| 30 | +local function check_patterns(resource_uri, file_uri, patterns) |
| 31 | + for _, pattern in ipairs(patterns) do |
| 32 | + local re = vim.regex(pattern) |
| 33 | + if re:match_str(resource_uri) or (file_uri and re:match_str(file_uri)) then |
| 34 | + return true |
| 35 | + end |
| 36 | + end |
| 37 | + return false |
| 38 | +end |
| 39 | + |
| 40 | +---Check the uri matches any of the (configured or cached) safe patterns. |
| 41 | +---@param uri string |
| 42 | +---@return OrgPromise<boolean> safe |
| 43 | +function M.is_uri_safe(uri) |
| 44 | + local current_file = fs.get_real_path(utils.current_file_path()) |
| 45 | + ---@type string | false # deduced type is `string | boolean` |
| 46 | + local file_uri = current_file and vim.uri_from_fname(current_file) or false |
| 47 | + local uri_patterns = {} |
| 48 | + if config.org_safe_remote_resources then |
| 49 | + vim.list_extend(uri_patterns, config.org_safe_remote_resources) |
| 50 | + end |
| 51 | + return State:load():next(function(state) |
| 52 | + local cached = state['org_safe_remote_resources'] |
| 53 | + if cached then |
| 54 | + vim.list_extend(uri_patterns, cached) |
| 55 | + end |
| 56 | + return check_patterns(uri, file_uri, uri_patterns) |
| 57 | + end) |
| 58 | +end |
| 59 | + |
| 60 | +---@param uri string |
| 61 | +---@return string escaped |
| 62 | +local function uri_to_pattern(uri) |
| 63 | + -- Escape backslashes, disable magic characters, anchor front and back of the |
| 64 | + -- pattern. |
| 65 | + return string.format([[\V\^%s\$]], uri:gsub([[\]], [[\\]])) |
| 66 | +end |
| 67 | + |
| 68 | +---@param filename string |
| 69 | +---@return string escaped |
| 70 | +local function filename_to_pattern(filename) |
| 71 | + return uri_to_pattern(vim.uri_from_fname(filename)) |
| 72 | +end |
| 73 | + |
| 74 | +---@param domain string |
| 75 | +---@return string escaped |
| 76 | +local function domain_to_pattern(domain) |
| 77 | + -- We construct the following regex: |
| 78 | + -- 1. http or https protocol; |
| 79 | + -- 2. followed by userinfo (`name:password@`), |
| 80 | + -- 3. followed by potentially `www.` (for convenience), |
| 81 | + -- 4. followed by the domain (in very-nomagic mode) |
| 82 | + -- 5. followed by either a slash or nothing at all. |
| 83 | + return string.format( |
| 84 | + [[\v^https?://([^@/?#]*\@)?(www\.)?(\V%s\v)($|/)]], |
| 85 | + -- `domain` here includes the host name and port. If it doesn't contain |
| 86 | + -- characters illegal in a host or port, this encoding should do nothing. |
| 87 | + -- If it contains illegal characters, the domain is broken in a safe way. |
| 88 | + vim.uri_encode(domain) |
| 89 | + ) |
| 90 | +end |
| 91 | + |
| 92 | +---@param pattern string |
| 93 | +---@return OrgPromise<OrgState> |
| 94 | +local function cache_safe_pattern(pattern) |
| 95 | + ---@param state OrgState |
| 96 | + return State:load():next(function(state) |
| 97 | + -- We manipulate `cached` in a strange way here to ensure that `state` gets |
| 98 | + -- marked as dirty. |
| 99 | + local patterns = { pattern } |
| 100 | + local cached = state['org_safe_remote_resources'] |
| 101 | + if cached then |
| 102 | + vim.list_extend(patterns, cached) |
| 103 | + end |
| 104 | + state['org_safe_remote_resources'] = patterns |
| 105 | + end) |
| 106 | +end |
| 107 | + |
| 108 | +---Ask the user if URI should be considered safe. |
| 109 | +---@param uri string |
| 110 | +---@return OrgPromise<boolean> safe |
| 111 | +function M.confirm_safe(uri) |
| 112 | + ---@type OrgMenu |
| 113 | + return Promise.new(function(resolve) |
| 114 | + local menu = Menu:new({ |
| 115 | + title = string.format('An org-mode document would like to download %s, which is not considered safe.', uri), |
| 116 | + prompt = 'Do you want to download this?', |
| 117 | + }) |
| 118 | + menu:add_option({ |
| 119 | + key = '!', |
| 120 | + label = 'Yes, and mark it as safe.', |
| 121 | + action = function() |
| 122 | + cache_safe_pattern(uri_to_pattern(uri)) |
| 123 | + return true |
| 124 | + end, |
| 125 | + }) |
| 126 | + local authority = uri:match('^https?://([^/?#]*)') |
| 127 | + -- `domain` here includes the host name and port. |
| 128 | + local domain = authority and authority:match('^[^@]*@(.*)$') or authority |
| 129 | + if domain then |
| 130 | + menu:add_option({ |
| 131 | + key = 'd', |
| 132 | + label = string.format('Yes, and mark the domain as safe. (%s)', domain), |
| 133 | + action = function() |
| 134 | + cache_safe_pattern(domain_to_pattern(domain)) |
| 135 | + return true |
| 136 | + end, |
| 137 | + }) |
| 138 | + end |
| 139 | + local filename = fs.get_real_path(utils.current_file_path()) |
| 140 | + if filename then |
| 141 | + menu:add_option({ |
| 142 | + key = 'f', |
| 143 | + label = string.format('Yes, and mark the org file as safe. (%s)', filename), |
| 144 | + action = function() |
| 145 | + cache_safe_pattern(filename_to_pattern(filename)) |
| 146 | + return true |
| 147 | + end, |
| 148 | + }) |
| 149 | + end |
| 150 | + menu:add_option({ |
| 151 | + key = 'y', |
| 152 | + label = 'Yes, just this once.', |
| 153 | + action = function() |
| 154 | + return true |
| 155 | + end, |
| 156 | + }) |
| 157 | + menu:add_option({ |
| 158 | + key = 'n', |
| 159 | + label = 'No, skip this resource.', |
| 160 | + action = function() |
| 161 | + return false |
| 162 | + end, |
| 163 | + }) |
| 164 | + menu:add_separator({ icon = ' ', length = 1 }) |
| 165 | + resolve(menu:open()) |
| 166 | + end) |
| 167 | +end |
| 168 | + |
| 169 | +return M |
0 commit comments