"""CertDeploy Server daemon and SFTP parts."""
import json
import os
import socket
import time
from datetime import datetime, timedelta
from threading import Semaphore, Thread
from typing import Any, Optional
import paramiko
import schedule
from paramiko.ssh_exception import NoValidConnectionsError, SSHException
from .. import format_error
from ..errors import CertDeployError, ConfigError
from . import log
from .config import ServerConfig
from .config.client import ClientConnection
from .config.server import PushMode
from .renew import renew_certs
ONE_SHOT_TIMEOUT = None # seconds
GO_FAST_SLEEP = 0.1 # seconds
SLOW_DOWN_SLEEP = 30 # seconds
def _sftp_mkdir(sftp, path, mode=None):
"""Recursively make a remote directory (`path`) if needed."""
log.debug('_sftp_mkdir: path=%s, mode=%s', path, mode)
if path in ('', '/'):
return
mode = mode if mode is None else 0o700
try:
sftp.stat(path)
return
except FileNotFoundError:
pass
_sftp_mkdir(sftp, os.path.dirname(path), mode)
sftp.mkdir(path, mode=mode)
class _TimeoutTimer:
"""Timeout timer that uses time instead of a counter."""
def __init__(self, max_runtime: int = None):
"""Prepare the timer.
Arguments:
max_runtime: The maximum time to wait in seconds. `None` disables
the timer. Defaults to `None`.
"""
self._start = datetime.now()
self._enabled = False
if max_runtime is not None and max_runtime > 0:
self._enabled = True
self._max_runtime_tdelta = timedelta(seconds=max_runtime or 0)
def check(self, message=None):
"""Check if the timer has expired.
Arguments:
message: A message to pass to the `TimeoutError` if the timer has
expired.
Raises:
TimeoutError: When the timer has expired.
"""
if not self._enabled:
return
if datetime.now() - self._start > self._max_runtime_tdelta:
raise TimeoutError(message or '')
@classmethod
def start(cls, max_runtime: int = None) -> '_TimeoutTimer':
"""Start a new timer.
Arguments:
max_runtime: The maximum time to wait in seconds. `None` disables
the timer. Defaults to `None`.
Returns:
A new timer.
"""
return cls(max_runtime)
[docs]
class Queue:
"""A queue of push jobs."""
lock: Semaphore = Semaphore()
"""A lock for writing to the queue file."""
def __init__(self, server: ServerConfig, mode: str = 'r'):
"""Queue of client hash to lineages to be pushed to clients.
Arguments:
server: The config of the parent server.
mode: The access mode of the queue file. Valid values are 'r'
read (nonblocking) and 'w' write (blocks writes and reads).
Defaults to 'r'.
Note:
The access locking uses lock files to avoid issues with filesystem
locks on NFS.
"""
self._queue: dict = {}
self._filename: os.PathLike = os.path.join(
server.queue_dir,
'queue.json',
)
self._lock_filename = f'{self._filename}.lock'
if mode not in ('r', 'w'):
raise ValueError('`mode` must be either "r" or "w".')
self._mode: str = mode
@property
def clients(self) -> list[str]:
"""The client hashes in the queue."""
return list(self._queue.keys())
[docs]
def append(self, client_hash: str, lineage: str):
"""Append a job (`lineage`) to the queue for a given client.
Arguments:
client_hash: The value of `ClientConnection.hash` for the client
that needs the update.
lineage: The lineage path that needs syncing to the client.
"""
if client_hash not in self._queue:
self._queue[client_hash] = []
self._queue[client_hash].append(lineage)
[docs]
def get(self, client_hash: str, default: Any = None) -> list[str]:
"""Get a list of lineages that need to be pushed to a client.
Arguments:
client_hash: The value of `ClientConnection.hash` for the client
being requested.
default: This will be returned when no client matching `client_hash`
is found. Defaults to `None`.
Returns:
A list of lineages that need to be pushed for the given client.
"""
client_queue = self._queue.get(client_hash, None)
if client_queue is None:
return default
return (client_queue or []).copy()
[docs]
def count(self, client_hash: str) -> int:
"""Return the number of lineages left to push for the given client."""
client_queue = self._queue.get(client_hash)
if not client_queue:
return 0
return len(client_queue)
[docs]
def next(self, client_hash: str) -> str:
"""Return the next lineage to push for the given client."""
client_queue = self._queue.get(client_hash)
if not client_queue:
return None
lineage = client_queue.pop(0)
if client_hash in self and not self.get(client_hash):
del self._queue[client_hash]
return lineage
[docs]
def load(self):
"""Load the queue from the path configured with `queue_dir`.
This won't load a file that is open for writing.
"""
if self._mode == 'w':
raise ValueError('Can\'t use Queue.load() in writable mode.')
try:
self._lock()
self._load()
return self
finally:
self._unlock()
def _load(self):
"""Load the queue from disk.
The backend of `self.load`.
"""
if os.path.exists(self._filename):
with open(self._filename, 'r') as queue_file:
try:
queue = json.load(queue_file)
except json.JSONDecodeError as err:
raise CertDeployError(
'The queue file contains invalid data.'
) from err
if not isinstance(queue, dict):
raise CertDeployError(
'The queue file contains invalid data.',
)
self._queue = queue
else:
self._queue = {}
def _dump(self):
"""Write the queue to disk."""
with open(self._filename, 'w') as queue_file:
json.dump(self._queue, queue_file)
def _lock(self):
"""Attempt to lock the queue file for writing."""
self.lock.acquire()
while os.path.exists(self._lock_filename):
time.sleep(0.01)
open(self._lock_filename, 'w').close()
def _unlock(self):
"""Release the lock on the queue file."""
self.lock.release()
if os.path.exists(self._lock_filename):
os.remove(self._lock_filename)
def __contains__(self, key: str):
"""Test if some `key` (`ClientConnection.hash`) exists in the queue."""
return key in self._queue
def __enter__(self) -> 'Queue':
"""Attempt to lock and load the queue file."""
self._lock()
self._load()
return self
def __exit__(self, _exc_type, _exc_value, _traceback):
"""Write the queue file if needed and close it."""
if self._mode == 'w':
self._dump()
self._unlock()
def __len__(self):
"""Return the length of the queue."""
return len(self._queue)
[docs]
class PushWorker(Thread):
"""A worker thread to push lineages to a single client."""
def __init__(
self, server: 'Server', client: ClientConnection, config: ServerConfig
):
"""Prepare the worker.
Arguments:
server: The `Server` instance creating this worker.
client: The client connection information.
config: The CertDeploy server config for the server creating this
worker.
"""
Thread.__init__(self, daemon=True)
self._server = server
self._client = client
self._config = config
self._lineage: str = None
self._retries: int = self._config.push_retries
# Prefer the client config over the server config.
if isinstance(self._client.push_retries, int):
self._retries = self._client.push_retries
self._retry_interval: int = self._config.push_retry_interval
# Prefer the client config over the server config.
if isinstance(self._client.push_retry_interval, int):
self._retry_interval = self._client.push_retry_interval
self._exception: Exception = None
self._attempt: int = None
@property
def client_hash(self) -> str:
"""The hash of the associated client."""
return self._client.hash
@property
def has_error(self) -> bool:
"""Return `True` if there has been an exception in the thread."""
return self._exception is not None
def _sync_client(self):
"""Sync the current lineage to the client over SFTP."""
cert_dir = os.path.join(
self._client.path,
os.path.basename(self._lineage),
)
ssh = paramiko.client.SSHClient()
if self._client.port == 22:
hostkey_name = self._client.address
else:
hostkey_name = f'[{self._client.address}]:{self._client.port}'
ssh.get_host_keys().add(
hostkey_name,
'ssh-ed25519',
self._client.pubkey_blob,
)
# Set the safest policy by default
ssh.set_missing_host_key_policy(paramiko.client.RejectPolicy)
ssh.connect(
hostname=self._client.address,
port=self._client.port,
username=self._client.username,
key_filename=self._config.privkey_filename,
auth_timeout=self._config.sftp_auth_timeout,
banner_timeout=self._config.sftp_banner_timeout,
timeout=self._config.sftp_tcp_timeout,
)
sftp = ssh.open_sftp()
# Make the destination directory
_sftp_mkdir(sftp, cert_dir)
# Transfer certificates as needed
if self._client.needs_chain:
log.debug(
'Copying %s to %s',
os.path.join(self._lineage, 'chain.pem'),
os.path.join(cert_dir, 'chain.pem'),
)
sftp.put(
os.path.join(self._lineage, 'chain.pem'),
os.path.join(cert_dir, 'chain.pem'),
)
if self._client.needs_fullchain:
log.debug(
'Copying %s to %s',
os.path.join(self._lineage, 'fullchain.pem'),
os.path.join(cert_dir, 'fullchain.pem'),
)
sftp.put(
os.path.join(self._lineage, 'fullchain.pem'),
os.path.join(cert_dir, 'fullchain.pem'),
)
if self._client.needs_privkey:
log.debug(
'Copying %s to %s',
os.path.join(self._lineage, 'privkey.pem'),
os.path.join(cert_dir, 'privkey.pem'),
)
sftp.put(
os.path.join(self._lineage, 'privkey.pem'),
os.path.join(cert_dir, 'privkey.pem'),
)
sftp.close()
def _next(self) -> bool:
"""Return `True` if there is another lineage to push.
This also loads the `self._lineage` variable from the queue.
"""
with Queue(self._config, 'w') as queue:
self._lineage = queue.next(self._client.hash)
return self._lineage is not None
[docs]
def run(self):
"""Run the main loop.
Note:
This is called automatically by `self.start`.
"""
while self._next():
log.info('Pushing %s to %s', self._lineage, self._client)
for self._attempt in range(self._retries + 1):
attempt_str = f'{self._attempt + 1}'
tries_str = f'{ self._retries + 1 }'
log.debug(
'Attempt #%s of %s tries.',
attempt_str,
tries_str,
)
try:
self._sync_client()
except (
CertDeployError,
socket.gaierror,
SSHException,
NoValidConnectionsError,
) as err:
log.error(
'Error syncing with %s:%s: %s',
self._client.address,
self._client.port,
format_error(err),
exc_info=err,
)
if self._config.fail_fast:
self._exception = err
break # Go to the next lineage
if self._attempt == self._retries:
log.warning(
'Attempt #%s of %s failed. Not retrying '
'sync %s to %s.', # fmt: skip
attempt_str,
tries_str,
self._lineage,
self._client,
)
break # Go to the next lineage
log.info(
'Attempt #%s failed. Retrying sync to %s in '
'%s seconds.', # fmt: skip
attempt_str,
self._client,
self._retry_interval,
)
# Wait between attempts
time.sleep(self._retry_interval)
except Exception as err:
self._exception = err
if not self._config.fail_fast:
log.error(
'Error syncing with %s:%s: %s',
self._client.address,
self._client.port,
format_error(err),
exc_info=err,
)
return # End the thread
else:
log.info(
'Pushed %s to %s in %s attempts',
self._lineage,
self._client,
attempt_str,
)
break # Go to the next lineage
log.debug('Done pushing %s to %s', self._lineage, self._client)
log.info('Done pushing all lineages to %s', self._client)
[docs]
def join(self, timeout: Optional[float] = None):
"""Join the worker thread and raise an exception.
Arguments:
timeout: The number of seconds to wait for the thread to end before
raising a `TimeoutError`. Defaults to `None`.
Raises:
An exception if one was encountered and `fail_fast` is enabled.
"""
Thread.join(self, timeout)
if self._config.fail_fast and self._exception:
raise self._exception
def __repr__(self):
"""Return a pragmatic representation of this object."""
return (
f'<{self.__class__.__name__} address={self._client.address}, '
f'port={self._client.port}, username={self._client.username},'
f'attempts={self._attempt}, exception={self._exception}>'
)
[docs]
class Server:
"""Accept new sync requests and push new certs to clients."""
# Just for testing so that the daemon can be shutdown cleanly.
_stop_running: bool = False
def __init__(self, config: ServerConfig):
"""Prepare the server.
Arguments:
config: Server config.
"""
self._config = config
self._workers: dict[str, PushWorker] = {}
self._schedule_renew()
[docs]
def serve_forever(self, one_shot: bool = False):
"""Push queued lineages to clients.
Arguments:
one_shot: Push lineages in the queue and exit when the queue has
been fully processed. Defaults to `False`.
"""
# This is used in tests to determine if the server has started.
log.debug('Server.serve_forever: one_shot=%s', one_shot)
# `timeout` is just for debugging. It makes the whole server fail hard
# if it takes too long. In order to use it set `ONE_SHOT_TIMEOUT`
# to a positive integer.
timeout = _TimeoutTimer.start(ONE_SHOT_TIMEOUT)
while not self._stop_running:
main_loop_sleep = GO_FAST_SLEEP
queue = Queue(self._config, 'r').load()
log.debug(
'Queue length is %s, worker count is %s',
len(queue),
len(self._workers),
)
if len(queue) < 1 and len(self._workers) < 1:
# Slow down when the queue is empty and there are no workers.
main_loop_sleep = SLOW_DOWN_SLEEP
if one_shot:
# Once the queue is empty and there are no more workers,
# all the jobs sent with the push only option have
# completed and this loop doesn't need to run anymore.
return
for client in self._config.clients:
if client.hash not in queue:
continue
if client.hash not in self._workers:
log.debug(
'Adding worker for %s@%s:%s',
client.username,
client.address,
client.port,
)
self._add_worker(client)
if self._config.push_mode == PushMode.SERIAL:
log.debug(
'Waiting for push to %s@%s:%s to finish',
client.username,
client.address,
client.port,
)
self._workers[client.hash].join(
self._config.join_timeout,
)
log.debug(
'Finished pushing to %s@%s:%s',
client.username,
client.address,
client.port,
)
# Only delay when adding a new worker
time.sleep(self._config.push_interval)
# Cleanup workers if idle
for worker in list(self._workers.values()):
if not worker.is_alive():
self._remove_worker(worker)
schedule.run_pending() # Renew certs if needed
# This is just for debugging. It makes the whole server fail hard
# if it takes too long. In order to use it set `ONE_SHOT_TIMEOUT`
# to a positive integer.
if one_shot:
timeout.check()
# End just for debugging
time.sleep(main_loop_sleep)
log.debug('Done serving')
[docs]
def sync(self, lineage: os.PathLike, domains: list[str]):
"""Synchronize clients that need updates based on domains.
Arguments:
lineage: The full path of a lineage.
domains: A `list` of domain names to use to find clients to push
to.
"""
for client in self._config.clients:
for domain in domains:
if domain in client.domains:
with Queue(self._config, 'w') as queue:
queue.append(client.hash, lineage)
log.debug(
'Queued lineage %s for client %s',
lineage,
client,
)
break
def _add_worker(self, client: ClientConnection):
"""Kickstart a new `PushWorker` for `client`."""
worker = PushWorker(self, client, self._config)
self._workers[client.hash] = worker
worker.start()
def _remove_worker(self, worker):
"""End `worker` and pop it out of the pool."""
worker.join(self._config.join_timeout)
del self._workers[worker.client_hash]
def _schedule_renew(self):
"""Attempt to configure a scheduled cert renewal.
Raises:
ConfigError: When renew related configs are invalid.
"""
# Catch config related errors and add some context before reraising.
try:
every = schedule.every(self._config.renew_every)
except schedule.ScheduleValueError as err:
raise ConfigError(f'Invalid `renew_every` value: {err}') from err
try:
when = getattr(every, self._config.renew_unit)
except schedule.ScheduleValueError as err:
raise ConfigError(f'Invalid `renew_unit` value: {err}') from err
if self._config.renew_at:
try:
when = when.at(self._config.renew_at)
except schedule.ScheduleValueError as err:
raise ConfigError(f'Invalid `renew_at` value: {err}') from err
when.do(renew_certs, config=self._config)