Signed-off-by: ngn <ngn@ngn.tf>
This commit is contained in:
90
lib/__init__.py
Normal file
90
lib/__init__.py
Normal file
@ -0,0 +1,90 @@
|
||||
from .provider import provider
|
||||
from .config import config
|
||||
|
||||
from .gitea import gitea
|
||||
from .github import github
|
||||
|
||||
from os import remove, listdir, rmdir
|
||||
from tempfile import mkdtemp
|
||||
from os.path import join
|
||||
import subprocess as sp
|
||||
from typing import List
|
||||
import atexit
|
||||
|
||||
PROVIDERS = {
|
||||
"gitea": gitea,
|
||||
"github": github,
|
||||
}
|
||||
|
||||
|
||||
class upstream:
|
||||
def __init__(self, dir=".") -> None:
|
||||
self.config: config = config(dir=dir)
|
||||
self.provider: provider
|
||||
self._tempdir: str = ""
|
||||
|
||||
url = self.config.url("upstream")
|
||||
prov = self.config.str("provider")
|
||||
|
||||
for n, p in PROVIDERS.items():
|
||||
if n == prov.lower():
|
||||
self.provider = p(url)
|
||||
break
|
||||
|
||||
if self.provider is None:
|
||||
raise Exception("invalid provider %s" % provider)
|
||||
|
||||
if not self.provider.check():
|
||||
raise Exception("upstream is not available")
|
||||
|
||||
# self.cleanup() will run when the program exits
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def tempdir(self, path="") -> str:
|
||||
if self._tempdir == "":
|
||||
self._tempdir = mkdtemp(prefix="ups_")
|
||||
return self._tempdir if path == "" else join(self._tempdir, path)
|
||||
|
||||
def commit(self, commit="") -> str:
|
||||
if commit != "":
|
||||
self.config.set_str("commit", commit)
|
||||
return commit
|
||||
return self.config.str("commit")
|
||||
|
||||
def until(self, until: str) -> List[str]:
|
||||
return self.provider.until(until)
|
||||
|
||||
def last(self, count=1) -> List[str]:
|
||||
return self.provider.last(count)
|
||||
|
||||
# download a patch for the commit and run all the scripts on it
|
||||
def download(self, commit: str) -> str:
|
||||
path = self.tempdir("%s.patch" % commit)
|
||||
self.provider.download(commit, path)
|
||||
|
||||
for s in self.config.list("scripts"):
|
||||
proc = sp.run(["sed", "-e", s, "-i", path], stderr=sp.PIPE)
|
||||
if proc.returncode != 0:
|
||||
raise Exception("script '%s': %s" % (s, proc.stderr))
|
||||
|
||||
return path
|
||||
|
||||
# apply the patch file using ups-apply
|
||||
def apply(self, commit: str) -> None:
|
||||
path = self.tempdir("%s.patch" % commit)
|
||||
proc = sp.run(["ups-apply", path])
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise Exception(
|
||||
"apply script returned non-zero exit code: %d" % proc.returncode
|
||||
)
|
||||
|
||||
# cleanup the temp directory
|
||||
def cleanup(self) -> None:
|
||||
if self._tempdir == "":
|
||||
return
|
||||
|
||||
for f in listdir(self._tempdir):
|
||||
remove(join(self._tempdir, f))
|
||||
|
||||
rmdir(self._tempdir)
|
81
lib/config.py
Normal file
81
lib/config.py
Normal file
@ -0,0 +1,81 @@
|
||||
from .util import validate_url, validate_list
|
||||
from os import getcwd, path
|
||||
from typing import List
|
||||
import json
|
||||
|
||||
CONFIG_FILE = "ups.json"
|
||||
|
||||
|
||||
class config:
|
||||
def __init__(self, dir="") -> None:
|
||||
if dir == "":
|
||||
dir = getcwd()
|
||||
|
||||
self.path = path.join(dir, CONFIG_FILE)
|
||||
self.conf = {}
|
||||
|
||||
self.load()
|
||||
|
||||
# load the configuration from the file
|
||||
def load(self) -> None:
|
||||
try:
|
||||
f = open(self.path, "r")
|
||||
content = f.read()
|
||||
f.close()
|
||||
except Exception as e:
|
||||
raise Exception("failed to read config: %s" % e)
|
||||
else:
|
||||
self.conf = json.loads(content)
|
||||
f.close()
|
||||
|
||||
# save the current configuration
|
||||
def save(self) -> None:
|
||||
content = json.dumps(self.conf, indent=2)
|
||||
|
||||
try:
|
||||
f = open(self.path, "w")
|
||||
f.write(content)
|
||||
f.close()
|
||||
except Exception as e:
|
||||
raise Exception("failed to save config: %s" % e)
|
||||
|
||||
# set a URL key in the configuration
|
||||
def set_url(self, key: str, val: str) -> None:
|
||||
if not validate_url(val):
|
||||
raise Exception("expected a valid URL for %s" % key)
|
||||
|
||||
self.conf[key] = val
|
||||
self.save()
|
||||
|
||||
# set a string key in the configuration
|
||||
def set_str(self, key: str, val: str) -> None:
|
||||
if val == "" or val is None:
|
||||
raise Exception("expected a non-empty string for %s" % key)
|
||||
|
||||
self.conf[key] = val
|
||||
self.save()
|
||||
|
||||
def set_list(self, key: str, val: List[str]) -> None:
|
||||
if not validate_list(val):
|
||||
raise Exception("expected a non-empty list for %s" % key)
|
||||
|
||||
self.conf[key] = val
|
||||
self.save()
|
||||
|
||||
# read a URL key from the configuration
|
||||
def url(self, key: str) -> str:
|
||||
if key not in self.conf.keys() or not validate_url(self.conf[key]):
|
||||
return ""
|
||||
return self.conf[key]
|
||||
|
||||
# read a string key from the configuration file
|
||||
def str(self, key: str) -> str:
|
||||
if key not in self.conf.keys():
|
||||
return ""
|
||||
return self.conf[key]
|
||||
|
||||
# read a list from the configuration file
|
||||
def list(self, key: str) -> List[str]:
|
||||
if key not in self.conf.keys() or not validate_list(self.conf[key]):
|
||||
return []
|
||||
return self.conf[key]
|
63
lib/gitea.py
Normal file
63
lib/gitea.py
Normal file
@ -0,0 +1,63 @@
|
||||
from .provider import provider
|
||||
from urllib import request
|
||||
from typing import List
|
||||
|
||||
COMMITS_URL = "PROTO://HOST/api/v1/repos/OWNER/REPO/commits"
|
||||
REPO_URL = "PROTO://HOST/api/v1/repos/OWNER/REPO"
|
||||
PATCH_URL = ""
|
||||
|
||||
|
||||
class gitea(provider):
|
||||
def __init__(self, url: str) -> None:
|
||||
super().__init__(url)
|
||||
|
||||
def check(self) -> bool:
|
||||
res = self.GET(REPO_URL)
|
||||
return not res["private"]
|
||||
|
||||
def until(self, commit: str) -> List[str]:
|
||||
commits = []
|
||||
found = False
|
||||
page = 1
|
||||
|
||||
while not found:
|
||||
res = self.GET(
|
||||
COMMITS_URL,
|
||||
{
|
||||
"limit": 50,
|
||||
"stat": False,
|
||||
"page": page,
|
||||
},
|
||||
)
|
||||
|
||||
if len(res) <= 0:
|
||||
break
|
||||
|
||||
for c in res:
|
||||
if c["sha"] == commit:
|
||||
found = True
|
||||
break
|
||||
commits.append(c["sha"])
|
||||
|
||||
page += 1
|
||||
|
||||
return commits
|
||||
|
||||
def last(self, count=1) -> List[str]:
|
||||
commits = []
|
||||
res = self.GET(
|
||||
COMMITS_URL,
|
||||
{
|
||||
"limit": count,
|
||||
"stat": False,
|
||||
},
|
||||
)
|
||||
|
||||
[commits.append(c["sha"]) for c in res]
|
||||
return commits
|
||||
|
||||
def download(self, commit: str, file: str) -> None:
|
||||
url = self.url(
|
||||
"PROTO://HOST/api/v1/repos/OWNER/REPO/git/commits/%s.patch" % commit
|
||||
)
|
||||
request.urlretrieve(url, file)
|
58
lib/github.py
Normal file
58
lib/github.py
Normal file
@ -0,0 +1,58 @@
|
||||
from .provider import provider
|
||||
from urllib import request
|
||||
from typing import List
|
||||
|
||||
REPO_URL = "https://api.github.com/repos/OWNER/REPO"
|
||||
COMMITS_URL = "https://api.github.com/repos/OWNER/REPO/commits"
|
||||
|
||||
|
||||
class github(provider):
|
||||
def __init__(self, url: str) -> None:
|
||||
super().__init__(url)
|
||||
|
||||
def check(self) -> bool:
|
||||
res = self.GET(REPO_URL)
|
||||
return res["full_name"] == "%s/%s" % (self.owner, self.repo)
|
||||
|
||||
def until(self, commit: str) -> List[str]:
|
||||
commits = []
|
||||
found = False
|
||||
page = 1
|
||||
|
||||
while not found:
|
||||
res = self.GET(
|
||||
COMMITS_URL,
|
||||
{
|
||||
"per_page": 50,
|
||||
"page": page,
|
||||
},
|
||||
)
|
||||
|
||||
if len(res) <= 0:
|
||||
break
|
||||
|
||||
for c in res:
|
||||
if c["sha"] == commit:
|
||||
found = True
|
||||
break
|
||||
commits.append(c["sha"])
|
||||
|
||||
page += 1
|
||||
|
||||
return commits
|
||||
|
||||
def last(self, count=1) -> List[str]:
|
||||
commits = []
|
||||
res = self.GET(
|
||||
COMMITS_URL,
|
||||
{
|
||||
"per_page": count,
|
||||
},
|
||||
)
|
||||
|
||||
[commits.append(c["sha"]) for c in res]
|
||||
return commits
|
||||
|
||||
def download(self, commit: str, file: str) -> None:
|
||||
url = self.url("PROTO://github.com/OWNER/REPO/commit/%s.patch" % commit)
|
||||
request.urlretrieve(url, file)
|
40
lib/log.py
Normal file
40
lib/log.py
Normal file
@ -0,0 +1,40 @@
|
||||
from time import strftime, localtime
|
||||
from sys import stdout, stderr
|
||||
from os import path
|
||||
import inspect
|
||||
|
||||
BLUE = "\033[34m"
|
||||
YELLOW = "\033[33m"
|
||||
RED = "\033[31m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
def _log(color: str, level: str, msg: str, err=False) -> int:
|
||||
frame = inspect.stack()[2]
|
||||
filename = path.basename(inspect.getmodule(frame[0]).__file__)
|
||||
|
||||
timestr = strftime("%H:%M:%S", localtime())
|
||||
funcstr = "%s:%d" % (filename, frame.lineno)
|
||||
|
||||
write = stderr.write if err else stdout.write
|
||||
return write(
|
||||
"%s[%s]%s %s %s: %s\n"
|
||||
% (color, level.lower(), RESET, timestr, funcstr, msg)
|
||||
)
|
||||
|
||||
|
||||
def info(msg: str) -> None:
|
||||
_log(BLUE, "info", msg)
|
||||
|
||||
|
||||
def warn(msg: str) -> None:
|
||||
_log(YELLOW, "warn", msg, err=True)
|
||||
|
||||
|
||||
def fail(msg: str, exception=None) -> None:
|
||||
size = _log(RED, "fail", msg, err=True)
|
||||
|
||||
if exception is not None:
|
||||
size -= len(msg) + len(RED) + len(RESET) + 1 + len("details: ")
|
||||
stderr.write(" " * size + "details: ")
|
||||
stderr.write(exception.__str__() + "\n")
|
78
lib/provider.py
Normal file
78
lib/provider.py
Normal file
@ -0,0 +1,78 @@
|
||||
from os.path import dirname, basename
|
||||
from urllib.parse import urlencode
|
||||
import urllib.parse as urlparse
|
||||
from .util import validate_url
|
||||
import requests as req
|
||||
from typing import List
|
||||
|
||||
|
||||
class provider:
|
||||
def __init__(self, url: str) -> None:
|
||||
parsed = validate_url(url)
|
||||
|
||||
if parsed is None:
|
||||
raise Exception("invalid provider URL: %s" % url)
|
||||
|
||||
self.host = parsed.hostname
|
||||
self.protocol = parsed.scheme
|
||||
self.owner = dirname(parsed.path)
|
||||
self.repo = basename(parsed.path)
|
||||
|
||||
if self.owner[0] == "/":
|
||||
self.owner = self.owner[1:]
|
||||
|
||||
if self.repo[0] == "/":
|
||||
self.repo = self.repo[1:]
|
||||
|
||||
if self.owner == "" or self.repo == "":
|
||||
raise Exception("invalid provider URL: %s" % url)
|
||||
|
||||
def url(self, url: str, queries={}) -> str:
|
||||
url = url.replace("HOST", self.host)
|
||||
url = url.replace("PROTO", self.protocol)
|
||||
url = url.replace("OWNER", self.owner)
|
||||
url = url.replace("REPO", self.repo)
|
||||
|
||||
parsed = list(urlparse.urlparse(url))
|
||||
|
||||
query = dict(urlparse.parse_qsl(parsed[4]))
|
||||
query.update(queries)
|
||||
|
||||
parsed[4] = urlencode(query)
|
||||
url = urlparse.urlunparse(parsed)
|
||||
|
||||
return url
|
||||
|
||||
# send a HTTP GET request to the provider
|
||||
def GET(self, url: str, queries={}, code=200) -> dict:
|
||||
url = self.url(url, queries=queries)
|
||||
|
||||
# res = req.get(url, headers={
|
||||
# "User-Agent": "ups (https://git.ngn.tf/ngn/ups)",
|
||||
# })
|
||||
|
||||
res = req.get(url)
|
||||
|
||||
if res.status_code != code:
|
||||
raise Exception(
|
||||
"expected %d from %s, received %d"
|
||||
% (code, url, res.status_code)
|
||||
)
|
||||
|
||||
return res.json()
|
||||
|
||||
# check if the provider is available
|
||||
def check(self) -> bool:
|
||||
return False
|
||||
|
||||
# get all the commits until the target commit
|
||||
def until(self, commit: str) -> List[str]:
|
||||
return []
|
||||
|
||||
# get last "count" commits
|
||||
def last(self, count=1) -> List[str]:
|
||||
return []
|
||||
|
||||
# download the patch for a commit
|
||||
def download(self, commit: str, file: str) -> None:
|
||||
return False
|
16
lib/util.py
Normal file
16
lib/util.py
Normal file
@ -0,0 +1,16 @@
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
from typing import List
|
||||
|
||||
|
||||
def validate_url(url: str) -> ParseResult:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if all([parsed.scheme, parsed.netloc]):
|
||||
return parsed
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_list(_list: List[str]) -> bool:
|
||||
return _list is not None and len(_list) != 0
|
Reference in New Issue
Block a user