Skip to content

Commit 3d8c97c

Browse files
committed
implement pygit2 backend for fetch_refspec
fix #168 1. Add order select for `_backend_func`. 2. Raise exception for fetch_refspec for ssh:// repo on Windows. 3. Add order select for _backend_func
1 parent 4984eb3 commit 3d8c97c

File tree

4 files changed

+139
-6
lines changed

4 files changed

+139
-6
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ install_requires=
3232
pathspec>=0.9.0
3333
asyncssh>=2.7.1,<3
3434
funcy>=1.14
35+
shortuuid>=0.5.0
3536

3637
[options.extras_require]
3738
tests =

src/scmrepo/git/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __getitem__(self, key: str) -> BaseGitBackend:
4141
"""Lazily initialize backends and cache it afterwards"""
4242
initialized = self.initialized.get(key)
4343
if not initialized:
44+
if key not in self.backends and key in self.DEFAULT:
45+
raise NotImplementedError
4446
backend = self.backends[key]
4547
initialized = backend(*self.args, **self.kwargs)
4648
self.initialized[key] = initialized
@@ -266,11 +268,13 @@ def no_commits(self):
266268
# https://github.com/iterative/dvc/issues/5641
267269
# https://github.com/iterative/dvc/issues/7458
268270
def _backend_func(self, name, *args, **kwargs):
269-
for key, backend in self.backends.items():
271+
backends: Iterable[str] = kwargs.pop("backends", self.backends)
272+
for key in backends:
270273
if self._last_backend is not None and key != self._last_backend:
271274
self.backends[self._last_backend].close()
272275
self._last_backend = None
273276
try:
277+
backend = self.backends[key]
274278
func = getattr(backend, name)
275279
result = func(*args, **kwargs)
276280
self._last_backend = key
@@ -333,7 +337,9 @@ def add_commit(
333337
iter_remote_refs = partialmethod(_backend_func, "iter_remote_refs")
334338
get_refs_containing = partialmethod(_backend_func, "get_refs_containing")
335339
push_refspecs = partialmethod(_backend_func, "push_refspecs")
336-
fetch_refspecs = partialmethod(_backend_func, "fetch_refspecs")
340+
fetch_refspecs = partialmethod(
341+
_backend_func, "fetch_refspecs", backends=["pygit2", "dulwich"]
342+
)
337343
_stash_iter = partialmethod(_backend_func, "_stash_iter")
338344
_stash_push = partialmethod(_backend_func, "_stash_push")
339345
_stash_apply = partialmethod(_backend_func, "_stash_apply")

src/scmrepo/git/backend/pygit2.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import (
88
TYPE_CHECKING,
99
Callable,
10+
Dict,
11+
Generator,
1012
Iterable,
1113
List,
1214
Mapping,
@@ -15,7 +17,8 @@
1517
Union,
1618
)
1719

18-
from funcy import cached_property
20+
from funcy import cached_property, reraise
21+
from shortuuid import uuid
1922

2023
from scmrepo.exceptions import CloneError, MergeConflictError, RevError, SCMError
2124
from scmrepo.utils import relpath
@@ -27,6 +30,8 @@
2730

2831

2932
if TYPE_CHECKING:
33+
from pygit2.remote import Remote # type: ignore
34+
3035
from scmrepo.progress import GitProgressEvent
3136

3237

@@ -412,6 +417,52 @@ def push_refspecs(
412417
) -> Mapping[str, SyncStatus]:
413418
raise NotImplementedError
414419

420+
def _merge_remote_branch(
421+
self,
422+
rh: str,
423+
lh: str,
424+
force: bool = False,
425+
on_diverged: Optional[Callable[[str, str], bool]] = None,
426+
) -> SyncStatus:
427+
import pygit2
428+
429+
rh_rev = self.resolve_rev(rh)
430+
431+
if force:
432+
self.set_ref(lh, rh_rev)
433+
return SyncStatus.SUCCESS
434+
435+
try:
436+
merge_result, _ = self.repo.merge_analysis(rh_rev, lh)
437+
except KeyError:
438+
self.set_ref(lh, rh_rev)
439+
return SyncStatus.SUCCESS
440+
441+
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
442+
return SyncStatus.UP_TO_DATE
443+
if merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
444+
self.set_ref(lh, rh_rev)
445+
return SyncStatus.SUCCESS
446+
if merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
447+
if on_diverged and on_diverged(lh, rh_rev):
448+
return SyncStatus.SUCCESS
449+
return SyncStatus.DIVERGED
450+
logger.debug("Unexpected merge result: %s", pygit2.GIT_MERGE_ANALYSIS_NORMAL)
451+
raise SCMError("Unknown merge analysis result")
452+
453+
@contextmanager
454+
def get_remote(self, url: str) -> Generator["Remote", None, None]:
455+
try:
456+
yield self.repo.remotes[url]
457+
except ValueError:
458+
try:
459+
remote_name = uuid()
460+
yield self.repo.remotes.create(remote_name, url)
461+
finally:
462+
self.repo.remotes.delete(remote_name)
463+
except KeyError:
464+
raise SCMError(f"'{url}' is not a valid Git remote or URL")
465+
415466
def fetch_refspecs(
416467
self,
417468
url: str,
@@ -421,7 +472,58 @@ def fetch_refspecs(
421472
progress: Callable[["GitProgressEvent"], None] = None,
422473
**kwargs,
423474
) -> Mapping[str, SyncStatus]:
424-
raise NotImplementedError
475+
from pygit2 import GitError
476+
477+
if isinstance(refspecs, str):
478+
refspecs = [refspecs]
479+
480+
with self.get_remote(url) as remote:
481+
if os.name == "nt" and remote.url.startswith("ssh://"):
482+
raise NotImplementedError
483+
484+
if os.name == "nt" and remote.url.startswith("file://"):
485+
url = remote.url[len("file://") :]
486+
self.repo.remotes.set_url(remote.name, url)
487+
remote = self.repo.remotes[remote.name]
488+
489+
fetch_refspecs: List[str] = []
490+
for refspec in refspecs:
491+
if ":" in refspec:
492+
lh, rh = refspec.split(":")
493+
else:
494+
lh = rh = refspec
495+
if not rh.startswith("refs/"):
496+
rh = f"refs/heads/{rh}"
497+
if not lh.startswith("refs/"):
498+
lh = f"refs/heads/{lh}"
499+
rh = rh[len("refs/") :]
500+
refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}"
501+
fetch_refspecs.append(refspec)
502+
503+
logger.debug("fetch_refspecs: %s", fetch_refspecs)
504+
with reraise(
505+
GitError,
506+
SCMError(f"Git failed to fetch ref from '{url}'"),
507+
):
508+
remote.fetch(refspecs=fetch_refspecs)
509+
510+
result: Dict[str, "SyncStatus"] = {}
511+
for refspec in fetch_refspecs:
512+
_, rh = refspec.split(":")
513+
if not rh.endswith("*"):
514+
refname = rh.split("/", 3)[-1]
515+
refname = f"refs/{refname}"
516+
result[refname] = self._merge_remote_branch(
517+
rh, refname, force, on_diverged
518+
)
519+
continue
520+
rh = rh.rstrip("*").rstrip("/") + "/"
521+
for branch in self.iter_refs(base=rh):
522+
refname = f"refs/{branch[len(rh):]}"
523+
result[refname] = self._merge_remote_branch(
524+
branch, refname, force, on_diverged
525+
)
526+
return result
425527

426528
def _stash_iter(self, ref: str):
427529
raise NotImplementedError

tests/test_git.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22
import shutil
3-
from typing import Any, Dict, Optional, Type
3+
from typing import Any, Dict, List, Optional, Type
44

55
import pytest
66
from asyncssh import SFTPClient
77
from asyncssh.connection import SSHClientConnection
88
from dulwich.client import LocalGitClient
99
from git import Repo as GitPythonRepo
10+
from pygit2 import GitError
11+
from pygit2.remote import Remote # type: ignore
1012
from pytest_mock import MockerFixture
1113
from pytest_test_utils import TempDirFactory, TmpDir
1214
from pytest_test_utils.matchers import Matcher
@@ -306,7 +308,7 @@ def test_push_refspecs(
306308
assert remote_scm.get_ref("refs/foo/baz") is None
307309

308310

309-
@pytest.mark.skip_git_backend("pygit2", "gitpython")
311+
@pytest.mark.skip_git_backend("gitpython")
310312
@pytest.mark.parametrize("use_url", [True, False])
311313
def test_fetch_refspecs(
312314
tmp_dir: TmpDir,
@@ -362,8 +364,11 @@ def test_fetch_refspecs(
362364

363365
with pytest.raises(SCMError):
364366
mocker.patch.object(LocalGitClient, "fetch", side_effect=KeyError)
367+
mocker.patch.object(Remote, "fetch", side_effect=GitError)
365368
git.fetch_refspecs(remote, "refs/foo/bar:refs/foo/bar")
366369

370+
assert len(scm.pygit2.repo.remotes) == 1
371+
367372

368373
@pytest.mark.skip_git_backend("pygit2", "gitpython")
369374
@pytest.mark.parametrize("use_url", [True, False])
@@ -1046,3 +1051,22 @@ def test_is_dirty_untracked(
10461051
tmp_dir.gen("untracked", "untracked")
10471052
assert git.is_dirty(untracked_files=True)
10481053
assert not git.is_dirty(untracked_files=False)
1054+
1055+
1056+
@pytest.mark.parametrize(
1057+
"backends", [["gitpython", "dulwich"], ["dulwich", "gitpython"]]
1058+
)
1059+
def test_backend_func(
1060+
tmp_dir: TmpDir,
1061+
scm: Git,
1062+
backends: List[str],
1063+
mocker: MockerFixture,
1064+
):
1065+
from functools import partial
1066+
1067+
scm.add = partial(scm._backend_func, "add", backends=backends)
1068+
tmp_dir.gen({"foo": "foo"})
1069+
backend = getattr(scm, backends[0])
1070+
mock = mocker.spy(backend, "add")
1071+
scm.add(["foo"])
1072+
mock.assert_called_once_with(["foo"])

0 commit comments

Comments
 (0)