Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- """Tiny async Gentoo distfiles cache proxy"""
- import asyncio
- from pathlib import Path, PurePosixPath
- import aiohttp
- from aiohttp import web
- ORIGINS = [
- "https://gentoo.mirrors.ovh.net/gentoo-distfiles/",
- "https://distfiles.gentoo.org/releases/amd64/binpackages/23.0/x86-64/",
- ]
- CACHE_DIR = Path("cache").resolve()
- HOST = "0.0.0.0"
- PORT = 80
- CHUNK_SIZE = 1 << 16
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
- locks: dict[str, asyncio.Lock] = {}
- session: aiohttp.ClientSession | None = None
- def clean_relpath(raw: str) -> str:
- rel = PurePosixPath(raw)
- if rel.is_absolute() or ".." in rel.parts or rel.as_posix() == "":
- raise web.HTTPBadRequest(text="invalid path")
- return rel.as_posix()
- async def get_session() -> aiohttp.ClientSession:
- global session
- if session is None:
- timeout = aiohttp.ClientTimeout(total=None, connect=30)
- connector = aiohttp.TCPConnector(limit=0, ttl_dns_cache=300)
- session = aiohttp.ClientSession(timeout=timeout, connector=connector)
- return session
- async def handler(request: web.Request) -> web.StreamResponse:
- if request.method not in ("GET", "HEAD"):
- raise web.HTTPMethodNotAllowed(request.method, ["GET", "HEAD"])
- rel = clean_relpath(request.match_info.get("tail", ""))
- local_path = CACHE_DIR / rel
- if request.method == "HEAD":
- if local_path.exists():
- return web.FileResponse(local_path)
- return await head_upstream(rel)
- if local_path.exists():
- print(f"cached -> {rel}")
- return web.FileResponse(local_path)
- return await stream_and_cache(request, rel, local_path)
- async def head_upstream(rel: str) -> web.Response:
- sess = await get_session()
- for origin in ORIGINS:
- url = origin.rstrip("/") + "/" + rel
- async with sess.head(url) as upstream:
- if upstream.status == 404:
- continue
- if upstream.status >= 400:
- raise web.HTTPBadGateway(text=f"upstream status {upstream.status}")
- headers = {
- name: upstream.headers[name]
- for name in ("Content-Length", "Content-Type", "ETag", "Last-Modified")
- if name in upstream.headers
- }
- return web.Response(status=upstream.status, headers=headers)
- raise web.HTTPNotFound()
- async def stream_and_cache(
- request: web.Request, rel: str, local_path: Path
- ) -> web.StreamResponse:
- sess = await get_session()
- lock = locks.setdefault(rel, asyncio.Lock())
- async with lock:
- assert not local_path.exists()
- tmp_path = local_path.with_name(local_path.name + ".part")
- tmp_path.parent.mkdir(parents=True, exist_ok=True)
- response = web.StreamResponse()
- client_connected = True
- try:
- for origin in ORIGINS:
- url = origin.rstrip("/") + "/" + rel
- print(f"cache miss -> {url}")
- async with sess.get(url) as upstream:
- if upstream.status == 404:
- continue
- if upstream.status >= 400:
- raise web.HTTPBadGateway(
- text=f"upstream status {upstream.status}"
- )
- for name in ("Content-Length", "Content-Type"):
- if name in upstream.headers:
- response.headers[name] = upstream.headers[name]
- response.set_status(upstream.status)
- await response.prepare(request)
- with tmp_path.open("wb") as fh:
- async for chunk in upstream.content.iter_chunked(CHUNK_SIZE):
- fh.write(chunk)
- if client_connected:
- try:
- await response.write(chunk)
- except ConnectionResetError:
- client_connected = False
- tmp_path.replace(local_path)
- if client_connected:
- await response.write_eof()
- print(f"cached {rel}")
- return response
- raise web.HTTPNotFound()
- except Exception:
- if tmp_path.exists():
- try:
- tmp_path.unlink()
- except OSError:
- pass
- raise
- finally:
- locks.pop(rel, None)
- app = web.Application()
- app.router.add_route("*", "/{tail:.*}", handler)
- async def _cleanup(_: web.Application) -> None:
- if session:
- await session.close()
- app.on_cleanup.append(_cleanup)
- web.run_app(app, host=HOST, port=PORT)
Advertisement