Source code for pendant.aws.batch

import inspect
from abc import abstractmethod
from datetime import datetime
from typing import Dict, List, Mapping, Optional, Tuple

import boto3

from custom_inherit import DocInheritMeta

from pendant.aws.exception import BatchJobSubmissionError
from pendant.aws.logs import AwsLogUtil, LogEvent
from pendant.aws.response import SubmitJobResponse
from pendant.util import format_ISO8601

__all__ = ['BatchJob', 'JobDefinition']

CLOUDWATCH_LOG_GROUP = '/aws/batch/job'
BATCH_STATUS_SUBMITTED = 'SUBMITTED'
BATCH_STATUS_PENDING = 'PENDING'
BATCH_STATUS_RUNNABLE = 'RUNNABLE'
BATCH_STATUS_STARTING = 'STARTING'
BATCH_STATUS_RUNNING = 'RUNNING'
BATCH_STATUS_FAILED = 'FAILED'
BATCH_STATUS_NOTFOUND = 'NOTFOUND'


[docs]class JobDefinition( metaclass=DocInheritMeta(style="google", abstract_base_class=True) # type: ignore ): """A Batch job definition.""" def __new__(cls, *args: str, **kwargs: str) -> 'JobDefinition': """Create a new Batch job definition.""" this: JobDefinition = super().__new__(cls) this._revision = '0' return this @property @abstractmethod def name(self) -> str: """Return the name of the job definition.""" @property def parameters(self) -> Tuple[str]: """Return the parameters of the job definition.""" return tuple(inspect.signature(self.__init__).parameters.keys()) # type: ignore @property def revision(self) -> str: """Return the revision of the job definition.""" return self._revision
[docs] @abstractmethod def validate(self) -> None: """Validate this job definition after initialization."""
[docs] def at_revision(self, revision: str) -> 'JobDefinition': """Set this job definition to a specific revision.""" self._revision = revision return self
[docs] def make_job_name(self, moment: Optional[datetime] = None) -> str: """Format a Batch job name from this definition.""" moment = datetime.now() if moment is None else moment return format_ISO8601(moment) + '_' + self.name
[docs] def to_dict(self) -> Dict[str, str]: """Return a dictionary of all parameters and their values as strings.""" mapping: Dict[str, str] = {key: str(getattr(self, key)) for key in self.parameters} return mapping
def __str__(self) -> str: return f'{self.name}:{self.revision}' def __repr__(self) -> str: parts = [f'{key}={repr(getattr(self, key))}' for key in self.parameters] signature = ', '.join(parts) return f'{self.__class__.__qualname__}({signature})'
[docs]class BatchJob(object): """An AWS Batch job. A Batch job can be instantiated and then submitted against the Batch service. After submission, the job's status can be queried, the job's logs can be read, and other methods can be called to understand the state of the job. Args: definition: A Batch job definition. """ def __init__(self, definition: JobDefinition): definition.validate() self.definition = definition self._client = boto3.client('batch') self._is_submitted: bool = False self._container_overrides: Mapping = dict() self._job_id: Optional[str] = None self._queue: Optional[str] = None self._submit_response: Optional[SubmitJobResponse] = None @property def container_overrides(self) -> Optional[Mapping]: """Return container overriding parameters.""" return self._container_overrides @property def job_id(self) -> Optional[str]: """Return the job ID.""" return self._job_id @property def queue(self) -> Optional[str]: """Return the job queue.""" return self._queue
[docs] @staticmethod def describe_job(job_id: str) -> Dict: """Describe this job.""" job, *_ = BatchJob.describe_jobs([job_id]) return job if job else dict()
[docs] @staticmethod def describe_jobs(job_ids: List[str]) -> List[Dict]: """Describe a Batch job by job ID.""" jobs: List[Dict] = boto3.client('batch').describe_jobs(jobs=job_ids)['jobs'] return jobs
[docs] def status(self) -> str: """Return the job status.""" if self.job_id is None: raise BatchJobSubmissionError( 'Cannot check status of a job that has not been submitted.' ) job = BatchJob.describe_job(self.job_id) status: str = job.get('status', BATCH_STATUS_NOTFOUND) return status
[docs] def cancel(self, reason: str) -> Dict: """Cancel this job. Args: reason: The reason why the job must be canceled. Returns: The service response to job cancellation. """ assert self.is_submitted(), 'Cannot cancel a job that has not been submitted.' response: Dict = self._client.cancel_job(jobId=self.job_id, reason=reason) return response
[docs] def terminate(self, reason: str) -> Dict: """Terminate this job. Jobs that are in the STARTING or RUNNING state are terminated, which causes them to transition to FAILED. Jobs that have not progressed to the STARTING state are cancelled. Args: reason: The reason why the job must be terminated. Returns: The service response to job termination. """ assert self.is_submitted(), 'Cannot terminate a job that has not been submitted.' response: Dict = self._client.terminate_job(jobId=self.job_id, reason=reason) return response
[docs] def is_running(self) -> bool: """Return if this job's state is RUNNING or not.""" return self.status() == BATCH_STATUS_RUNNING
[docs] def is_runnable(self) -> bool: """Return if this job's state is RUNNABLE or not.""" return self.status() == BATCH_STATUS_RUNNABLE
[docs] def is_submitted(self) -> bool: """Return if this job has been submitted to Batch.""" return self._is_submitted
[docs] def submit( self, queue: str, container_overrides: Optional[Mapping] = None ) -> SubmitJobResponse: """Submit this job to Batch. Args: queue: The Batch job queue to use. container_overrides: The values to override in the spawned container. Returns: The service response to job submission. """ assert not self.is_submitted(), 'Cannot submit already submitted job!' self._queue = queue self._container_overrides = container_overrides if container_overrides else {} job_name = self.definition.make_job_name() response: Mapping = self._client.submit_job( jobName=job_name, jobQueue=queue, jobDefinition=str(self.definition), parameters=self.definition.to_dict(), containerOverrides=self.container_overrides, ) submit_response = SubmitJobResponse(response) if submit_response.is_ok(): self._is_submitted = True self._job_id = submit_response.job_id self._submit_response = submit_response else: raise BatchJobSubmissionError(f'Batch job failed to submit!\n{response}') return submit_response
[docs] def log_stream_name(self) -> str: """Return the Batch log stream name for this job.""" if self.job_id is None: raise BatchJobSubmissionError( 'Cannot check status of a job that has not been submitted.' ) job = BatchJob.describe_job(self.job_id) log_stream_name: str = job['container']['logStreamName'] return log_stream_name
[docs] def log_stream_events(self) -> List[LogEvent]: """Return all log events for this job. Returns: events: All log events, to date. """ log_util = AwsLogUtil() log_stream_name = self.log_stream_name() events = log_util.get_log_events( group_name=CLOUDWATCH_LOG_GROUP, stream_name=log_stream_name ) return events
def __repr__(self) -> str: return f'{self.__class__.__qualname__}(' f'definition={repr(self.definition)})'