Guest User

Untitled

a guest
Jan 11th, 2018
450
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 150.90 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18.  
  19. from future.standard_library import install_aliases
  20.  
  21. install_aliases()
  22. from builtins import str
  23. from builtins import object, bytes
  24. import copy
  25. from collections import namedtuple
  26. from datetime import datetime, timedelta
  27. import dill
  28. import functools
  29. import getpass
  30. import imp
  31. import importlib
  32. import inspect
  33. import zipfile
  34. import jinja2
  35. import json
  36. import logging
  37. import os
  38. import pickle
  39. import re
  40. import signal
  41. import socket
  42. import sys
  43. import textwrap
  44. import traceback
  45. import warnings
  46. import hashlib
  47.  
  48. from urllib.parse import urlparse
  49.  
  50. from sqlalchemy import (
  51.     Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType,
  52.     Index, Float)
  53. from sqlalchemy import func, or_, and_
  54. from sqlalchemy.ext.declarative import declarative_base, declared_attr
  55. from sqlalchemy.dialects.mysql import LONGTEXT
  56. from sqlalchemy.orm import reconstructor, relationship, synonym
  57.  
  58. from croniter import croniter
  59. import six
  60.  
  61. from airflow import settings, utils
  62. from airflow.executors import DEFAULT_EXECUTOR, LocalExecutor
  63. from airflow import configuration
  64. from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout
  65. from airflow.dag.base_dag import BaseDag, BaseDagBag
  66. from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
  67. from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
  68. from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
  69. from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
  70. from airflow.utils.dates import cron_presets, date_range as utils_date_range
  71. from airflow.utils.db import provide_session
  72. from airflow.utils.decorators import apply_defaults
  73. from airflow.utils.email import send_email
  74. from airflow.utils.helpers import (
  75.     as_tuple, is_container, is_in, validate_key, pprinttable)
  76. from airflow.utils.logging import LoggingMixin
  77. from airflow.utils.operator_resources import Resources
  78. from airflow.utils.state import State
  79. from airflow.utils.timeout import timeout
  80. from airflow.utils.trigger_rule import TriggerRule
  81.  
  82. Base = declarative_base()
  83. ID_LEN = 250
  84. XCOM_RETURN_KEY = 'return_value'
  85.  
  86. Stats = settings.Stats
  87.  
  88. ENCRYPTION_ON = False
  89. try:
  90.     from cryptography.fernet import Fernet
  91.     FERNET = Fernet(configuration.get('core', 'FERNET_KEY').encode('utf-8'))
  92.     ENCRYPTION_ON = True
  93. except:
  94.     pass
  95.  
  96. if 'mysql' in settings.SQL_ALCHEMY_CONN:
  97.     LongText = LONGTEXT
  98. else:
  99.     LongText = Text
  100.  
  101. # used by DAG context_managers
  102. _CONTEXT_MANAGER_DAG = None
  103.  
  104.  
  105. def clear_task_instances(tis, session, activate_dag_runs=True):
  106.     """
  107.    Clears a set of task instances, but makes sure the running ones
  108.    get killed.
  109.    """
  110.     job_ids = []
  111.     for ti in tis:
  112.         if ti.state == State.RUNNING:
  113.             if ti.job_id:
  114.                 ti.state = State.SHUTDOWN
  115.                 job_ids.append(ti.job_id)
  116.         # todo: this creates an issue with the webui tests
  117.         # elif ti.state != State.REMOVED:
  118.         #     ti.state = State.NONE
  119.         #     session.merge(ti)
  120.         else:
  121.             session.delete(ti)
  122.     if job_ids:
  123.         from airflow.jobs import BaseJob as BJ
  124.         for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all():
  125.             job.state = State.SHUTDOWN
  126.     if activate_dag_runs:
  127.         execution_dates = {ti.execution_date for ti in tis}
  128.         dag_ids = {ti.dag_id for ti in tis}
  129.         drs = session.query(DagRun).filter(
  130.             DagRun.dag_id.in_(dag_ids),
  131.             DagRun.execution_date.in_(execution_dates),
  132.         ).all()
  133.         for dr in drs:
  134.             dr.state = State.RUNNING
  135.             dr.start_date = datetime.now()
  136.  
  137.  
  138. class DagBag(BaseDagBag, LoggingMixin):
  139.     """
  140.    A dagbag is a collection of dags, parsed out of a folder tree and has high
  141.    level configuration settings, like what database to use as a backend and
  142.    what executor to use to fire off tasks. This makes it easier to run
  143.    distinct environments for say production and development, tests, or for
  144.    different teams or security profiles. What would have been system level
  145.    settings are now dagbag level so that one system can run multiple,
  146.    independent settings sets.
  147.  
  148.    :param dag_folder: the folder to scan to find DAGs
  149.    :type dag_folder: unicode
  150.    :param executor: the executor to use when executing task instances
  151.        in this DagBag
  152.    :param include_examples: whether to include the examples that ship
  153.        with airflow or not
  154.    :type include_examples: bool
  155.    :param sync_to_db: whether to sync the properties of the DAGs to
  156.        the metadata DB while finding them, typically should be done
  157.        by the scheduler job only
  158.    :type sync_to_db: bool
  159.    """
  160.     def __init__(
  161.             self,
  162.             dag_folder=None,
  163.             executor=DEFAULT_EXECUTOR,
  164.             include_examples=configuration.getboolean('core', 'LOAD_EXAMPLES')):
  165.  
  166.         dag_folder = dag_folder or settings.DAGS_FOLDER
  167.         self.logger.info("Filling up the DagBag from {}".format(dag_folder))
  168.         self.dag_folder = dag_folder
  169.         self.dags = {}
  170.         # the file's last modified timestamp when we last read it
  171.         self.file_last_changed = {}
  172.         self.executor = executor
  173.         self.import_errors = {}
  174.  
  175.         if include_examples:
  176.             example_dag_folder = os.path.join(
  177.                 os.path.dirname(__file__),
  178.                 'example_dags')
  179.             self.collect_dags(example_dag_folder)
  180.         self.collect_dags(dag_folder)
  181.  
  182.     def size(self):
  183.         """
  184.        :return: the amount of dags contained in this dagbag
  185.        """
  186.         return len(self.dags)
  187.  
  188.     def get_dag(self, dag_id):
  189.         """
  190.        Gets the DAG out of the dictionary, and refreshes it if expired
  191.        """
  192.         # If asking for a known subdag, we want to refresh the parent
  193.         root_dag_id = dag_id
  194.         if dag_id in self.dags:
  195.             dag = self.dags[dag_id]
  196.             if dag.is_subdag:
  197.                 root_dag_id = dag.parent_dag.dag_id
  198.  
  199.         # If the dag corresponding to root_dag_id is absent or expired
  200.         orm_dag = DagModel.get_current(root_dag_id)
  201.         if orm_dag and (
  202.                 root_dag_id not in self.dags or
  203.                 (
  204.                     orm_dag.last_expired and
  205.                     dag.last_loaded < orm_dag.last_expired
  206.                 )
  207.         ):
  208.             # Reprocess source file
  209.             found_dags = self.process_file(
  210.                 filepath=orm_dag.fileloc, only_if_updated=False)
  211.  
  212.             # If the source file no longer exports `dag_id`, delete it from self.dags
  213.             if found_dags and dag_id in [dag.dag_id for dag in found_dags]:
  214.                 return self.dags[dag_id]
  215.             elif dag_id in self.dags:
  216.                 del self.dags[dag_id]
  217.         return self.dags.get(dag_id)
  218.  
  219.     def process_file(self, filepath, only_if_updated=True, safe_mode=True):
  220.         """
  221.        Given a path to a python module or zip file, this method imports
  222.        the module and look for dag objects within it.
  223.        """
  224.         found_dags = []
  225.  
  226.         # todo: raise exception?
  227.         if not os.path.isfile(filepath):
  228.             return found_dags
  229.  
  230.         try:
  231.             # This failed before in what may have been a git sync
  232.             # race condition
  233.             file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath))
  234.             if only_if_updated \
  235.                     and filepath in self.file_last_changed \
  236.                     and file_last_changed_on_disk == self.file_last_changed[filepath]:
  237.                 return found_dags
  238.  
  239.         except Exception as e:
  240.             logging.exception(e)
  241.             return found_dags
  242.  
  243.         mods = []
  244.         if not zipfile.is_zipfile(filepath):
  245.             if safe_mode and os.path.isfile(filepath):
  246.                 with open(filepath, 'rb') as f:
  247.                     content = f.read()
  248.                     if not all([s in content for s in (b'DAG', b'airflow')]):
  249.                         self.file_last_changed[filepath] = file_last_changed_on_disk
  250.                         return found_dags
  251.  
  252.             self.logger.debug("Importing {}".format(filepath))
  253.             org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
  254.             mod_name = ('unusual_prefix_' +
  255.                         hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
  256.                         '_' + org_mod_name)
  257.  
  258.             if mod_name in sys.modules:
  259.                 del sys.modules[mod_name]
  260.  
  261.             with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")):
  262.                 try:
  263.                     m = imp.load_source(mod_name, filepath)
  264.                     mods.append(m)
  265.                 except Exception as e:
  266.                     self.logger.exception("Failed to import: " + filepath)
  267.                     self.import_errors[filepath] = str(e)
  268.                     self.file_last_changed[filepath] = file_last_changed_on_disk
  269.  
  270.         else:
  271.             zip_file = zipfile.ZipFile(filepath)
  272.             for mod in zip_file.infolist():
  273.                 head, _ = os.path.split(mod.filename)
  274.                 mod_name, ext = os.path.splitext(mod.filename)
  275.                 if not head and (ext == '.py' or ext == '.pyc'):
  276.                     if mod_name == '__init__':
  277.                         self.logger.warning("Found __init__.{0} at root of {1}".
  278.                                             format(ext, filepath))
  279.  
  280.                     if safe_mode:
  281.                         with zip_file.open(mod.filename) as zf:
  282.                             self.logger.debug("Reading {} from {}".
  283.                                               format(mod.filename, filepath))
  284.                             content = zf.read()
  285.                             if not all([s in content for s in (b'DAG', b'airflow')]):
  286.                                 self.file_last_changed[filepath] = (
  287.                                     file_last_changed_on_disk)
  288.                                 # todo: create ignore list
  289.                                 return found_dags
  290.  
  291.                     if mod_name in sys.modules:
  292.                         del sys.modules[mod_name]
  293.  
  294.                     try:
  295.                         sys.path.insert(0, filepath)
  296.                         m = importlib.import_module(mod_name)
  297.                         mods.append(m)
  298.                     except Exception as e:
  299.                         self.logger.exception("Failed to import: " + filepath)
  300.                         self.import_errors[filepath] = str(e)
  301.                         self.file_last_changed[filepath] = file_last_changed_on_disk
  302.  
  303.         for m in mods:
  304.             for dag in list(m.__dict__.values()):
  305.                 if isinstance(dag, DAG):
  306.                     if not dag.full_filepath:
  307.                         dag.full_filepath = filepath
  308.                     dag.is_subdag = False
  309.                     self.bag_dag(dag, parent_dag=dag, root_dag=dag)
  310.                     found_dags.append(dag)
  311.                     found_dags += dag.subdags
  312.  
  313.         self.file_last_changed[filepath] = file_last_changed_on_disk
  314.         return found_dags
  315.  
  316.     @provide_session
  317.     def kill_zombies(self, session=None):
  318.         """
  319.        Fails tasks that haven't had a heartbeat in too long
  320.        """
  321.         from airflow.jobs import LocalTaskJob as LJ
  322.         self.logger.info("Finding 'running' jobs without a recent heartbeat")
  323.         TI = TaskInstance
  324.         secs = (
  325.             configuration.getint('scheduler', 'scheduler_zombie_task_threshold'))
  326.         limit_dttm = datetime.now() - timedelta(seconds=secs)
  327.         self.logger.info(
  328.             "Failing jobs without heartbeat after {}".format(limit_dttm))
  329.  
  330.         tis = (
  331.             session.query(TI)
  332.             .join(LJ, TI.job_id == LJ.id)
  333.             .filter(TI.state == State.RUNNING)
  334.             .filter(
  335.                 or_(
  336.                     LJ.state != State.RUNNING,
  337.                     LJ.latest_heartbeat < limit_dttm,
  338.                 ))
  339.             .all()
  340.         )
  341.  
  342.         for ti in tis:
  343.             if ti and ti.dag_id in self.dags:
  344.                 dag = self.dags[ti.dag_id]
  345.                 if ti.task_id in dag.task_ids:
  346.                     task = dag.get_task(ti.task_id)
  347.                     ti.task = task
  348.                     ti.handle_failure("{} killed as zombie".format(ti))
  349.                     self.logger.info(
  350.                         'Marked zombie job {} as failed'.format(ti))
  351.                     Stats.incr('zombies_killed')
  352.         session.commit()
  353.  
  354.     def bag_dag(self, dag, parent_dag, root_dag):
  355.         """
  356.        Adds the DAG into the bag, recurses into sub dags.
  357.        """
  358.         self.dags[dag.dag_id] = dag
  359.         dag.resolve_template_files()
  360.         dag.last_loaded = datetime.now()
  361.  
  362.         for task in dag.tasks:
  363.             settings.policy(task)
  364.  
  365.         for subdag in dag.subdags:
  366.             subdag.full_filepath = dag.full_filepath
  367.             subdag.parent_dag = dag
  368.             subdag.is_subdag = True
  369.             self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag)
  370.         self.logger.debug('Loaded DAG {dag}'.format(**locals()))
  371.  
  372.     def collect_dags(
  373.             self,
  374.             dag_folder=None,
  375.             only_if_updated=True):
  376.         """
  377.        Given a file path or a folder, this method looks for python modules,
  378.        imports them and adds them to the dagbag collection.
  379.  
  380.        Note that if a .airflowignore file is found while processing,
  381.        the directory, it will behaves much like a .gitignore does,
  382.        ignoring files that match any of the regex patterns specified
  383.        in the file.
  384.        """
  385.         start_dttm = datetime.now()
  386.         dag_folder = dag_folder or self.dag_folder
  387.  
  388.         # Used to store stats around DagBag processing
  389.         stats = []
  390.         FileLoadStat = namedtuple(
  391.             'FileLoadStat', "file duration dag_num task_num dags")
  392.         if os.path.isfile(dag_folder):
  393.             self.process_file(dag_folder, only_if_updated=only_if_updated)
  394.         elif os.path.isdir(dag_folder):
  395.             patterns = []
  396.             for root, dirs, files in os.walk(dag_folder, followlinks=True):
  397.                 ignore_file = [f for f in files if f == '.airflowignore']
  398.                 if ignore_file:
  399.                     f = open(os.path.join(root, ignore_file[0]), 'r')
  400.                     patterns += [p for p in f.read().split('\n') if p]
  401.                     f.close()
  402.                 for f in files:
  403.                     try:
  404.                         filepath = os.path.join(root, f)
  405.                         if not os.path.isfile(filepath):
  406.                             continue
  407.                         mod_name, file_ext = os.path.splitext(
  408.                             os.path.split(filepath)[-1])
  409.                         if file_ext != '.py' and not zipfile.is_zipfile(filepath):
  410.                             continue
  411.                         if not any(
  412.                                 [re.findall(p, filepath) for p in patterns]):
  413.                             ts = datetime.now()
  414.                             found_dags = self.process_file(
  415.                                 filepath, only_if_updated=only_if_updated)
  416.  
  417.                             td = datetime.now() - ts
  418.                             td = td.total_seconds() + (
  419.                                 float(td.microseconds) / 1000000)
  420.                             stats.append(FileLoadStat(
  421.                                 filepath.replace(dag_folder, ''),
  422.                                 td,
  423.                                 len(found_dags),
  424.                                 sum([len(dag.tasks) for dag in found_dags]),
  425.                                 str([dag.dag_id for dag in found_dags]),
  426.                             ))
  427.                     except Exception as e:
  428.                         logging.warning(e)
  429.         Stats.gauge(
  430.             'collect_dags', (datetime.now() - start_dttm).total_seconds(), 1)
  431.         Stats.gauge(
  432.             'dagbag_size', len(self.dags), 1)
  433.         Stats.gauge(
  434.             'dagbag_import_errors', len(self.import_errors), 1)
  435.         self.dagbag_stats = sorted(
  436.             stats, key=lambda x: x.duration, reverse=True)
  437.  
  438.     def dagbag_report(self):
  439.         """Prints a report around DagBag loading stats"""
  440.         report = textwrap.dedent("""\n
  441.        -------------------------------------------------------------------
  442.        DagBag loading stats for {dag_folder}
  443.        -------------------------------------------------------------------
  444.        Number of DAGs: {dag_num}
  445.        Total task number: {task_num}
  446.        DagBag parsing time: {duration}
  447.        {table}
  448.        """)
  449.         stats = self.dagbag_stats
  450.         return report.format(
  451.             dag_folder=self.dag_folder,
  452.             duration=sum([o.duration for o in stats]),
  453.             dag_num=sum([o.dag_num for o in stats]),
  454.             task_num=sum([o.dag_num for o in stats]),
  455.             table=pprinttable(stats),
  456.         )
  457.  
  458.     def deactivate_inactive_dags(self):
  459.         active_dag_ids = [dag.dag_id for dag in list(self.dags.values())]
  460.         session = settings.Session()
  461.         for dag in session.query(
  462.                 DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
  463.             dag.is_active = False
  464.             session.merge(dag)
  465.         session.commit()
  466.         session.close()
  467.  
  468.     def paused_dags(self):
  469.         session = settings.Session()
  470.         dag_ids = [dp.dag_id for dp in session.query(DagModel).filter(
  471.             DagModel.is_paused.is_(True))]
  472.         session.commit()
  473.         session.close()
  474.         return dag_ids
  475.  
  476.  
  477. class User(Base):
  478.     __tablename__ = "users"
  479.  
  480.     id = Column(Integer, primary_key=True)
  481.     username = Column(String(ID_LEN), unique=True)
  482.     email = Column(String(500))
  483.     superuser = False
  484.  
  485.     def __repr__(self):
  486.         return self.username
  487.  
  488.     def get_id(self):
  489.         return str(self.id)
  490.  
  491.     def is_superuser(self):
  492.         return self.superuser
  493.  
  494.  
  495. class Connection(Base):
  496.     """
  497.    Placeholder to store information about different database instances
  498.    connection information. The idea here is that scripts use references to
  499.    database instances (conn_id) instead of hard coding hostname, logins and
  500.    passwords when using operators or hooks.
  501.    """
  502.     __tablename__ = "connection"
  503.  
  504.     id = Column(Integer(), primary_key=True)
  505.     conn_id = Column(String(ID_LEN))
  506.     conn_type = Column(String(500))
  507.     host = Column(String(500))
  508.     schema = Column(String(500))
  509.     login = Column(String(500))
  510.     _password = Column('password', String(5000))
  511.     port = Column(Integer())
  512.     is_encrypted = Column(Boolean, unique=False, default=False)
  513.     is_extra_encrypted = Column(Boolean, unique=False, default=False)
  514.     _extra = Column('extra', String(5000))
  515.  
  516.     _types = [
  517.         ('fs', 'File (path)'),
  518.         ('ftp', 'FTP',),
  519.         ('google_cloud_platform', 'Google Cloud Platform'),
  520.         ('hdfs', 'HDFS',),
  521.         ('http', 'HTTP',),
  522.         ('hive_cli', 'Hive Client Wrapper',),
  523.         ('hive_metastore', 'Hive Metastore Thrift',),
  524.         ('hiveserver2', 'Hive Server 2 Thrift',),
  525.         ('jdbc', 'Jdbc Connection',),
  526.         ('mysql', 'MySQL',),
  527.         ('postgres', 'Postgres',),
  528.         ('oracle', 'Oracle',),
  529.         ('vertica', 'Vertica',),
  530.         ('presto', 'Presto',),
  531.         ('s3', 'S3',),
  532.         ('samba', 'Samba',),
  533.         ('sqlite', 'Sqlite',),
  534.         ('ssh', 'SSH',),
  535.         ('cloudant', 'IBM Cloudant',),
  536.         ('mssql', 'Microsoft SQL Server'),
  537.         ('mesos_framework-id', 'Mesos Framework ID'),
  538.         ('jira', 'JIRA',),
  539.     ]
  540.  
  541.     def __init__(
  542.             self, conn_id=None, conn_type=None,
  543.             host=None, login=None, password=None,
  544.             schema=None, port=None, extra=None,
  545.             uri=None):
  546.         self.conn_id = conn_id
  547.         if uri:
  548.             self.parse_from_uri(uri)
  549.         else:
  550.             self.conn_type = conn_type
  551.             self.host = host
  552.             self.login = login
  553.             self.password = password
  554.             self.schema = schema
  555.             self.port = port
  556.             self.extra = extra
  557.  
  558.     def parse_from_uri(self, uri):
  559.         temp_uri = urlparse(uri)
  560.         hostname = temp_uri.hostname or ''
  561.         if '%2f' in hostname:
  562.             hostname = hostname.replace('%2f', '/').replace('%2F', '/')
  563.         conn_type = temp_uri.scheme
  564.         if conn_type == 'postgresql':
  565.             conn_type = 'postgres'
  566.         self.conn_type = conn_type
  567.         self.host = hostname
  568.         self.schema = temp_uri.path[1:]
  569.         self.login = temp_uri.username
  570.         self.password = temp_uri.password
  571.         self.port = temp_uri.port
  572.  
  573.     def get_password(self):
  574.         if self._password and self.is_encrypted:
  575.             if not ENCRYPTION_ON:
  576.                 raise AirflowException(
  577.                     "Can't decrypt encrypted password for login={}, \
  578.                    FERNET_KEY configuration is missing".format(self.login))
  579.             return FERNET.decrypt(bytes(self._password, 'utf-8')).decode()
  580.         else:
  581.             return self._password
  582.  
  583.     def set_password(self, value):
  584.         if value:
  585.             try:
  586.                 self._password = FERNET.encrypt(bytes(value, 'utf-8')).decode()
  587.                 self.is_encrypted = True
  588.             except NameError:
  589.                 self._password = value
  590.                 self.is_encrypted = False
  591.  
  592.     @declared_attr
  593.     def password(cls):
  594.         return synonym('_password',
  595.                        descriptor=property(cls.get_password, cls.set_password))
  596.  
  597.     def get_extra(self):
  598.         if self._extra and self.is_extra_encrypted:
  599.             if not ENCRYPTION_ON:
  600.                 raise AirflowException(
  601.                     "Can't decrypt `extra` params for login={},\
  602.                    FERNET_KEY configuration is missing".format(self.login))
  603.             return FERNET.decrypt(bytes(self._extra, 'utf-8')).decode()
  604.         else:
  605.             return self._extra
  606.  
  607.     def set_extra(self, value):
  608.         if value:
  609.             try:
  610.                 self._extra = FERNET.encrypt(bytes(value, 'utf-8')).decode()
  611.                 self.is_extra_encrypted = True
  612.             except NameError:
  613.                 self._extra = value
  614.                 self.is_extra_encrypted = False
  615.  
  616.     @declared_attr
  617.     def extra(cls):
  618.         return synonym('_extra',
  619.                        descriptor=property(cls.get_extra, cls.set_extra))
  620.  
  621.     def get_hook(self):
  622.         try:
  623.             if self.conn_type == 'mysql':
  624.                 from airflow.hooks.mysql_hook import MySqlHook
  625.                 return MySqlHook(mysql_conn_id=self.conn_id)
  626.             elif self.conn_type == 'google_cloud_platform':
  627.                 from airflow.contrib.hooks.bigquery_hook import BigQueryHook
  628.                 return BigQueryHook(bigquery_conn_id=self.conn_id)
  629.             elif self.conn_type == 'postgres':
  630.                 from airflow.hooks.postgres_hook import PostgresHook
  631.                 return PostgresHook(postgres_conn_id=self.conn_id)
  632.             elif self.conn_type == 'hive_cli':
  633.                 from airflow.hooks.hive_hooks import HiveCliHook
  634.                 return HiveCliHook(hive_cli_conn_id=self.conn_id)
  635.             elif self.conn_type == 'presto':
  636.                 from airflow.hooks.presto_hook import PrestoHook
  637.                 return PrestoHook(presto_conn_id=self.conn_id)
  638.             elif self.conn_type == 'hiveserver2':
  639.                 from airflow.hooks.hive_hooks import HiveServer2Hook
  640.                 return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
  641.             elif self.conn_type == 'sqlite':
  642.                 from airflow.hooks.sqlite_hook import SqliteHook
  643.                 return SqliteHook(sqlite_conn_id=self.conn_id)
  644.             elif self.conn_type == 'jdbc':
  645.                 from airflow.hooks.jdbc_hook import JdbcHook
  646.                 return JdbcHook(jdbc_conn_id=self.conn_id)
  647.             elif self.conn_type == 'mssql':
  648.                 from airflow.hooks.mssql_hook import MsSqlHook
  649.                 return MsSqlHook(mssql_conn_id=self.conn_id)
  650.             elif self.conn_type == 'oracle':
  651.                 from airflow.hooks.oracle_hook import OracleHook
  652.                 return OracleHook(oracle_conn_id=self.conn_id)
  653.             elif self.conn_type == 'vertica':
  654.                 from airflow.contrib.hooks.vertica_hook import VerticaHook
  655.                 return VerticaHook(vertica_conn_id=self.conn_id)
  656.             elif self.conn_type == 'cloudant':
  657.                 from airflow.contrib.hooks.cloudant_hook import CloudantHook
  658.                 return CloudantHook(cloudant_conn_id=self.conn_id)
  659.             elif self.conn_type == 'jira':
  660.                 from airflow.contrib.hooks.jira_hook import JiraHook
  661.                 return JiraHook(jira_conn_id=self.conn_id)
  662.         except:
  663.             pass
  664.  
  665.     def __repr__(self):
  666.         return self.conn_id
  667.  
  668.     @property
  669.     def extra_dejson(self):
  670.         """Returns the extra property by deserializing json."""
  671.         obj = {}
  672.         if self.extra:
  673.             try:
  674.                 obj = json.loads(self.extra)
  675.             except Exception as e:
  676.                 logging.exception(e)
  677.                 logging.error("Failed parsing the json for conn_id %s", self.conn_id)
  678.  
  679.         return obj
  680.  
  681.  
  682. class DagPickle(Base):
  683.     """
  684.    Dags can originate from different places (user repos, master repo, ...)
  685.    and also get executed in different places (different executors). This
  686.    object represents a version of a DAG and becomes a source of truth for
  687.    a BackfillJob execution. A pickle is a native python serialized object,
  688.    and in this case gets stored in the database for the duration of the job.
  689.  
  690.    The executors pick up the DagPickle id and read the dag definition from
  691.    the database.
  692.    """
  693.     id = Column(Integer, primary_key=True)
  694.     pickle = Column(PickleType(pickler=dill))
  695.     created_dttm = Column(DateTime, default=func.now())
  696.     pickle_hash = Column(Text)
  697.  
  698.     __tablename__ = "dag_pickle"
  699.  
  700.     def __init__(self, dag):
  701.         self.dag_id = dag.dag_id
  702.         if hasattr(dag, 'template_env'):
  703.             dag.template_env = None
  704.         self.pickle_hash = hash(dag)
  705.         self.pickle = dag
  706.  
  707.  
  708. class TaskInstance(Base):
  709.     """
  710.    Task instances store the state of a task instance. This table is the
  711.    authority and single source of truth around what tasks have run and the
  712.    state they are in.
  713.  
  714.    The SqlAchemy model doesn't have a SqlAlchemy foreign key to the task or
  715.    dag model deliberately to have more control over transactions.
  716.  
  717.    Database transactions on this table should insure double triggers and
  718.    any confusion around what task instances are or aren't ready to run
  719.    even while multiple schedulers may be firing task instances.
  720.    """
  721.  
  722.     __tablename__ = "task_instance"
  723.  
  724.     task_id = Column(String(ID_LEN), primary_key=True)
  725.     dag_id = Column(String(ID_LEN), primary_key=True)
  726.     execution_date = Column(DateTime, primary_key=True)
  727.     start_date = Column(DateTime)
  728.     end_date = Column(DateTime)
  729.     duration = Column(Float)
  730.     state = Column(String(20))
  731.     try_number = Column(Integer, default=0)
  732.     hostname = Column(String(1000))
  733.     unixname = Column(String(1000))
  734.     job_id = Column(Integer)
  735.     pool = Column(String(50))
  736.     queue = Column(String(50))
  737.     priority_weight = Column(Integer)
  738.     operator = Column(String(1000))
  739.     queued_dttm = Column(DateTime)
  740.     pid = Column(Integer)
  741.  
  742.     __table_args__ = (
  743.         Index('ti_dag_state', dag_id, state),
  744.         Index('ti_state', state),
  745.         Index('ti_state_lkp', dag_id, task_id, execution_date, state),
  746.         Index('ti_pool', pool, state, priority_weight),
  747.     )
  748.  
  749.     def __init__(self, task, execution_date, state=None):
  750.         self.dag_id = task.dag_id
  751.         self.task_id = task.task_id
  752.         self.execution_date = execution_date
  753.         self.task = task
  754.         self.queue = task.queue
  755.         self.pool = task.pool
  756.         self.priority_weight = task.priority_weight_total
  757.         self.try_number = 0
  758.         self.unixname = getpass.getuser()
  759.         self.run_as_user = task.run_as_user
  760.         if state:
  761.             self.state = state
  762.         self.hostname = ''
  763.         self.init_on_load()
  764.  
  765.     @reconstructor
  766.     def init_on_load(self):
  767.         """ Initialize the attributes that aren't stored in the DB. """
  768.         self.test_mode = False  # can be changed when calling 'run'
  769.  
  770.     def command(
  771.             self,
  772.             mark_success=False,
  773.             ignore_all_deps=False,
  774.             ignore_depends_on_past=False,
  775.             ignore_task_deps=False,
  776.             ignore_ti_state=False,
  777.             local=False,
  778.             pickle_id=None,
  779.             raw=False,
  780.             job_id=None,
  781.             pool=None,
  782.             cfg_path=None):
  783.         """
  784.        Returns a command that can be executed anywhere where airflow is
  785.        installed. This command is part of the message sent to executors by
  786.        the orchestrator.
  787.        """
  788.         return " ".join(self.command_as_list(
  789.             mark_success=mark_success,
  790.             ignore_all_deps=ignore_all_deps,
  791.             ignore_depends_on_past=ignore_depends_on_past,
  792.             ignore_task_deps=ignore_task_deps,
  793.             ignore_ti_state=ignore_ti_state,
  794.             local=local,
  795.             pickle_id=pickle_id,
  796.             raw=raw,
  797.             job_id=job_id,
  798.             pool=pool,
  799.             cfg_path=cfg_path))
  800.  
  801.     def command_as_list(
  802.             self,
  803.             mark_success=False,
  804.             ignore_all_deps=False,
  805.             ignore_task_deps=False,
  806.             ignore_depends_on_past=False,
  807.             ignore_ti_state=False,
  808.             local=False,
  809.             pickle_id=None,
  810.             raw=False,
  811.             job_id=None,
  812.             pool=None,
  813.             cfg_path=None):
  814.         """
  815.        Returns a command that can be executed anywhere where airflow is
  816.        installed. This command is part of the message sent to executors by
  817.        the orchestrator.
  818.        """
  819.         dag = self.task.dag
  820.  
  821.         should_pass_filepath = not pickle_id and dag
  822.         if should_pass_filepath and dag.full_filepath != dag.filepath:
  823.             path = "DAGS_FOLDER/{}".format(dag.filepath)
  824.         elif should_pass_filepath and dag.full_filepath:
  825.             path = dag.full_filepath
  826.         else:
  827.             path = None
  828.  
  829.         return TaskInstance.generate_command(
  830.             self.dag_id,
  831.             self.task_id,
  832.             self.execution_date,
  833.             mark_success=mark_success,
  834.             ignore_all_deps=ignore_all_deps,
  835.             ignore_task_deps=ignore_task_deps,
  836.             ignore_depends_on_past=ignore_depends_on_past,
  837.             ignore_ti_state=ignore_ti_state,
  838.             local=local,
  839.             pickle_id=pickle_id,
  840.             file_path=path,
  841.             raw=raw,
  842.             job_id=job_id,
  843.             pool=pool,
  844.             cfg_path=cfg_path)
  845.  
  846.     @staticmethod
  847.     def generate_command(dag_id,
  848.                          task_id,
  849.                          execution_date,
  850.                          mark_success=False,
  851.                          ignore_all_deps=False,
  852.                          ignore_depends_on_past=False,
  853.                          ignore_task_deps=False,
  854.                          ignore_ti_state=False,
  855.                          local=False,
  856.                          pickle_id=None,
  857.                          file_path=None,
  858.                          raw=False,
  859.                          job_id=None,
  860.                          pool=None,
  861.                          cfg_path=None
  862.                          ):
  863.         """
  864.        Generates the shell command required to execute this task instance.
  865.  
  866.        :param dag_id: DAG ID
  867.        :type dag_id: unicode
  868.        :param task_id: Task ID
  869.        :type task_id: unicode
  870.        :param execution_date: Execution date for the task
  871.        :type execution_date: datetime
  872.        :param mark_success: Whether to mark the task as successful
  873.        :type mark_success: bool
  874.        :param ignore_all_deps: Ignore all ignoreable dependencies.
  875.            Overrides the other ignore_* parameters.
  876.        :type ignore_all_deps: boolean
  877.        :param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
  878.            (e.g. for Backfills)
  879.        :type ignore_depends_on_past: boolean
  880.        :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
  881.            and trigger rule
  882.        :type ignore_task_deps: boolean
  883.        :param ignore_ti_state: Ignore the task instance's previous failure/success
  884.        :type ignore_ti_state: boolean
  885.        :param local: Whether to run the task locally
  886.        :type local: bool
  887.        :param pickle_id: If the DAG was serialized to the DB, the ID
  888.        associated with the pickled DAG
  889.        :type pickle_id: unicode
  890.        :param file_path: path to the file containing the DAG definition
  891.        :param raw: raw mode (needs more details)
  892.        :param job_id: job ID (needs more details)
  893.        :param pool: the Airflow pool that the task should run in
  894.        :type pool: unicode
  895.        :return: shell command that can be used to run the task instance
  896.        """
  897.         iso = execution_date.isoformat()
  898.         cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)]
  899.         cmd.extend(["--mark_success"]) if mark_success else None
  900.         cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None
  901.         cmd.extend(["--job_id", str(job_id)]) if job_id else None
  902.         cmd.extend(["-A "]) if ignore_all_deps else None
  903.         cmd.extend(["-i"]) if ignore_task_deps else None
  904.         cmd.extend(["-I"]) if ignore_depends_on_past else None
  905.         cmd.extend(["--force"]) if ignore_ti_state else None
  906.         cmd.extend(["--local"]) if local else None
  907.         cmd.extend(["--pool", pool]) if pool else None
  908.         cmd.extend(["--raw"]) if raw else None
  909.         cmd.extend(["-sd", file_path]) if file_path else None
  910.         cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None
  911.         return cmd
  912.  
  913.     @property
  914.     def log_filepath(self):
  915.         iso = self.execution_date.isoformat()
  916.         log = os.path.expanduser(configuration.get('core', 'BASE_LOG_FOLDER'))
  917.         return (
  918.             "{log}/{self.dag_id}/{self.task_id}/{iso}.log".format(**locals()))
  919.  
  920.     @property
  921.     def log_url(self):
  922.         iso = self.execution_date.isoformat()
  923.         BASE_URL = configuration.get('webserver', 'BASE_URL')
  924.         return BASE_URL + (
  925.             "/admin/airflow/log"
  926.             "?dag_id={self.dag_id}"
  927.             "&task_id={self.task_id}"
  928.             "&execution_date={iso}"
  929.         ).format(**locals())
  930.  
  931.     @property
  932.     def mark_success_url(self):
  933.         iso = self.execution_date.isoformat()
  934.         BASE_URL = configuration.get('webserver', 'BASE_URL')
  935.         return BASE_URL + (
  936.             "/admin/airflow/action"
  937.             "?action=success"
  938.             "&task_id={self.task_id}"
  939.             "&dag_id={self.dag_id}"
  940.             "&execution_date={iso}"
  941.             "&upstream=false"
  942.             "&downstream=false"
  943.         ).format(**locals())
  944.  
  945.     @provide_session
  946.     def current_state(self, session=None):
  947.         """
  948.        Get the very latest state from the database, if a session is passed,
  949.        we use and looking up the state becomes part of the session, otherwise
  950.        a new session is used.
  951.        """
  952.         TI = TaskInstance
  953.         ti = session.query(TI).filter(
  954.             TI.dag_id == self.dag_id,
  955.             TI.task_id == self.task_id,
  956.             TI.execution_date == self.execution_date,
  957.         ).all()
  958.         if ti:
  959.             state = ti[0].state
  960.         else:
  961.             state = None
  962.         return state
  963.  
  964.     @provide_session
  965.     def error(self, session=None):
  966.         """
  967.        Forces the task instance's state to FAILED in the database.
  968.        """
  969.         logging.error("Recording the task instance as FAILED")
  970.         self.state = State.FAILED
  971.         session.merge(self)
  972.         session.commit()
  973.  
  974.     @provide_session
  975.     def refresh_from_db(self, session=None, lock_for_update=False):
  976.         """
  977.        Refreshes the task instance from the database based on the primary key
  978.  
  979.        :param lock_for_update: if True, indicates that the database should
  980.            lock the TaskInstance (issuing a FOR UPDATE clause) until the
  981.            session is committed.
  982.        """
  983.         TI = TaskInstance
  984.  
  985.         qry = session.query(TI).filter(
  986.             TI.dag_id == self.dag_id,
  987.             TI.task_id == self.task_id,
  988.             TI.execution_date == self.execution_date)
  989.  
  990.         if lock_for_update:
  991.             ti = qry.with_for_update().first()
  992.         else:
  993.             ti = qry.first()
  994.         if ti:
  995.             self.state = ti.state
  996.             self.start_date = ti.start_date
  997.             self.end_date = ti.end_date
  998.             self.try_number = ti.try_number
  999.             self.hostname = ti.hostname
  1000.             self.pid = ti.pid
  1001.         else:
  1002.             self.state = None
  1003.  
  1004.     @provide_session
  1005.     def clear_xcom_data(self, session=None):
  1006.         """
  1007.        Clears all XCom data from the database for the task instance
  1008.        """
  1009.         session.query(XCom).filter(
  1010.             XCom.dag_id == self.dag_id,
  1011.             XCom.task_id == self.task_id,
  1012.             XCom.execution_date == self.execution_date
  1013.         ).delete()
  1014.         session.commit()
  1015.  
  1016.     @property
  1017.     def key(self):
  1018.         """
  1019.        Returns a tuple that identifies the task instance uniquely
  1020.        """
  1021.         return self.dag_id, self.task_id, self.execution_date
  1022.  
  1023.     def set_state(self, state, session):
  1024.         self.state = state
  1025.         self.start_date = datetime.now()
  1026.         self.end_date = datetime.now()
  1027.         session.merge(self)
  1028.         session.commit()
  1029.  
  1030.     @property
  1031.     def is_premature(self):
  1032.         """
  1033.        Returns whether a task is in UP_FOR_RETRY state and its retry interval
  1034.        has elapsed.
  1035.        """
  1036.         # is the task still in the retry waiting period?
  1037.         return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
  1038.  
  1039.     @provide_session
  1040.     def are_dependents_done(self, session=None):
  1041.         """
  1042.        Checks whether the dependents of this task instance have all succeeded.
  1043.        This is meant to be used by wait_for_downstream.
  1044.  
  1045.        This is useful when you do not want to start processing the next
  1046.        schedule of a task until the dependents are done. For instance,
  1047.        if the task DROPs and recreates a table.
  1048.        """
  1049.         task = self.task
  1050.  
  1051.         if not task.downstream_task_ids:
  1052.             return True
  1053.  
  1054.         ti = session.query(func.count(TaskInstance.task_id)).filter(
  1055.             TaskInstance.dag_id == self.dag_id,
  1056.             TaskInstance.task_id.in_(task.downstream_task_ids),
  1057.             TaskInstance.execution_date == self.execution_date,
  1058.             TaskInstance.state == State.SUCCESS,
  1059.         )
  1060.         count = ti[0][0]
  1061.         return count == len(task.downstream_task_ids)
  1062.  
  1063.     @property
  1064.     @provide_session
  1065.     def previous_ti(self, session=None):
  1066.         """ The task instance for the task that ran before this task instance """
  1067.  
  1068.         dag = self.task.dag
  1069.         if dag:
  1070.             dr = self.get_dagrun(session=session)
  1071.  
  1072.             # LEGACY: most likely running from unit tests
  1073.             if not dr:
  1074.                 # Means that this TI is NOT being run from a DR, but from a catchup
  1075.                 previous_scheduled_date = dag.previous_schedule(self.execution_date)
  1076.                 if not previous_scheduled_date:
  1077.                     return None
  1078.  
  1079.                 return TaskInstance(task=self.task,
  1080.                                     execution_date=previous_scheduled_date)
  1081.  
  1082.             dr.dag = dag
  1083.             if dag.catchup:
  1084.                 last_dagrun = dr.get_previous_scheduled_dagrun(session=session)
  1085.             else:
  1086.                 last_dagrun = dr.get_previous_dagrun(session=session)
  1087.  
  1088.             if last_dagrun:
  1089.                 return last_dagrun.get_task_instance(self.task_id, session=session)
  1090.  
  1091.         return None
  1092.  
  1093.     @provide_session
  1094.     def are_dependencies_met(
  1095.             self,
  1096.             dep_context=None,
  1097.             session=None,
  1098.             verbose=False):
  1099.         """
  1100.        Returns whether or not all the conditions are met for this task instance to be run
  1101.        given the context for the dependencies (e.g. a task instance being force run from
  1102.        the UI will ignore some dependencies).
  1103.  
  1104.        :param dep_context: The execution context that determines the dependencies that
  1105.            should be evaluated.
  1106.        :type dep_context: DepContext
  1107.        :param session: database session
  1108.        :type session: Session
  1109.        :param verbose: whether or not to print details on failed dependencies
  1110.        :type verbose: boolean
  1111.        """
  1112.         dep_context = dep_context or DepContext()
  1113.         failed = False
  1114.         for dep_status in self.get_failed_dep_statuses(
  1115.                 dep_context=dep_context,
  1116.                 session=session):
  1117.             failed = True
  1118.             if verbose:
  1119.                 logging.info("Dependencies not met for {}, dependency '{}' FAILED: {}"
  1120.                              .format(self, dep_status.dep_name, dep_status.reason))
  1121.  
  1122.         if failed:
  1123.             return False
  1124.  
  1125.         if verbose:
  1126.             logging.info("Dependencies all met for {}".format(self))
  1127.  
  1128.         return True
  1129.  
  1130.     @provide_session
  1131.     def get_failed_dep_statuses(
  1132.             self,
  1133.             dep_context=None,
  1134.             session=None):
  1135.         dep_context = dep_context or DepContext()
  1136.         for dep in dep_context.deps | self.task.deps:
  1137.             for dep_status in dep.get_dep_statuses(
  1138.                     self,
  1139.                     session,
  1140.                     dep_context):
  1141.  
  1142.                 logging.debug("{} dependency '{}' PASSED: {}, {}"
  1143.                               .format(self,
  1144.                                       dep_status.dep_name,
  1145.                                       dep_status.passed,
  1146.                                       dep_status.reason))
  1147.  
  1148.                 if not dep_status.passed:
  1149.                     yield dep_status
  1150.  
  1151.     def __repr__(self):
  1152.         return (
  1153.             "<TaskInstance: {ti.dag_id}.{ti.task_id} "
  1154.             "{ti.execution_date} [{ti.state}]>"
  1155.         ).format(ti=self)
  1156.  
  1157.     def next_retry_datetime(self):
  1158.         """
  1159.        Get datetime of the next retry if the task instance fails. For exponential
  1160.        backoff, retry_delay is used as base and will be converted to seconds.
  1161.        """
  1162.         delay = self.task.retry_delay
  1163.         if self.task.retry_exponential_backoff:
  1164.             delay_backoff_in_seconds = delay.total_seconds() ** self.try_number
  1165.             delay = timedelta(seconds=delay_backoff_in_seconds)
  1166.             if self.task.max_retry_delay:
  1167.                 delay = min(self.task.max_retry_delay, delay)
  1168.         return self.end_date + delay
  1169.  
  1170.     def ready_for_retry(self):
  1171.         """
  1172.        Checks on whether the task instance is in the right state and timeframe
  1173.        to be retried.
  1174.        """
  1175.         return (self.state == State.UP_FOR_RETRY and
  1176.                 self.next_retry_datetime() < datetime.now())
  1177.  
  1178.     @provide_session
  1179.     def pool_full(self, session):
  1180.         """
  1181.        Returns a boolean as to whether the slot pool has room for this
  1182.        task to run
  1183.        """
  1184.         if not self.task.pool:
  1185.             return False
  1186.  
  1187.         pool = (
  1188.             session
  1189.             .query(Pool)
  1190.             .filter(Pool.pool == self.task.pool)
  1191.             .first()
  1192.         )
  1193.         if not pool:
  1194.             return False
  1195.         open_slots = pool.open_slots(session=session)
  1196.  
  1197.         return open_slots <= 0
  1198.  
  1199.     @provide_session
  1200.     def get_dagrun(self, session):
  1201.         """
  1202.        Returns the DagRun for this TaskInstance
  1203.        :param session:
  1204.        :return: DagRun
  1205.        """
  1206.         dr = session.query(DagRun).filter(
  1207.             DagRun.dag_id == self.dag_id,
  1208.             DagRun.execution_date == self.execution_date
  1209.         ).first()
  1210.  
  1211.         return dr
  1212.  
  1213.     @provide_session
  1214.     def run(
  1215.             self,
  1216.             verbose=True,
  1217.             ignore_all_deps=False,
  1218.             ignore_depends_on_past=False,
  1219.             ignore_task_deps=False,
  1220.             ignore_ti_state=False,
  1221.             mark_success=False,
  1222.             test_mode=False,
  1223.             job_id=None,
  1224.             pool=None,
  1225.             session=None):
  1226.         """
  1227.        Runs the task instance.
  1228.  
  1229.        :param verbose: whether to turn on more verbose loggin
  1230.        :type verbose: boolean
  1231.        :param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
  1232.        :type ignore_all_deps: boolean
  1233.        :param ignore_depends_on_past: Ignore depends_on_past DAG attribute
  1234.        :type ignore_depends_on_past: boolean
  1235.        :param ignore_task_deps: Don't check the dependencies of this TI's task
  1236.        :type ignore_task_deps: boolean
  1237.        :param ignore_ti_state: Disregards previous task instance state
  1238.        :type ignore_ti_state: boolean
  1239.        :param mark_success: Don't run the task, mark its state as success
  1240.        :type mark_success: boolean
  1241.        :param test_mode: Doesn't record success or failure in the DB
  1242.        :type test_mode: boolean
  1243.        :param pool: specifies the pool to use to run the task instance
  1244.        :type pool: str
  1245.        """
  1246.         task = self.task
  1247.         self.pool = pool or task.pool
  1248.         self.test_mode = test_mode
  1249.         self.refresh_from_db(session=session, lock_for_update=True)
  1250.         self.job_id = job_id
  1251.         self.hostname = socket.getfqdn()
  1252.         self.operator = task.__class__.__name__
  1253.  
  1254.         if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
  1255.             Stats.incr('previously_succeeded', 1, 1)
  1256.  
  1257.         queue_dep_context = DepContext(
  1258.             deps=QUEUE_DEPS,
  1259.             ignore_all_deps=ignore_all_deps,
  1260.             ignore_ti_state=ignore_ti_state,
  1261.             ignore_depends_on_past=ignore_depends_on_past,
  1262.             ignore_task_deps=ignore_task_deps)
  1263.         if not self.are_dependencies_met(
  1264.                 dep_context=queue_dep_context,
  1265.                 session=session,
  1266.                 verbose=True):
  1267.             session.commit()
  1268.             return
  1269.  
  1270.         hr = "\n" + ("-" * 80) + "\n"  # Line break
  1271.  
  1272.         # For reporting purposes, we report based on 1-indexed,
  1273.         # not 0-indexed lists (i.e. Attempt 1 instead of
  1274.         # Attempt 0 for the first attempt).
  1275.         msg = "Starting attempt {attempt} of {total}".format(
  1276.             attempt=self.try_number % (task.retries + 1) + 1,
  1277.             total=task.retries + 1)
  1278.         self.start_date = datetime.now()
  1279.  
  1280.         dep_context = DepContext(
  1281.             deps=RUN_DEPS - QUEUE_DEPS,
  1282.             ignore_all_deps=ignore_all_deps,
  1283.             ignore_depends_on_past=ignore_depends_on_past,
  1284.             ignore_task_deps=ignore_task_deps,
  1285.             ignore_ti_state=ignore_ti_state)
  1286.         runnable = self.are_dependencies_met(
  1287.             dep_context=dep_context,
  1288.             session=session,
  1289.             verbose=True)
  1290.  
  1291.         if not runnable and not mark_success:
  1292.             # FIXME: we might have hit concurrency limits, which means we probably
  1293.             # have been running prematurely. This should be handled in the
  1294.             # scheduling mechanism.
  1295.             self.state = State.NONE
  1296.             msg = ("FIXME: Rescheduling due to concurrency limits reached at task "
  1297.                    "runtime. Attempt {attempt} of {total}. State set to NONE.").format(
  1298.                 attempt=self.try_number % (task.retries + 1) + 1,
  1299.                 total=task.retries + 1)
  1300.             logging.warning(hr + msg + hr)
  1301.  
  1302.             self.queued_dttm = datetime.now()
  1303.             msg = "Queuing into pool {}".format(self.pool)
  1304.             logging.info(msg)
  1305.             session.merge(self)
  1306.             session.commit()
  1307.             return
  1308.  
  1309.         # Another worker might have started running this task instance while
  1310.         # the current worker process was blocked on refresh_from_db
  1311.         if self.state == State.RUNNING:
  1312.             msg = "Task Instance already running {}".format(self)
  1313.             logging.warn(msg)
  1314.             session.commit()
  1315.             return
  1316.  
  1317.         # print status message
  1318.         logging.info(hr + msg + hr)
  1319.         self.try_number += 1
  1320.  
  1321.         if not test_mode:
  1322.             session.add(Log(State.RUNNING, self))
  1323.         self.state = State.RUNNING
  1324.         self.pid = os.getpid()
  1325.         self.end_date = None
  1326.         if not test_mode:
  1327.             session.merge(self)
  1328.         session.commit()
  1329.  
  1330.         # Closing all pooled connections to prevent
  1331.         # "max number of connections reached"
  1332.         settings.engine.dispose()
  1333.         if verbose:
  1334.             if mark_success:
  1335.                 msg = "Marking success for "
  1336.             else:
  1337.                 msg = "Executing "
  1338.             msg += "{self.task} on {self.execution_date}"
  1339.  
  1340.         context = {}
  1341.         try:
  1342.             logging.info(msg.format(self=self))
  1343.             if not mark_success:
  1344.                 context = self.get_template_context()
  1345.  
  1346.                 task_copy = copy.copy(task)
  1347.                 self.task = task_copy
  1348.  
  1349.                 def signal_handler(signum, frame):
  1350.                     '''Setting kill signal handler'''
  1351.                     logging.error("Killing subprocess")
  1352.                     task_copy.on_kill()
  1353.                     raise AirflowException("Task received SIGTERM signal")
  1354.                 signal.signal(signal.SIGTERM, signal_handler)
  1355.  
  1356.                 # Don't clear Xcom until the task is certain to execute
  1357.                 self.clear_xcom_data()
  1358.  
  1359.                 self.render_templates()
  1360.                 task_copy.pre_execute(context=context)
  1361.  
  1362.                 # If a timeout is specified for the task, make it fail
  1363.                 # if it goes beyond
  1364.                 result = None
  1365.                 if task_copy.execution_timeout:
  1366.                     try:
  1367.                         with timeout(int(
  1368.                                 task_copy.execution_timeout.total_seconds())):
  1369.                             result = task_copy.execute(context=context)
  1370.                     except AirflowTaskTimeout:
  1371.                         task_copy.on_kill()
  1372.                         raise
  1373.                 else:
  1374.                     result = task_copy.execute(context=context)
  1375.  
  1376.                 # If the task returns a result, push an XCom containing it
  1377.                 if result is not None:
  1378.                     self.xcom_push(key=XCOM_RETURN_KEY, value=result)
  1379.  
  1380.                 task_copy.post_execute(context=context)
  1381.                 Stats.incr('operator_successes_{}'.format(
  1382.                     self.task.__class__.__name__), 1, 1)
  1383.             self.state = State.SUCCESS
  1384.         except AirflowSkipException:
  1385.             self.state = State.SKIPPED
  1386.         except (Exception, KeyboardInterrupt) as e:
  1387.             self.handle_failure(e, test_mode, context)
  1388.             raise
  1389.  
  1390.         # Recording SUCCESS
  1391.         self.end_date = datetime.now()
  1392.         self.set_duration()
  1393.         if not test_mode:
  1394.             session.add(Log(self.state, self))
  1395.             session.merge(self)
  1396.         session.commit()
  1397.  
  1398.         # Success callback
  1399.         try:
  1400.             if task.on_success_callback:
  1401.                 task.on_success_callback(context)
  1402.         except Exception as e3:
  1403.             logging.error("Failed when executing success callback")
  1404.             logging.exception(e3)
  1405.  
  1406.         session.commit()
  1407.  
  1408.     def dry_run(self):
  1409.         task = self.task
  1410.         task_copy = copy.copy(task)
  1411.         self.task = task_copy
  1412.  
  1413.         self.render_templates()
  1414.         task_copy.dry_run()
  1415.  
  1416.     def handle_failure(self, error, test_mode=False, context=None):
  1417.         logging.exception(error)
  1418.         task = self.task
  1419.         session = settings.Session()
  1420.         self.end_date = datetime.now()
  1421.         self.set_duration()
  1422.         Stats.incr('operator_failures_{}'.format(task.__class__.__name__), 1, 1)
  1423.         if not test_mode:
  1424.             session.add(Log(State.FAILED, self))
  1425.  
  1426.         # Log failure duration
  1427.         session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date))
  1428.  
  1429.         # Let's go deeper
  1430.         try:
  1431.             if task.retries and self.try_number % (task.retries + 1) != 0:
  1432.                 self.state = State.UP_FOR_RETRY
  1433.                 logging.info('Marking task as UP_FOR_RETRY')
  1434.                 if task.email_on_retry and task.email:
  1435.                     self.email_alert(error, is_retry=True)
  1436.             else:
  1437.                 self.state = State.FAILED
  1438.                 if task.retries:
  1439.                     logging.info('All retries failed; marking task as FAILED')
  1440.                 else:
  1441.                     logging.info('Marking task as FAILED.')
  1442.                 if task.email_on_failure and task.email:
  1443.                     self.email_alert(error, is_retry=False)
  1444.         except Exception as e2:
  1445.             logging.error(
  1446.                 'Failed to send email to: ' + str(task.email))
  1447.             logging.exception(e2)
  1448.  
  1449.         # Handling callbacks pessimistically
  1450.         try:
  1451.             if self.state == State.UP_FOR_RETRY and task.on_retry_callback:
  1452.                 task.on_retry_callback(context)
  1453.             if self.state == State.FAILED and task.on_failure_callback:
  1454.                 task.on_failure_callback(context)
  1455.         except Exception as e3:
  1456.             logging.error("Failed at executing callback")
  1457.             logging.exception(e3)
  1458.  
  1459.         if not test_mode:
  1460.             session.merge(self)
  1461.         session.commit()
  1462.         logging.error(str(error))
  1463.  
  1464.     @provide_session
  1465.     def get_template_context(self, session=None):
  1466.         task = self.task
  1467.         from airflow import macros
  1468.         tables = None
  1469.         if 'tables' in task.params:
  1470.             tables = task.params['tables']
  1471.  
  1472.         ds = self.execution_date.isoformat()[:10]
  1473.         ts = self.execution_date.isoformat()
  1474.         yesterday_ds = (self.execution_date - timedelta(1)).isoformat()[:10]
  1475.         tomorrow_ds = (self.execution_date + timedelta(1)).isoformat()[:10]
  1476.  
  1477.         prev_execution_date = task.dag.previous_schedule(self.execution_date)
  1478.         next_execution_date = task.dag.following_schedule(self.execution_date)
  1479.  
  1480.         ds_nodash = ds.replace('-', '')
  1481.         ts_nodash = ts.replace('-', '').replace(':', '')
  1482.         yesterday_ds_nodash = yesterday_ds.replace('-', '')
  1483.         tomorrow_ds_nodash = tomorrow_ds.replace('-', '')
  1484.  
  1485.         ti_key_str = "{task.dag_id}__{task.task_id}__{ds_nodash}"
  1486.         ti_key_str = ti_key_str.format(**locals())
  1487.  
  1488.         params = {}
  1489.         run_id = ''
  1490.         dag_run = None
  1491.         if hasattr(task, 'dag'):
  1492.             if task.dag.params:
  1493.                 params.update(task.dag.params)
  1494.             dag_run = (
  1495.                 session.query(DagRun)
  1496.                 .filter_by(
  1497.                     dag_id=task.dag.dag_id,
  1498.                     execution_date=self.execution_date)
  1499.                 .first()
  1500.             )
  1501.             run_id = dag_run.run_id if dag_run else None
  1502.             session.expunge_all()
  1503.             session.commit()
  1504.  
  1505.         if task.params:
  1506.             params.update(task.params)
  1507.  
  1508.         class VariableAccessor:
  1509.             """
  1510.            Wrapper around Variable. This way you can get variables in templates by using
  1511.            {var.variable_name}.
  1512.            """
  1513.             def __init__(self):
  1514.                 self.var = None
  1515.  
  1516.             def __getattr__(self, item):
  1517.                 self.var = Variable.get(item)
  1518.                 return self.var
  1519.  
  1520.             def __repr__(self):
  1521.                 return str(self.var)
  1522.  
  1523.         class VariableJsonAccessor:
  1524.             def __init__(self):
  1525.                 self.var = None
  1526.  
  1527.             def __getattr__(self, item):
  1528.                 self.var = Variable.get(item, deserialize_json=True)
  1529.                 return self.var
  1530.  
  1531.             def __repr__(self):
  1532.                 return str(self.var)
  1533.  
  1534.         return {
  1535.             'dag': task.dag,
  1536.             'ds': ds,
  1537.             'ds_nodash': ds_nodash,
  1538.             'ts': ts,
  1539.             'ts_nodash': ts_nodash,
  1540.             'yesterday_ds': yesterday_ds,
  1541.             'yesterday_ds_nodash': yesterday_ds_nodash,
  1542.             'tomorrow_ds': tomorrow_ds,
  1543.             'tomorrow_ds_nodash': tomorrow_ds_nodash,
  1544.             'END_DATE': ds,
  1545.             'end_date': ds,
  1546.             'dag_run': dag_run,
  1547.             'run_id': run_id,
  1548.             'execution_date': self.execution_date,
  1549.             'prev_execution_date': prev_execution_date,
  1550.             'next_execution_date': next_execution_date,
  1551.             'latest_date': ds,
  1552.             'macros': macros,
  1553.             'params': params,
  1554.             'tables': tables,
  1555.             'task': task,
  1556.             'task_instance': self,
  1557.             'ti': self,
  1558.             'task_instance_key_str': ti_key_str,
  1559.             'conf': configuration,
  1560.             'test_mode': self.test_mode,
  1561.             'var': {
  1562.                 'value': VariableAccessor(),
  1563.                 'json': VariableJsonAccessor()
  1564.             }
  1565.         }
  1566.  
  1567.     def render_templates(self):
  1568.         task = self.task
  1569.         jinja_context = self.get_template_context()
  1570.         if hasattr(self, 'task') and hasattr(self.task, 'dag'):
  1571.             if self.task.dag.user_defined_macros:
  1572.                 jinja_context.update(
  1573.                     self.task.dag.user_defined_macros)
  1574.  
  1575.         rt = self.task.render_template  # shortcut to method
  1576.         for attr in task.__class__.template_fields:
  1577.             content = getattr(task, attr)
  1578.             if content:
  1579.                 rendered_content = rt(attr, content, jinja_context)
  1580.                 setattr(task, attr, rendered_content)
  1581.  
  1582.     def email_alert(self, exception, is_retry=False):
  1583.         task = self.task
  1584.         title = "Airflow alert: {self}".format(**locals())
  1585.         exception = str(exception).replace('\n', '<br>')
  1586.         try_ = task.retries + 1
  1587.         body = (
  1588.             "Try {self.try_number} out of {try_}<br>"
  1589.             "Exception:<br>{exception}<br>"
  1590.             "Log: <a href='{self.log_url}'>Link</a><br>"
  1591.             "Host: {self.hostname}<br>"
  1592.             "Log file: {self.log_filepath}<br>"
  1593.             "Mark success: <a href='{self.mark_success_url}'>Link</a><br>"
  1594.         ).format(**locals())
  1595.         send_email(task.email, title, body)
  1596.  
  1597.     def set_duration(self):
  1598.         if self.end_date and self.start_date:
  1599.             self.duration = (self.end_date - self.start_date).total_seconds()
  1600.         else:
  1601.             self.duration = None
  1602.  
  1603.     def xcom_push(
  1604.             self,
  1605.             key,
  1606.             value,
  1607.             execution_date=None):
  1608.         """
  1609.        Make an XCom available for tasks to pull.
  1610.  
  1611.        :param key: A key for the XCom
  1612.        :type key: string
  1613.        :param value: A value for the XCom. The value is pickled and stored
  1614.            in the database.
  1615.        :type value: any pickleable object
  1616.        :param execution_date: if provided, the XCom will not be visible until
  1617.            this date. This can be used, for example, to send a message to a
  1618.            task on a future date without it being immediately visible.
  1619.        :type execution_date: datetime
  1620.        """
  1621.  
  1622.         if execution_date and execution_date < self.execution_date:
  1623.             raise ValueError(
  1624.                 'execution_date can not be in the past (current '
  1625.                 'execution_date is {}; received {})'.format(
  1626.                     self.execution_date, execution_date))
  1627.  
  1628.         XCom.set(
  1629.             key=key,
  1630.             value=value,
  1631.             task_id=self.task_id,
  1632.             dag_id=self.dag_id,
  1633.             execution_date=execution_date or self.execution_date)
  1634.  
  1635.     def xcom_pull(
  1636.             self,
  1637.             task_ids,
  1638.             dag_id=None,
  1639.             key=XCOM_RETURN_KEY,
  1640.             include_prior_dates=False):
  1641.         """
  1642.        Pull XComs that optionally meet certain criteria.
  1643.  
  1644.        The default value for `key` limits the search to XComs
  1645.        that were returned by other tasks (as opposed to those that were pushed
  1646.        manually). To remove this filter, pass key=None (or any desired value).
  1647.  
  1648.        If a single task_id string is provided, the result is the value of the
  1649.        most recent matching XCom from that task_id. If multiple task_ids are
  1650.        provided, a tuple of matching values is returned. None is returned
  1651.        whenever no matches are found.
  1652.  
  1653.        :param key: A key for the XCom. If provided, only XComs with matching
  1654.            keys will be returned. The default key is 'return_value', also
  1655.            available as a constant XCOM_RETURN_KEY. This key is automatically
  1656.            given to XComs returned by tasks (as opposed to being pushed
  1657.            manually). To remove the filter, pass key=None.
  1658.        :type key: string
  1659.        :param task_ids: Only XComs from tasks with matching ids will be
  1660.            pulled. Can pass None to remove the filter.
  1661.        :type task_ids: string or iterable of strings (representing task_ids)
  1662.        :param dag_id: If provided, only pulls XComs from this DAG.
  1663.            If None (default), the DAG of the calling task is used.
  1664.        :type dag_id: string
  1665.        :param include_prior_dates: If False, only XComs from the current
  1666.            execution_date are returned. If True, XComs from previous dates
  1667.            are returned as well.
  1668.        :type include_prior_dates: bool
  1669.        """
  1670.  
  1671.         if dag_id is None:
  1672.             dag_id = self.dag_id
  1673.  
  1674.         pull_fn = functools.partial(
  1675.             XCom.get_one,
  1676.             execution_date=self.execution_date,
  1677.             key=key,
  1678.             dag_id=dag_id,
  1679.             include_prior_dates=include_prior_dates)
  1680.  
  1681.         if is_container(task_ids):
  1682.             return tuple(pull_fn(task_id=t) for t in task_ids)
  1683.         else:
  1684.             return pull_fn(task_id=task_ids)
  1685.  
  1686.  
  1687. class TaskFail(Base):
  1688.     """
  1689.    TaskFail tracks the failed run durations of each task instance.
  1690.    """
  1691.  
  1692.     __tablename__ = "task_fail"
  1693.  
  1694.     task_id = Column(String(ID_LEN), primary_key=True)
  1695.     dag_id = Column(String(ID_LEN), primary_key=True)
  1696.     execution_date = Column(DateTime, primary_key=True)
  1697.     start_date = Column(DateTime)
  1698.     end_date = Column(DateTime)
  1699.     duration = Column(Float)
  1700.  
  1701.     def __init__(self, task, execution_date, start_date, end_date):
  1702.         self.dag_id = task.dag_id
  1703.         self.task_id = task.task_id
  1704.         self.execution_date = execution_date
  1705.         self.start_date = start_date
  1706.         self.end_date = end_date
  1707.         self.duration = (self.end_date - self.start_date).total_seconds()
  1708.  
  1709.  
  1710. class Log(Base):
  1711.     """
  1712.    Used to actively log events to the database
  1713.    """
  1714.  
  1715.     __tablename__ = "log"
  1716.  
  1717.     id = Column(Integer, primary_key=True)
  1718.     dttm = Column(DateTime)
  1719.     dag_id = Column(String(ID_LEN))
  1720.     task_id = Column(String(ID_LEN))
  1721.     event = Column(String(30))
  1722.     execution_date = Column(DateTime)
  1723.     owner = Column(String(500))
  1724.     extra = Column(Text)
  1725.  
  1726.     def __init__(self, event, task_instance, owner=None, extra=None, **kwargs):
  1727.         self.dttm = datetime.now()
  1728.         self.event = event
  1729.         self.extra = extra
  1730.  
  1731.         task_owner = None
  1732.  
  1733.         if task_instance:
  1734.             self.dag_id = task_instance.dag_id
  1735.             self.task_id = task_instance.task_id
  1736.             self.execution_date = task_instance.execution_date
  1737.             task_owner = task_instance.task.owner
  1738.  
  1739.         if 'task_id' in kwargs:
  1740.             self.task_id = kwargs['task_id']
  1741.         if 'dag_id' in kwargs:
  1742.             self.dag_id = kwargs['dag_id']
  1743.         if 'execution_date' in kwargs:
  1744.             if kwargs['execution_date']:
  1745.                 self.execution_date = kwargs['execution_date']
  1746.  
  1747.         self.owner = owner or task_owner
  1748.  
  1749.  
  1750. @functools.total_ordering
  1751. class BaseOperator(object):
  1752.     """
  1753.    Abstract base class for all operators. Since operators create objects that
  1754.    become node in the dag, BaseOperator contains many recursive methods for
  1755.    dag crawling behavior. To derive this class, you are expected to override
  1756.    the constructor as well as the 'execute' method.
  1757.  
  1758.    Operators derived from this task should perform or trigger certain tasks
  1759.    synchronously (wait for completion). Example of operators could be an
  1760.    operator the runs a Pig job (PigOperator), a sensor operator that
  1761.    waits for a partition to land in Hive (HiveSensorOperator), or one that
  1762.    moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
  1763.    operators (tasks) target specific operations, running specific scripts,
  1764.    functions or data transfers.
  1765.  
  1766.    This class is abstract and shouldn't be instantiated. Instantiating a
  1767.    class derived from this one results in the creation of a task object,
  1768.    which ultimately becomes a node in DAG objects. Task dependencies should
  1769.    be set by using the set_upstream and/or set_downstream methods.
  1770.  
  1771.    Note that this class is derived from SQLAlchemy's Base class, which
  1772.    allows us to push metadata regarding tasks to the database. Deriving this
  1773.    classes needs to implement the polymorphic specificities documented in
  1774.    SQLAlchemy. This should become clear while reading the code for other
  1775.    operators.
  1776.  
  1777.    :param task_id: a unique, meaningful id for the task
  1778.    :type task_id: string
  1779.    :param owner: the owner of the task, using the unix username is recommended
  1780.    :type owner: string
  1781.    :param retries: the number of retries that should be performed before
  1782.        failing the task
  1783.    :type retries: int
  1784.    :param retry_delay: delay between retries
  1785.    :type retry_delay: timedelta
  1786.    :param retry_exponential_backoff: allow progressive longer waits between
  1787.        retries by using exponential backoff algorithm on retry delay (delay
  1788.        will be converted into seconds)
  1789.    :type retry_exponential_backoff: bool
  1790.    :param max_retry_delay: maximum delay interval between retries
  1791.    :type max_retry_delay: timedelta
  1792.    :param start_date: The ``start_date`` for the task, determines
  1793.        the ``execution_date`` for the first task instance. The best practice
  1794.        is to have the start_date rounded
  1795.        to your DAG's ``schedule_interval``. Daily jobs have their start_date
  1796.        some day at 00:00:00, hourly jobs have their start_date at 00:00
  1797.        of a specific hour. Note that Airflow simply looks at the latest
  1798.        ``execution_date`` and adds the ``schedule_interval`` to determine
  1799.        the next ``execution_date``. It is also very important
  1800.        to note that different tasks' dependencies
  1801.        need to line up in time. If task A depends on task B and their
  1802.        start_date are offset in a way that their execution_date don't line
  1803.        up, A's dependencies will never be met. If you are looking to delay
  1804.        a task, for example running a daily task at 2AM, look into the
  1805.        ``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
  1806.        dynamic ``start_date`` and recommend using fixed ones. Read the
  1807.        FAQ entry about start_date for more information.
  1808.    :type start_date: datetime
  1809.    :param end_date: if specified, the scheduler won't go beyond this date
  1810.    :type end_date: datetime
  1811.    :param depends_on_past: when set to true, task instances will run
  1812.        sequentially while relying on the previous task's schedule to
  1813.        succeed. The task instance for the start_date is allowed to run.
  1814.    :type depends_on_past: bool
  1815.    :param wait_for_downstream: when set to true, an instance of task
  1816.        X will wait for tasks immediately downstream of the previous instance
  1817.        of task X to finish successfully before it runs. This is useful if the
  1818.        different instances of a task X alter the same asset, and this asset
  1819.        is used by tasks downstream of task X. Note that depends_on_past
  1820.        is forced to True wherever wait_for_downstream is used.
  1821.    :type wait_for_downstream: bool
  1822.    :param queue: which queue to target when running this job. Not
  1823.        all executors implement queue management, the CeleryExecutor
  1824.        does support targeting specific queues.
  1825.    :type queue: str
  1826.    :param dag: a reference to the dag the task is attached to (if any)
  1827.    :type dag: DAG
  1828.    :param priority_weight: priority weight of this task against other task.
  1829.        This allows the executor to trigger higher priority tasks before
  1830.        others when things get backed up.
  1831.    :type priority_weight: int
  1832.    :param pool: the slot pool this task should run in, slot pools are a
  1833.        way to limit concurrency for certain tasks
  1834.    :type pool: str
  1835.    :param sla: time by which the job is expected to succeed. Note that
  1836.        this represents the ``timedelta`` after the period is closed. For
  1837.        example if you set an SLA of 1 hour, the scheduler would send dan email
  1838.        soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance
  1839.        has not succeeded yet.
  1840.        The scheduler pays special attention for jobs with an SLA and
  1841.        sends alert
  1842.        emails for sla misses. SLA misses are also recorded in the database
  1843.        for future reference. All tasks that share the same SLA time
  1844.        get bundled in a single email, sent soon after that time. SLA
  1845.        notification are sent once and only once for each task instance.
  1846.    :type sla: datetime.timedelta
  1847.    :param execution_timeout: max time allowed for the execution of
  1848.        this task instance, if it goes beyond it will raise and fail.
  1849.    :type execution_timeout: datetime.timedelta
  1850.    :param on_failure_callback: a function to be called when a task instance
  1851.        of this task fails. a context dictionary is passed as a single
  1852.        parameter to this function. Context contains references to related
  1853.        objects to the task instance and is documented under the macros
  1854.        section of the API.
  1855.    :type on_failure_callback: callable
  1856.    :param on_retry_callback: much like the ``on_failure_callback`` excepts
  1857.        that it is executed when retries occur.
  1858.    :param on_success_callback: much like the ``on_failure_callback`` excepts
  1859.        that it is executed when the task succeeds.
  1860.    :type on_success_callback: callable
  1861.    :param trigger_rule: defines the rule by which dependencies are applied
  1862.        for the task to get triggered. Options are:
  1863.        ``{ all_success | all_failed | all_done | one_success |
  1864.        one_failed | dummy}``
  1865.        default is ``all_success``. Options can be set as string or
  1866.        using the constants defined in the static class
  1867.        ``airflow.utils.TriggerRule``
  1868.    :type trigger_rule: str
  1869.    :param resources: A map of resource parameter names (the argument names of the
  1870.        Resources constructor) to their values.
  1871.    :type resources: dict
  1872.    :param run_as_user: unix username to impersonate while running the task
  1873.    :type run_as_user: str
  1874.    """
  1875.  
  1876.     # For derived classes to define which fields will get jinjaified
  1877.     template_fields = []
  1878.     # Defines wich files extensions to look for in the templated fields
  1879.     template_ext = []
  1880.     # Defines the color in the UI
  1881.     ui_color = '#fff'
  1882.     ui_fgcolor = '#000'
  1883.  
  1884.     @apply_defaults
  1885.     def __init__(
  1886.             self,
  1887.             task_id,
  1888.             owner=configuration.get('operators', 'DEFAULT_OWNER'),
  1889.             email=None,
  1890.             email_on_retry=True,
  1891.             email_on_failure=True,
  1892.             retries=0,
  1893.             retry_delay=timedelta(seconds=300),
  1894.             retry_exponential_backoff=False,
  1895.             max_retry_delay=None,
  1896.             start_date=None,
  1897.             end_date=None,
  1898.             schedule_interval=None,  # not hooked as of now
  1899.             depends_on_past=False,
  1900.             wait_for_downstream=False,
  1901.             dag=None,
  1902.             params=None,
  1903.             default_args=None,
  1904.             adhoc=False,
  1905.             priority_weight=1,
  1906.             queue=configuration.get('celery', 'default_queue'),
  1907.             pool=None,
  1908.             sla=None,
  1909.             execution_timeout=None,
  1910.             on_failure_callback=None,
  1911.             on_success_callback=None,
  1912.             on_retry_callback=None,
  1913.             trigger_rule=TriggerRule.ALL_SUCCESS,
  1914.             resources=None,
  1915.             run_as_user=None,
  1916.             *args,
  1917.             **kwargs):
  1918.  
  1919.         if args or kwargs:
  1920.             # TODO remove *args and **kwargs in Airflow 2.0
  1921.             warnings.warn(
  1922.                 'Invalid arguments were passed to {c}. Support for '
  1923.                 'passing such arguments will be dropped in Airflow 2.0. '
  1924.                 'Invalid arguments were:'
  1925.                 '\n*args: {a}\n**kwargs: {k}'.format(
  1926.                     c=self.__class__.__name__, a=args, k=kwargs),
  1927.                 category=PendingDeprecationWarning
  1928.             )
  1929.  
  1930.         validate_key(task_id)
  1931.         self.task_id = task_id
  1932.         self.owner = owner
  1933.         self.email = email
  1934.         self.email_on_retry = email_on_retry
  1935.         self.email_on_failure = email_on_failure
  1936.         self.start_date = start_date
  1937.         if start_date and not isinstance(start_date, datetime):
  1938.             logging.warning(
  1939.                 "start_date for {} isn't datetime.datetime".format(self))
  1940.         self.end_date = end_date
  1941.         if not TriggerRule.is_valid(trigger_rule):
  1942.             raise AirflowException(
  1943.                 "The trigger_rule must be one of {all_triggers},"
  1944.                 "'{d}.{t}'; received '{tr}'."
  1945.                 .format(all_triggers=TriggerRule.all_triggers,
  1946.                         d=dag.dag_id, t=task_id, tr=trigger_rule))
  1947.  
  1948.         self.trigger_rule = trigger_rule
  1949.         self.depends_on_past = depends_on_past
  1950.         self.wait_for_downstream = wait_for_downstream
  1951.         if wait_for_downstream:
  1952.             self.depends_on_past = True
  1953.  
  1954.         if schedule_interval:
  1955.             logging.warning(
  1956.                 "schedule_interval is used for {}, though it has "
  1957.                 "been deprecated as a task parameter, you need to "
  1958.                 "specify it as a DAG parameter instead".format(self))
  1959.         self._schedule_interval = schedule_interval
  1960.         self.retries = retries
  1961.         self.queue = queue
  1962.         self.pool = pool
  1963.         self.sla = sla
  1964.         self.execution_timeout = execution_timeout
  1965.         self.on_failure_callback = on_failure_callback
  1966.         self.on_success_callback = on_success_callback
  1967.         self.on_retry_callback = on_retry_callback
  1968.         if isinstance(retry_delay, timedelta):
  1969.             self.retry_delay = retry_delay
  1970.         else:
  1971.             logging.debug("retry_delay isn't timedelta object, assuming secs")
  1972.             self.retry_delay = timedelta(seconds=retry_delay)
  1973.         self.retry_exponential_backoff = retry_exponential_backoff
  1974.         self.max_retry_delay = max_retry_delay
  1975.         self.params = params or {}  # Available in templates!
  1976.         self.adhoc = adhoc
  1977.         self.priority_weight = priority_weight
  1978.         self.resources = Resources(**(resources or {}))
  1979.         self.run_as_user = run_as_user
  1980.  
  1981.         # Private attributes
  1982.         self._upstream_task_ids = []
  1983.         self._downstream_task_ids = []
  1984.  
  1985.         if not dag and _CONTEXT_MANAGER_DAG:
  1986.             dag = _CONTEXT_MANAGER_DAG
  1987.         if dag:
  1988.             self.dag = dag
  1989.  
  1990.         self._comps = {
  1991.             'task_id',
  1992.             'dag_id',
  1993.             'owner',
  1994.             'email',
  1995.             'email_on_retry',
  1996.             'retry_delay',
  1997.             'retry_exponential_backoff',
  1998.             'max_retry_delay',
  1999.             'start_date',
  2000.             'schedule_interval',
  2001.             'depends_on_past',
  2002.             'wait_for_downstream',
  2003.             'adhoc',
  2004.             'priority_weight',
  2005.             'sla',
  2006.             'execution_timeout',
  2007.             'on_failure_callback',
  2008.             'on_success_callback',
  2009.             'on_retry_callback',
  2010.         }
  2011.  
  2012.     def __eq__(self, other):
  2013.         return (
  2014.             type(self) == type(other) and
  2015.             all(self.__dict__.get(c, None) == other.__dict__.get(c, None)
  2016.                 for c in self._comps))
  2017.  
  2018.     def __ne__(self, other):
  2019.         return not self == other
  2020.  
  2021.     def __lt__(self, other):
  2022.         return self.task_id < other.task_id
  2023.  
  2024.     def __hash__(self):
  2025.         hash_components = [type(self)]
  2026.         for c in self._comps:
  2027.             val = getattr(self, c, None)
  2028.             try:
  2029.                 hash(val)
  2030.                 hash_components.append(val)
  2031.             except TypeError:
  2032.                 hash_components.append(repr(val))
  2033.         return hash(tuple(hash_components))
  2034.  
  2035.     # Composing Operators -----------------------------------------------
  2036.  
  2037.     def __rshift__(self, other):
  2038.         """
  2039.        Implements Self >> Other == self.set_downstream(other)
  2040.  
  2041.        If "Other" is a DAG, the DAG is assigned to the Operator.
  2042.        """
  2043.         if isinstance(other, DAG):
  2044.             # if this dag is already assigned, do nothing
  2045.             # otherwise, do normal dag assignment
  2046.             if not (self.has_dag() and self.dag is other):
  2047.                 self.dag = other
  2048.         else:
  2049.             self.set_downstream(other)
  2050.         return other
  2051.  
  2052.     def __lshift__(self, other):
  2053.         """
  2054.        Implements Self << Other == self.set_upstream(other)
  2055.  
  2056.        If "Other" is a DAG, the DAG is assigned to the Operator.
  2057.        """
  2058.         if isinstance(other, DAG):
  2059.             # if this dag is already assigned, do nothing
  2060.             # otherwise, do normal dag assignment
  2061.             if not (self.has_dag() and self.dag is other):
  2062.                 self.dag = other
  2063.         else:
  2064.             self.set_upstream(other)
  2065.         return other
  2066.  
  2067.     def __rrshift__(self, other):
  2068.         """
  2069.        Called for [DAG] >> [Operator] because DAGs don't have
  2070.        __rshift__ operators.
  2071.        """
  2072.         self.__lshift__(other)
  2073.         return self
  2074.  
  2075.     def __rlshift__(self, other):
  2076.         """
  2077.        Called for [DAG] << [Operator] because DAGs don't have
  2078.        __lshift__ operators.
  2079.        """
  2080.         self.__rshift__(other)
  2081.         return self
  2082.  
  2083.     # /Composing Operators ---------------------------------------------
  2084.  
  2085.     @property
  2086.     def dag(self):
  2087.         """
  2088.        Returns the Operator's DAG if set, otherwise raises an error
  2089.        """
  2090.         if self.has_dag():
  2091.             return self._dag
  2092.         else:
  2093.             raise AirflowException(
  2094.                 'Operator {} has not been assigned to a DAG yet'.format(self))
  2095.  
  2096.     @dag.setter
  2097.     def dag(self, dag):
  2098.         """
  2099.        Operators can be assigned to one DAG, one time. Repeat assignments to
  2100.        that same DAG are ok.
  2101.        """
  2102.         if not isinstance(dag, DAG):
  2103.             raise TypeError(
  2104.                 'Expected DAG; received {}'.format(dag.__class__.__name__))
  2105.         elif self.has_dag() and self.dag is not dag:
  2106.             raise AirflowException(
  2107.                 "The DAG assigned to {} can not be changed.".format(self))
  2108.         elif self.task_id not in dag.task_dict:
  2109.             dag.add_task(self)
  2110.  
  2111.         self._dag = dag
  2112.  
  2113.     def has_dag(self):
  2114.         """
  2115.        Returns True if the Operator has been assigned to a DAG.
  2116.        """
  2117.         return getattr(self, '_dag', None) is not None
  2118.  
  2119.     @property
  2120.     def dag_id(self):
  2121.         if self.has_dag():
  2122.             return self.dag.dag_id
  2123.         else:
  2124.             return 'adhoc_' + self.owner
  2125.  
  2126.     @property
  2127.     def deps(self):
  2128.         """
  2129.        Returns the list of dependencies for the operator. These differ from execution
  2130.        context dependencies in that they are specific to tasks and can be
  2131.        extended/overriden by subclasses.
  2132.        """
  2133.         return {
  2134.             NotInRetryPeriodDep(),
  2135.             PrevDagrunDep(),
  2136.             TriggerRuleDep(),
  2137.         }
  2138.  
  2139.     @property
  2140.     def schedule_interval(self):
  2141.         """
  2142.        The schedule interval of the DAG always wins over individual tasks so
  2143.        that tasks within a DAG always line up. The task still needs a
  2144.        schedule_interval as it may not be attached to a DAG.
  2145.        """
  2146.         if self.has_dag():
  2147.             return self.dag._schedule_interval
  2148.         else:
  2149.             return self._schedule_interval
  2150.  
  2151.     @property
  2152.     def priority_weight_total(self):
  2153.         return sum([
  2154.             t.priority_weight
  2155.             for t in self.get_flat_relatives(upstream=False)
  2156.         ]) + self.priority_weight
  2157.  
  2158.     def pre_execute(self, context):
  2159.         """
  2160.        This is triggered right before self.execute, it's mostly a hook
  2161.        for people deriving operators.
  2162.        """
  2163.         pass
  2164.  
  2165.     def execute(self, context):
  2166.         """
  2167.        This is the main method to derive when creating an operator.
  2168.        Context is the same dictionary used as when rendering jinja templates.
  2169.  
  2170.        Refer to get_template_context for more context.
  2171.        """
  2172.         raise NotImplementedError()
  2173.  
  2174.     def post_execute(self, context):
  2175.         """
  2176.        This is triggered right after self.execute, it's mostly a hook
  2177.        for people deriving operators.
  2178.        """
  2179.         pass
  2180.  
  2181.     def on_kill(self):
  2182.         """
  2183.        Override this method to cleanup subprocesses when a task instance
  2184.        gets killed. Any use of the threading, subprocess or multiprocessing
  2185.        module within an operator needs to be cleaned up or it will leave
  2186.        ghost processes behind.
  2187.        """
  2188.         pass
  2189.  
  2190.     def __deepcopy__(self, memo):
  2191.         """
  2192.        Hack sorting double chained task lists by task_id to avoid hitting
  2193.        max_depth on deepcopy operations.
  2194.        """
  2195.         sys.setrecursionlimit(5000)  # TODO fix this in a better way
  2196.         cls = self.__class__
  2197.         result = cls.__new__(cls)
  2198.         memo[id(self)] = result
  2199.  
  2200.         for k, v in list(self.__dict__.items()):
  2201.             if k not in ('user_defined_macros', 'params'):
  2202.                 setattr(result, k, copy.deepcopy(v, memo))
  2203.         result.params = self.params
  2204.         if hasattr(self, 'user_defined_macros'):
  2205.             result.user_defined_macros = self.user_defined_macros
  2206.         return result
  2207.  
  2208.     def render_template_from_field(self, attr, content, context, jinja_env):
  2209.         """
  2210.        Renders a template from a field. If the field is a string, it will
  2211.        simply render the string and return the result. If it is a collection or
  2212.        nested set of collections, it will traverse the structure and render
  2213.        all strings in it.
  2214.        """
  2215.         rt = self.render_template
  2216.         if isinstance(content, six.string_types):
  2217.             result = jinja_env.from_string(content).render(**context)
  2218.         elif isinstance(content, (list, tuple)):
  2219.             result = [rt(attr, e, context) for e in content]
  2220.         elif isinstance(content, dict):
  2221.             result = {
  2222.                 k: rt("{}[{}]".format(attr, k), v, context)
  2223.                 for k, v in list(content.items())}
  2224.         else:
  2225.             param_type = type(content)
  2226.             msg = (
  2227.                 "Type '{param_type}' used for parameter '{attr}' is "
  2228.                 "not supported for templating").format(**locals())
  2229.             raise AirflowException(msg)
  2230.         return result
  2231.  
  2232.     def render_template(self, attr, content, context):
  2233.         """
  2234.        Renders a template either from a file or directly in a field, and returns
  2235.        the rendered result.
  2236.        """
  2237.         jinja_env = self.dag.get_template_env() \
  2238.             if hasattr(self, 'dag') \
  2239.             else jinja2.Environment(cache_size=0)
  2240.  
  2241.         exts = self.__class__.template_ext
  2242.         if (
  2243.                 isinstance(content, six.string_types) and
  2244.                 any([content.endswith(ext) for ext in exts])):
  2245.             return jinja_env.get_template(content).render(**context)
  2246.         else:
  2247.             return self.render_template_from_field(attr, content, context, jinja_env)
  2248.  
  2249.     def prepare_template(self):
  2250.         """
  2251.        Hook that is triggered after the templated fields get replaced
  2252.        by their content. If you need your operator to alter the
  2253.        content of the file before the template is rendered,
  2254.        it should override this method to do so.
  2255.        """
  2256.         pass
  2257.  
  2258.     def resolve_template_files(self):
  2259.         # Getting the content of files for template_field / template_ext
  2260.         for attr in self.template_fields:
  2261.             content = getattr(self, attr)
  2262.             if content is not None and \
  2263.                     isinstance(content, six.string_types) and \
  2264.                     any([content.endswith(ext) for ext in self.template_ext]):
  2265.                 env = self.dag.get_template_env()
  2266.                 try:
  2267.                     setattr(self, attr, env.loader.get_source(env, content)[0])
  2268.                 except Exception as e:
  2269.                     logging.exception(e)
  2270.         self.prepare_template()
  2271.  
  2272.     @property
  2273.     def upstream_list(self):
  2274.         """@property: list of tasks directly upstream"""
  2275.         return [self.dag.get_task(tid) for tid in self._upstream_task_ids]
  2276.  
  2277.     @property
  2278.     def upstream_task_ids(self):
  2279.         return self._upstream_task_ids
  2280.  
  2281.     @property
  2282.     def downstream_list(self):
  2283.         """@property: list of tasks directly downstream"""
  2284.         return [self.dag.get_task(tid) for tid in self._downstream_task_ids]
  2285.  
  2286.     @property
  2287.     def downstream_task_ids(self):
  2288.         return self._downstream_task_ids
  2289.  
  2290.     def clear(
  2291.             self, start_date=None, end_date=None,
  2292.             upstream=False, downstream=False):
  2293.         """
  2294.        Clears the state of task instances associated with the task, following
  2295.        the parameters specified.
  2296.        """
  2297.         session = settings.Session()
  2298.  
  2299.         TI = TaskInstance
  2300.         qry = session.query(TI).filter(TI.dag_id == self.dag_id)
  2301.  
  2302.         if start_date:
  2303.             qry = qry.filter(TI.execution_date >= start_date)
  2304.         if end_date:
  2305.             qry = qry.filter(TI.execution_date <= end_date)
  2306.  
  2307.         tasks = [self.task_id]
  2308.  
  2309.         if upstream:
  2310.             tasks += [
  2311.                 t.task_id for t in self.get_flat_relatives(upstream=True)]
  2312.  
  2313.         if downstream:
  2314.             tasks += [
  2315.                 t.task_id for t in self.get_flat_relatives(upstream=False)]
  2316.  
  2317.         qry = qry.filter(TI.task_id.in_(tasks))
  2318.  
  2319.         count = qry.count()
  2320.  
  2321.         clear_task_instances(qry, session)
  2322.  
  2323.         session.commit()
  2324.         session.close()
  2325.         return count
  2326.  
  2327.     def get_task_instances(self, session, start_date=None, end_date=None):
  2328.         """
  2329.        Get a set of task instance related to this task for a specific date
  2330.        range.
  2331.        """
  2332.         TI = TaskInstance
  2333.         end_date = end_date or datetime.now()
  2334.         return session.query(TI).filter(
  2335.             TI.dag_id == self.dag_id,
  2336.             TI.task_id == self.task_id,
  2337.             TI.execution_date >= start_date,
  2338.             TI.execution_date <= end_date,
  2339.         ).order_by(TI.execution_date).all()
  2340.  
  2341.     def get_flat_relatives(self, upstream=False, l=None):
  2342.         """
  2343.        Get a flat list of relatives, either upstream or downstream.
  2344.        """
  2345.         if not l:
  2346.             l = []
  2347.         for t in self.get_direct_relatives(upstream):
  2348.             if not is_in(t, l):
  2349.                 l.append(t)
  2350.                 t.get_flat_relatives(upstream, l)
  2351.         return l
  2352.  
  2353.     def detect_downstream_cycle(self, task=None):
  2354.         """
  2355.        When invoked, this routine will raise an exception if a cycle is
  2356.        detected downstream from self. It is invoked when tasks are added to
  2357.        the DAG to detect cycles.
  2358.        """
  2359.         if not task:
  2360.             task = self
  2361.         for t in self.get_direct_relatives():
  2362.             if task is t:
  2363.                 msg = "Cycle detected in DAG. Faulty task: {0}".format(task)
  2364.                 raise AirflowException(msg)
  2365.             else:
  2366.                 t.detect_downstream_cycle(task=task)
  2367.         return False
  2368.  
  2369.     def run(
  2370.             self,
  2371.             start_date=None,
  2372.             end_date=None,
  2373.             ignore_first_depends_on_past=False,
  2374.             ignore_ti_state=False,
  2375.             mark_success=False):
  2376.         """
  2377.        Run a set of task instances for a date range.
  2378.        """
  2379.         start_date = start_date or self.start_date
  2380.         end_date = end_date or self.end_date or datetime.now()
  2381.  
  2382.         for dt in self.dag.date_range(start_date, end_date=end_date):
  2383.             TaskInstance(self, dt).run(
  2384.                 mark_success=mark_success,
  2385.                 ignore_depends_on_past=(
  2386.                     dt == start_date and ignore_first_depends_on_past),
  2387.                 ignore_ti_state=ignore_ti_state)
  2388.  
  2389.     def dry_run(self):
  2390.         logging.info('Dry run')
  2391.         for attr in self.template_fields:
  2392.             content = getattr(self, attr)
  2393.             if content and isinstance(content, six.string_types):
  2394.                 logging.info('Rendering template for {0}'.format(attr))
  2395.                 logging.info(content)
  2396.  
  2397.     def get_direct_relatives(self, upstream=False):
  2398.         """
  2399.        Get the direct relatives to the current task, upstream or
  2400.        downstream.
  2401.        """
  2402.         if upstream:
  2403.             return self.upstream_list
  2404.         else:
  2405.             return self.downstream_list
  2406.  
  2407.     def __repr__(self):
  2408.         return "<Task({self.__class__.__name__}): {self.task_id}>".format(
  2409.             self=self)
  2410.  
  2411.     @property
  2412.     def task_type(self):
  2413.         return self.__class__.__name__
  2414.  
  2415.     def append_only_new(self, l, item):
  2416.         if any([item is t for t in l]):
  2417.             raise AirflowException(
  2418.                 'Dependency {self}, {item} already registered'
  2419.                 ''.format(**locals()))
  2420.         else:
  2421.             l.append(item)
  2422.  
  2423.     def _set_relatives(self, task_or_task_list, upstream=False):
  2424.         try:
  2425.             task_list = list(task_or_task_list)
  2426.         except TypeError:
  2427.             task_list = [task_or_task_list]
  2428.  
  2429.         for t in task_list:
  2430.             if not isinstance(t, BaseOperator):
  2431.                 raise AirflowException(
  2432.                     "Relationships can only be set between "
  2433.                     "Operators; received {}".format(t.__class__.__name__))
  2434.  
  2435.         # relationships can only be set if the tasks share a single DAG. Tasks
  2436.         # without a DAG are assigned to that DAG.
  2437.         dags = set(t.dag for t in [self] + task_list if t.has_dag())
  2438.  
  2439.         if len(dags) > 1:
  2440.             raise AirflowException(
  2441.                 'Tried to set relationships between tasks in '
  2442.                 'more than one DAG: {}'.format(dags))
  2443.         elif len(dags) == 1:
  2444.             dag = list(dags)[0]
  2445.         else:
  2446.             raise AirflowException(
  2447.                 "Tried to create relationships between tasks that don't have "
  2448.                 "DAGs yet. Set the DAG for at least one "
  2449.                 "task  and try again: {}".format([self] + task_list))
  2450.  
  2451.         if dag and not self.has_dag():
  2452.             self.dag = dag
  2453.  
  2454.         for task in task_list:
  2455.             if dag and not task.has_dag():
  2456.                 task.dag = dag
  2457.             if upstream:
  2458.                 task.append_only_new(task._downstream_task_ids, self.task_id)
  2459.                 self.append_only_new(self._upstream_task_ids, task.task_id)
  2460.             else:
  2461.                 self.append_only_new(self._downstream_task_ids, task.task_id)
  2462.                 task.append_only_new(task._upstream_task_ids, self.task_id)
  2463.  
  2464.         self.detect_downstream_cycle()
  2465.  
  2466.     def set_downstream(self, task_or_task_list):
  2467.         """
  2468.        Set a task, or a task task to be directly downstream from the current
  2469.        task.
  2470.        """
  2471.         self._set_relatives(task_or_task_list, upstream=False)
  2472.  
  2473.     def set_upstream(self, task_or_task_list):
  2474.         """
  2475.        Set a task, or a task task to be directly upstream from the current
  2476.        task.
  2477.        """
  2478.         self._set_relatives(task_or_task_list, upstream=True)
  2479.  
  2480.     def xcom_push(
  2481.             self,
  2482.             context,
  2483.             key,
  2484.             value,
  2485.             execution_date=None):
  2486.         """
  2487.        See TaskInstance.xcom_push()
  2488.        """
  2489.         context['ti'].xcom_push(
  2490.             key=key,
  2491.             value=value,
  2492.             execution_date=execution_date)
  2493.  
  2494.     def xcom_pull(
  2495.             self,
  2496.             context,
  2497.             task_ids,
  2498.             dag_id=None,
  2499.             key=XCOM_RETURN_KEY,
  2500.             include_prior_dates=None):
  2501.         """
  2502.        See TaskInstance.xcom_pull()
  2503.        """
  2504.         return context['ti'].xcom_pull(
  2505.             key=key,
  2506.             task_ids=task_ids,
  2507.             dag_id=dag_id,
  2508.             include_prior_dates=include_prior_dates)
  2509.  
  2510.  
  2511. class DagModel(Base):
  2512.  
  2513.     __tablename__ = "dag"
  2514.     """
  2515.    These items are stored in the database for state related information
  2516.    """
  2517.     dag_id = Column(String(ID_LEN), primary_key=True)
  2518.     # A DAG can be paused from the UI / DB
  2519.     # Set this default value of is_paused based on a configuration value!
  2520.     is_paused_at_creation = configuration.getboolean('core',
  2521.                                                      'dags_are_paused_at_creation')
  2522.     is_paused = Column(Boolean, default=is_paused_at_creation)
  2523.     # Whether the DAG is a subdag
  2524.     is_subdag = Column(Boolean, default=False)
  2525.     # Whether that DAG was seen on the last DagBag load
  2526.     is_active = Column(Boolean, default=False)
  2527.     # Last time the scheduler started
  2528.     last_scheduler_run = Column(DateTime)
  2529.     # Last time this DAG was pickled
  2530.     last_pickled = Column(DateTime)
  2531.     # Time when the DAG last received a refresh signal
  2532.     # (e.g. the DAG's "refresh" button was clicked in the web UI)
  2533.     last_expired = Column(DateTime)
  2534.     # Whether (one  of) the scheduler is scheduling this DAG at the moment
  2535.     scheduler_lock = Column(Boolean)
  2536.     # Foreign key to the latest pickle_id
  2537.     pickle_id = Column(Integer)
  2538.     # The location of the file containing the DAG object
  2539.     fileloc = Column(String(2000))
  2540.     # String representing the owners
  2541.     owners = Column(String(2000))
  2542.  
  2543.     def __repr__(self):
  2544.         return "<DAG: {self.dag_id}>".format(self=self)
  2545.  
  2546.     @classmethod
  2547.     def get_current(cls, dag_id):
  2548.         session = settings.Session()
  2549.         obj = session.query(cls).filter(cls.dag_id == dag_id).first()
  2550.         session.expunge_all()
  2551.         session.commit()
  2552.         session.close()
  2553.         return obj
  2554.  
  2555.  
  2556. @functools.total_ordering
  2557. class DAG(BaseDag, LoggingMixin):
  2558.     """
  2559.    A dag (directed acyclic graph) is a collection of tasks with directional
  2560.    dependencies. A dag also has a schedule, a start end an end date
  2561.    (optional). For each schedule, (say daily or hourly), the DAG needs to run
  2562.    each individual tasks as their dependencies are met. Certain tasks have
  2563.    the property of depending on their own past, meaning that they can't run
  2564.    until their previous schedule (and upstream tasks) are completed.
  2565.  
  2566.    DAGs essentially act as namespaces for tasks. A task_id can only be
  2567.    added once to a DAG.
  2568.  
  2569.    :param dag_id: The id of the DAG
  2570.    :type dag_id: string
  2571.    :param description: The description for the DAG to e.g. be shown on the webserver
  2572.    :type description: string
  2573.    :param schedule_interval: Defines how often that DAG runs, this
  2574.        timedelta object gets added to your latest task instance's
  2575.        execution_date to figure out the next schedule
  2576.    :type schedule_interval: datetime.timedelta or
  2577.        dateutil.relativedelta.relativedelta or str that acts as a cron
  2578.        expression
  2579.    :param start_date: The timestamp from which the scheduler will
  2580.        attempt to backfill
  2581.    :type start_date: datetime.datetime
  2582.    :param end_date: A date beyond which your DAG won't run, leave to None
  2583.        for open ended scheduling
  2584.    :type end_date: datetime.datetime
  2585.    :param template_searchpath: This list of folders (non relative)
  2586.        defines where jinja will look for your templates. Order matters.
  2587.        Note that jinja/airflow includes the path of your DAG file by
  2588.        default
  2589.    :type template_searchpath: string or list of stings
  2590.    :param user_defined_macros: a dictionary of macros that will be exposed
  2591.        in your jinja templates. For example, passing ``dict(foo='bar')``
  2592.        to this argument allows you to ``{{ foo }}`` in all jinja
  2593.        templates related to this DAG. Note that you can pass any
  2594.        type of object here.
  2595.    :type user_defined_macros: dict
  2596.    :param default_args: A dictionary of default parameters to be used
  2597.        as constructor keyword parameters when initialising operators.
  2598.        Note that operators have the same hook, and precede those defined
  2599.        here, meaning that if your dict contains `'depends_on_past': True`
  2600.        here and `'depends_on_past': False` in the operator's call
  2601.        `default_args`, the actual value will be `False`.
  2602.    :type default_args: dict
  2603.    :param params: a dictionary of DAG level parameters that are made
  2604.        accessible in templates, namespaced under `params`. These
  2605.        params can be overridden at the task level.
  2606.    :type params: dict
  2607.    :param concurrency: the number of task instances allowed to run
  2608.        concurrently
  2609.    :type concurrency: int
  2610.    :param max_active_runs: maximum number of active DAG runs, beyond this
  2611.        number of DAG runs in a running state, the scheduler won't create
  2612.        new active DAG runs
  2613.    :type max_active_runs: int
  2614.    :param dagrun_timeout: specify how long a DagRun should be up before
  2615.        timing out / failing, so that new DagRuns can be created
  2616.    :type dagrun_timeout: datetime.timedelta
  2617.    :param sla_miss_callback: specify a function to call when reporting SLA
  2618.        timeouts.
  2619.    :type sla_miss_callback: types.FunctionType
  2620.    :param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT)
  2621.    :type orientation: string
  2622.    :param catchup: Perform scheduler catchup (or only run latest)? Defaults to True
  2623.    "type catchup: bool"
  2624.    """
  2625.  
  2626.     def __init__(
  2627.             self, dag_id,
  2628.             description='',
  2629.             schedule_interval=timedelta(days=1),
  2630.             start_date=None, end_date=None,
  2631.             full_filepath=None,
  2632.             template_searchpath=None,
  2633.             user_defined_macros=None,
  2634.             default_args=None,
  2635.             concurrency=configuration.getint('core', 'dag_concurrency'),
  2636.             max_active_runs=configuration.getint(
  2637.                 'core', 'max_active_runs_per_dag'),
  2638.             dagrun_timeout=None,
  2639.             sla_miss_callback=None,
  2640.             orientation=configuration.get('webserver', 'dag_orientation'),
  2641.             catchup=configuration.getboolean('scheduler', 'catchup_by_default'),
  2642.             params=None):
  2643.  
  2644.         self.user_defined_macros = user_defined_macros
  2645.         self.default_args = default_args or {}
  2646.         self.params = params or {}
  2647.  
  2648.         # merging potentially conflicting default_args['params'] into params
  2649.         if 'params' in self.default_args:
  2650.             self.params.update(self.default_args['params'])
  2651.             del self.default_args['params']
  2652.  
  2653.         validate_key(dag_id)
  2654.  
  2655.         # Properties from BaseDag
  2656.         self._dag_id = dag_id
  2657.         self._full_filepath = full_filepath if full_filepath else ''
  2658.         self._concurrency = concurrency
  2659.         self._pickle_id = None
  2660.  
  2661.         self._description = description
  2662.         # set file location to caller source path
  2663.         self.fileloc = inspect.getsourcefile(inspect.stack()[1][0])
  2664.         self.task_dict = dict()
  2665.         self.start_date = start_date
  2666.         self.end_date = end_date
  2667.         self.schedule_interval = schedule_interval
  2668.         if schedule_interval in cron_presets:
  2669.             self._schedule_interval = cron_presets.get(schedule_interval)
  2670.         elif schedule_interval == '@once':
  2671.             self._schedule_interval = None
  2672.         else:
  2673.             self._schedule_interval = schedule_interval
  2674.         if isinstance(template_searchpath, six.string_types):
  2675.             template_searchpath = [template_searchpath]
  2676.         self.template_searchpath = template_searchpath
  2677.         self.parent_dag = None  # Gets set when DAGs are loaded
  2678.         self.last_loaded = datetime.now()
  2679.         self.safe_dag_id = dag_id.replace('.', '__dot__')
  2680.         self.max_active_runs = max_active_runs
  2681.         self.dagrun_timeout = dagrun_timeout
  2682.         self.sla_miss_callback = sla_miss_callback
  2683.         self.orientation = orientation
  2684.         self.catchup = catchup
  2685.  
  2686.         self.partial = False
  2687.  
  2688.         self._comps = {
  2689.             'dag_id',
  2690.             'task_ids',
  2691.             'parent_dag',
  2692.             'start_date',
  2693.             'schedule_interval',
  2694.             'full_filepath',
  2695.             'template_searchpath',
  2696.             'last_loaded',
  2697.         }
  2698.  
  2699.     def __repr__(self):
  2700.         return "<DAG: {self.dag_id}>".format(self=self)
  2701.  
  2702.     def __eq__(self, other):
  2703.         return (
  2704.             type(self) == type(other) and
  2705.             # Use getattr() instead of __dict__ as __dict__ doesn't return
  2706.             # correct values for properties.
  2707.             all(getattr(self, c, None) == getattr(other, c, None)
  2708.                 for c in self._comps))
  2709.  
  2710.     def __ne__(self, other):
  2711.         return not self == other
  2712.  
  2713.     def __lt__(self, other):
  2714.         return self.dag_id < other.dag_id
  2715.  
  2716.     def __hash__(self):
  2717.         hash_components = [type(self)]
  2718.         for c in self._comps:
  2719.             # task_ids returns a list and lists can't be hashed
  2720.             if c == 'task_ids':
  2721.                 val = tuple(self.task_dict.keys())
  2722.             else:
  2723.                 val = getattr(self, c, None)
  2724.             try:
  2725.                 hash(val)
  2726.                 hash_components.append(val)
  2727.             except TypeError:
  2728.                 hash_components.append(repr(val))
  2729.         return hash(tuple(hash_components))
  2730.  
  2731.     # Context Manager -----------------------------------------------
  2732.  
  2733.     def __enter__(self):
  2734.         global _CONTEXT_MANAGER_DAG
  2735.         self._old_context_manager_dag = _CONTEXT_MANAGER_DAG
  2736.         _CONTEXT_MANAGER_DAG = self
  2737.         return self
  2738.  
  2739.     def __exit__(self, _type, _value, _tb):
  2740.         global _CONTEXT_MANAGER_DAG
  2741.         _CONTEXT_MANAGER_DAG = self._old_context_manager_dag
  2742.  
  2743.     # /Context Manager ----------------------------------------------
  2744.  
  2745.     def date_range(self, start_date, num=None, end_date=datetime.now()):
  2746.         if num:
  2747.             end_date = None
  2748.         return utils_date_range(
  2749.             start_date=start_date, end_date=end_date,
  2750.             num=num, delta=self._schedule_interval)
  2751.  
  2752.     def following_schedule(self, dttm):
  2753.         if isinstance(self._schedule_interval, six.string_types):
  2754.             cron = croniter(self._schedule_interval, dttm)
  2755.             return cron.get_next(datetime)
  2756.         elif isinstance(self._schedule_interval, timedelta):
  2757.             return dttm + self._schedule_interval
  2758.  
  2759.     def previous_schedule(self, dttm):
  2760.         if isinstance(self._schedule_interval, six.string_types):
  2761.             cron = croniter(self._schedule_interval, dttm)
  2762.             return cron.get_prev(datetime)
  2763.         elif isinstance(self._schedule_interval, timedelta):
  2764.             return dttm - self._schedule_interval
  2765.  
  2766.     def normalize_schedule(self, dttm):
  2767.         """
  2768.        Returns dttm + interval unless dttm is first interval then it returns dttm
  2769.        """
  2770.         following = self.following_schedule(dttm)
  2771.  
  2772.         # in case of @once
  2773.         if not following:
  2774.             return dttm
  2775.  
  2776.         if self.previous_schedule(following) != dttm:
  2777.             return following
  2778.  
  2779.         return dttm
  2780.  
  2781.     @provide_session
  2782.     def get_last_dagrun(self, session=None, include_externally_triggered=False):
  2783.         """
  2784.        Returns the last dag run for this dag, None if there was none.
  2785.        Last dag run can be any type of run eg. scheduled or backfilled.
  2786.        Overriden DagRuns are ignored
  2787.        """
  2788.         DR = DagRun
  2789.         qry = session.query(DR).filter(
  2790.             DR.dag_id == self.dag_id,
  2791.         )
  2792.         if not include_externally_triggered:
  2793.             qry = qry.filter(DR.external_trigger.is_(False))
  2794.  
  2795.         qry = qry.order_by(DR.execution_date.desc())
  2796.  
  2797.         last = qry.first()
  2798.  
  2799.         return last
  2800.  
  2801.     @property
  2802.     def dag_id(self):
  2803.         return self._dag_id
  2804.  
  2805.     @dag_id.setter
  2806.     def dag_id(self, value):
  2807.         self._dag_id = value
  2808.  
  2809.     @property
  2810.     def full_filepath(self):
  2811.         return self._full_filepath
  2812.  
  2813.     @full_filepath.setter
  2814.     def full_filepath(self, value):
  2815.         self._full_filepath = value
  2816.  
  2817.     @property
  2818.     def concurrency(self):
  2819.         return self._concurrency
  2820.  
  2821.     @concurrency.setter
  2822.     def concurrency(self, value):
  2823.         self._concurrency = value
  2824.  
  2825.     @property
  2826.     def description(self):
  2827.         return self._description
  2828.  
  2829.     @property
  2830.     def pickle_id(self):
  2831.         return self._pickle_id
  2832.  
  2833.     @pickle_id.setter
  2834.     def pickle_id(self, value):
  2835.         self._pickle_id = value
  2836.  
  2837.     @property
  2838.     def tasks(self):
  2839.         return list(self.task_dict.values())
  2840.  
  2841.     @tasks.setter
  2842.     def tasks(self, val):
  2843.         raise AttributeError(
  2844.             'DAG.tasks can not be modified. Use dag.add_task() instead.')
  2845.  
  2846.     @property
  2847.     def task_ids(self):
  2848.         return list(self.task_dict.keys())
  2849.  
  2850.     @property
  2851.     def active_task_ids(self):
  2852.         return list(k for k, v in self.task_dict.items() if not v.adhoc)
  2853.  
  2854.     @property
  2855.     def active_tasks(self):
  2856.         return [t for t in self.tasks if not t.adhoc]
  2857.  
  2858.     @property
  2859.     def filepath(self):
  2860.         """
  2861.        File location of where the dag object is instantiated
  2862.        """
  2863.         fn = self.full_filepath.replace(settings.DAGS_FOLDER + '/', '')
  2864.         fn = fn.replace(os.path.dirname(__file__) + '/', '')
  2865.         return fn
  2866.  
  2867.     @property
  2868.     def folder(self):
  2869.         """
  2870.        Folder location of where the dag object is instantiated
  2871.        """
  2872.         return os.path.dirname(self.full_filepath)
  2873.  
  2874.     @property
  2875.     def owner(self):
  2876.         return ", ".join(list(set([t.owner for t in self.tasks])))
  2877.  
  2878.     @property
  2879.     @provide_session
  2880.     def concurrency_reached(self, session=None):
  2881.         """
  2882.        Returns a boolean indicating whether the concurrency limit for this DAG
  2883.        has been reached
  2884.        """
  2885.         TI = TaskInstance
  2886.         qry = session.query(func.count(TI.task_id)).filter(
  2887.             TI.dag_id == self.dag_id,
  2888.             TI.task_id.in_(self.task_ids),
  2889.             TI.state == State.RUNNING,
  2890.         )
  2891.         return qry.scalar() >= self.concurrency
  2892.  
  2893.     @property
  2894.     @provide_session
  2895.     def is_paused(self, session=None):
  2896.         """
  2897.        Returns a boolean indicating whether this DAG is paused
  2898.        """
  2899.         qry = session.query(DagModel).filter(
  2900.             DagModel.dag_id == self.dag_id)
  2901.         return qry.value('is_paused')
  2902.  
  2903.     @provide_session
  2904.     def get_active_runs(self, session=None):
  2905.         """
  2906.        Returns a list of "running" tasks
  2907.        :param session:
  2908.        :return: List of execution dates
  2909.        """
  2910.         runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
  2911.  
  2912.         active_dates = []
  2913.         for run in runs:
  2914.             active_dates.append(run.execution_date)
  2915.  
  2916.         return active_dates
  2917.  
  2918.     @provide_session
  2919.     def get_dagrun(self, execution_date, session=None):
  2920.         """
  2921.        Returns the dag run for a given execution date if it exists, otherwise
  2922.        none.
  2923.        :param execution_date: The execution date of the DagRun to find.
  2924.        :param session:
  2925.        :return: The DagRun if found, otherwise None.
  2926.        """
  2927.         dagrun = (
  2928.             session.query(DagRun)
  2929.             .filter(
  2930.                 DagRun.dag_id == self.dag_id,
  2931.                 DagRun.execution_date == execution_date)
  2932.             .first())
  2933.  
  2934.         return dagrun
  2935.  
  2936.     @property
  2937.     def latest_execution_date(self):
  2938.         """
  2939.        Returns the latest date for which at least one dag run exists
  2940.        """
  2941.         session = settings.Session()
  2942.         execution_date = session.query(func.max(DagRun.execution_date)).filter(
  2943.             DagRun.dag_id == self.dag_id
  2944.         ).scalar()
  2945.         session.commit()
  2946.         session.close()
  2947.         return execution_date
  2948.  
  2949.     @property
  2950.     def subdags(self):
  2951.         """
  2952.        Returns a list of the subdag objects associated to this DAG
  2953.        """
  2954.         # Check SubDag for class but don't check class directly, see
  2955.         # https://github.com/airbnb/airflow/issues/1168
  2956.         l = []
  2957.         for task in self.tasks:
  2958.             if (
  2959.                     task.__class__.__name__ == 'SubDagOperator' and
  2960.                     hasattr(task, 'subdag')):
  2961.                 l.append(task.subdag)
  2962.                 l += task.subdag.subdags
  2963.         return l
  2964.  
  2965.     def resolve_template_files(self):
  2966.         for t in self.tasks:
  2967.             t.resolve_template_files()
  2968.  
  2969.     def crawl_for_tasks(objects):
  2970.         """
  2971.        Typically called at the end of a script by passing globals() as a
  2972.        parameter. This allows to not explicitly add every single task to the
  2973.        dag explicitly.
  2974.        """
  2975.         raise NotImplementedError("")
  2976.  
  2977.     def get_template_env(self):
  2978.         """
  2979.        Returns a jinja2 Environment while taking into account the DAGs
  2980.        template_searchpath and user_defined_macros
  2981.        """
  2982.         searchpath = [self.folder]
  2983.         if self.template_searchpath:
  2984.             searchpath += self.template_searchpath
  2985.  
  2986.         env = jinja2.Environment(
  2987.             loader=jinja2.FileSystemLoader(searchpath),
  2988.             extensions=["jinja2.ext.do"],
  2989.             cache_size=0)
  2990.         if self.user_defined_macros:
  2991.             env.globals.update(self.user_defined_macros)
  2992.  
  2993.         return env
  2994.  
  2995.     def set_dependency(self, upstream_task_id, downstream_task_id):
  2996.         """
  2997.        Simple utility method to set dependency between two tasks that
  2998.        already have been added to the DAG using add_task()
  2999.        """
  3000.         self.get_task(upstream_task_id).set_downstream(
  3001.             self.get_task(downstream_task_id))
  3002.  
  3003.     def get_task_instances(
  3004.             self, session, start_date=None, end_date=None, state=None):
  3005.         TI = TaskInstance
  3006.         if not start_date:
  3007.             start_date = (datetime.today() - timedelta(30)).date()
  3008.             start_date = datetime.combine(start_date, datetime.min.time())
  3009.         end_date = end_date or datetime.now()
  3010.         tis = session.query(TI).filter(
  3011.             TI.dag_id == self.dag_id,
  3012.             TI.execution_date >= start_date,
  3013.             TI.execution_date <= end_date,
  3014.             TI.task_id.in_([t.task_id for t in self.tasks]),
  3015.         )
  3016.         if state:
  3017.             tis = tis.filter(TI.state == state)
  3018.         tis = tis.all()
  3019.         return tis
  3020.  
  3021.     @property
  3022.     def roots(self):
  3023.         return [t for t in self.tasks if not t.downstream_list]
  3024.  
  3025.     def topological_sort(self):
  3026.         """
  3027.        Sorts tasks in topographical order, such that a task comes after any of its
  3028.        upstream dependencies.
  3029.  
  3030.        Heavily inspired by:
  3031.        http://blog.jupo.org/2012/04/06/topological-sorting-acyclic-directed-graphs/
  3032.        :returns: list of tasks in topological order
  3033.        """
  3034.  
  3035.         # copy the the tasks so we leave it unmodified
  3036.         graph_unsorted = self.tasks[:]
  3037.  
  3038.         graph_sorted = []
  3039.  
  3040.         # special case
  3041.         if len(self.tasks) == 0:
  3042.             return tuple(graph_sorted)
  3043.  
  3044.         # Run until the unsorted graph is empty.
  3045.         while graph_unsorted:
  3046.             # Go through each of the node/edges pairs in the unsorted
  3047.             # graph. If a set of edges doesn't contain any nodes that
  3048.             # haven't been resolved, that is, that are still in the
  3049.             # unsorted graph, remove the pair from the unsorted graph,
  3050.             # and append it to the sorted graph. Note here that by using
  3051.             # using the items() method for iterating, a copy of the
  3052.             # unsorted graph is used, allowing us to modify the unsorted
  3053.             # graph as we move through it. We also keep a flag for
  3054.             # checking that that graph is acyclic, which is true if any
  3055.             # nodes are resolved during each pass through the graph. If
  3056.             # not, we need to bail out as the graph therefore can't be
  3057.             # sorted.
  3058.             acyclic = False
  3059.             for node in list(graph_unsorted):
  3060.                 for edge in node.upstream_list:
  3061.                     if edge in graph_unsorted:
  3062.                         break
  3063.                 # no edges in upstream tasks
  3064.                 else:
  3065.                     acyclic = True
  3066.                     graph_unsorted.remove(node)
  3067.                     graph_sorted.append(node)
  3068.  
  3069.             if not acyclic:
  3070.                 raise AirflowException("A cyclic dependency occurred in dag: {}"
  3071.                                        .format(self.dag_id))
  3072.  
  3073.         return tuple(graph_sorted)
  3074.  
  3075.     @provide_session
  3076.     def set_dag_runs_state(
  3077.             self, state=State.RUNNING, session=None):
  3078.         drs = session.query(DagModel).filter_by(dag_id=self.dag_id).all()
  3079.         dirty_ids = []
  3080.         for dr in drs:
  3081.             dr.state = state
  3082.             dirty_ids.append(dr.dag_id)
  3083.         DagStat.clean_dirty(dirty_ids, session=session)
  3084.  
  3085.     def clear(
  3086.             self, start_date=None, end_date=None,
  3087.             only_failed=False,
  3088.             only_running=False,
  3089.             confirm_prompt=False,
  3090.             include_subdags=True,
  3091.             reset_dag_runs=True,
  3092.             dry_run=False):
  3093.         """
  3094.        Clears a set of task instances associated with the current dag for
  3095.        a specified date range.
  3096.        """
  3097.         session = settings.Session()
  3098.         TI = TaskInstance
  3099.         tis = session.query(TI)
  3100.         if include_subdags:
  3101.             # Crafting the right filter for dag_id and task_ids combo
  3102.             conditions = []
  3103.             for dag in self.subdags + [self]:
  3104.                 conditions.append(
  3105.                     TI.dag_id.like(dag.dag_id) & TI.task_id.in_(dag.task_ids)
  3106.                 )
  3107.             tis = tis.filter(or_(*conditions))
  3108.         else:
  3109.             tis = session.query(TI).filter(TI.dag_id == self.dag_id)
  3110.             tis = tis.filter(TI.task_id.in_(self.task_ids))
  3111.  
  3112.         if start_date:
  3113.             tis = tis.filter(TI.execution_date >= start_date)
  3114.         if end_date:
  3115.             tis = tis.filter(TI.execution_date <= end_date)
  3116.         if only_failed:
  3117.             tis = tis.filter(TI.state == State.FAILED)
  3118.         if only_running:
  3119.             tis = tis.filter(TI.state == State.RUNNING)
  3120.  
  3121.         if dry_run:
  3122.             tis = tis.all()
  3123.             session.expunge_all()
  3124.             return tis
  3125.  
  3126.         count = tis.count()
  3127.         do_it = True
  3128.         if count == 0:
  3129.             print("Nothing to clear.")
  3130.             return 0
  3131.         if confirm_prompt:
  3132.             ti_list = "\n".join([str(t) for t in tis])
  3133.             question = (
  3134.                 "You are about to delete these {count} tasks:\n"
  3135.                 "{ti_list}\n\n"
  3136.                 "Are you sure? (yes/no): ").format(**locals())
  3137.             do_it = utils.helpers.ask_yesno(question)
  3138.  
  3139.         if do_it:
  3140.             clear_task_instances(tis, session)
  3141.             if reset_dag_runs:
  3142.                 self.set_dag_runs_state(session=session)
  3143.         else:
  3144.             count = 0
  3145.             print("Bail. Nothing was cleared.")
  3146.  
  3147.         session.commit()
  3148.         session.close()
  3149.         return count
  3150.  
  3151.     def __deepcopy__(self, memo):
  3152.         # Swiwtcharoo to go around deepcopying objects coming through the
  3153.         # backdoor
  3154.         cls = self.__class__
  3155.         result = cls.__new__(cls)
  3156.         memo[id(self)] = result
  3157.         for k, v in list(self.__dict__.items()):
  3158.             if k not in ('user_defined_macros', 'params'):
  3159.                 setattr(result, k, copy.deepcopy(v, memo))
  3160.  
  3161.         result.user_defined_macros = self.user_defined_macros
  3162.         result.params = self.params
  3163.         return result
  3164.  
  3165.     def sub_dag(self, task_regex, include_downstream=False,
  3166.                 include_upstream=True):
  3167.         """
  3168.        Returns a subset of the current dag as a deep copy of the current dag
  3169.        based on a regex that should match one or many tasks, and includes
  3170.        upstream and downstream neighbours based on the flag passed.
  3171.        """
  3172.  
  3173.         dag = copy.deepcopy(self)
  3174.  
  3175.         regex_match = [
  3176.             t for t in dag.tasks if re.findall(task_regex, t.task_id)]
  3177.         also_include = []
  3178.         for t in regex_match:
  3179.             if include_downstream:
  3180.                 also_include += t.get_flat_relatives(upstream=False)
  3181.             if include_upstream:
  3182.                 also_include += t.get_flat_relatives(upstream=True)
  3183.  
  3184.         # Compiling the unique list of tasks that made the cut
  3185.         dag.task_dict = {t.task_id: t for t in regex_match + also_include}
  3186.         for t in dag.tasks:
  3187.             # Removing upstream/downstream references to tasks that did not
  3188.             # made the cut
  3189.             t._upstream_task_ids = [
  3190.                 tid for tid in t._upstream_task_ids if tid in dag.task_ids]
  3191.             t._downstream_task_ids = [
  3192.                 tid for tid in t._downstream_task_ids if tid in dag.task_ids]
  3193.  
  3194.         if len(dag.tasks) < len(self.tasks):
  3195.             dag.partial = True
  3196.  
  3197.         return dag
  3198.  
  3199.     def has_task(self, task_id):
  3200.         return task_id in (t.task_id for t in self.tasks)
  3201.  
  3202.     def get_task(self, task_id):
  3203.         if task_id in self.task_dict:
  3204.             return self.task_dict[task_id]
  3205.         raise AirflowException("Task {task_id} not found".format(**locals()))
  3206.  
  3207.     @provide_session
  3208.     def pickle_info(self, session=None):
  3209.         d = {}
  3210.         d['is_picklable'] = True
  3211.         try:
  3212.             dttm = datetime.now()
  3213.             pickled = pickle.dumps(self)
  3214.             d['pickle_len'] = len(pickled)
  3215.             d['pickling_duration'] = "{}".format(datetime.now() - dttm)
  3216.         except Exception as e:
  3217.             logging.debug(e)
  3218.             d['is_picklable'] = False
  3219.             d['stacktrace'] = traceback.format_exc()
  3220.         return d
  3221.  
  3222.     @provide_session
  3223.     def pickle(self, session=None):
  3224.         dag = session.query(
  3225.             DagModel).filter(DagModel.dag_id == self.dag_id).first()
  3226.         dp = None
  3227.         if dag and dag.pickle_id:
  3228.             dp = session.query(DagPickle).filter(
  3229.                 DagPickle.id == dag.pickle_id).first()
  3230.         if not dp or dp.pickle != self:
  3231.             dp = DagPickle(dag=self)
  3232.             session.add(dp)
  3233.             self.last_pickled = datetime.now()
  3234.             session.commit()
  3235.             self.pickle_id = dp.id
  3236.  
  3237.         return dp
  3238.  
  3239.     def tree_view(self):
  3240.         """
  3241.        Shows an ascii tree representation of the DAG
  3242.        """
  3243.         def get_downstream(task, level=0):
  3244.             print((" " * level * 4) + str(task))
  3245.             level += 1
  3246.             for t in task.upstream_list:
  3247.                 get_downstream(t, level)
  3248.  
  3249.         for t in self.roots:
  3250.             get_downstream(t)
  3251.  
  3252.     def add_task(self, task):
  3253.         """
  3254.        Add a task to the DAG
  3255.  
  3256.        :param task: the task you want to add
  3257.        :type task: task
  3258.        """
  3259.         if not self.start_date and not task.start_date:
  3260.             raise AirflowException("Task is missing the start_date parameter")
  3261.         if not task.start_date:
  3262.             task.start_date = self.start_date
  3263.  
  3264.         if task.task_id in self.task_dict:
  3265.             # TODO: raise an error in Airflow 2.0
  3266.             warnings.warn(
  3267.                 'The requested task could not be added to the DAG because a '
  3268.                 'task with task_id {} is already in the DAG. Starting in '
  3269.                 'Airflow 2.0, trying to overwrite a task will raise an '
  3270.                 'exception.'.format(task.task_id),
  3271.                 category=PendingDeprecationWarning)
  3272.         else:
  3273.             self.tasks.append(task)
  3274.             self.task_dict[task.task_id] = task
  3275.             task.dag = self
  3276.  
  3277.         self.task_count = len(self.tasks)
  3278.  
  3279.     def add_tasks(self, tasks):
  3280.         """
  3281.        Add a list of tasks to the DAG
  3282.  
  3283.        :param task: a lit of tasks you want to add
  3284.        :type task: list of tasks
  3285.        """
  3286.         for task in tasks:
  3287.             self.add_task(task)
  3288.  
  3289.     def db_merge(self):
  3290.         BO = BaseOperator
  3291.         session = settings.Session()
  3292.         tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
  3293.         for t in tasks:
  3294.             session.delete(t)
  3295.         session.commit()
  3296.         session.merge(self)
  3297.         session.commit()
  3298.  
  3299.     def run(
  3300.             self,
  3301.             start_date=None,
  3302.             end_date=None,
  3303.             mark_success=False,
  3304.             include_adhoc=False,
  3305.             local=False,
  3306.             executor=None,
  3307.             donot_pickle=configuration.getboolean('core', 'donot_pickle'),
  3308.             ignore_task_deps=False,
  3309.             ignore_first_depends_on_past=False,
  3310.             pool=None):
  3311.         """
  3312.        Runs the DAG.
  3313.        """
  3314.         from airflow.jobs import BackfillJob
  3315.         if not executor and local:
  3316.             executor = LocalExecutor()
  3317.         elif not executor:
  3318.             executor = DEFAULT_EXECUTOR
  3319.         job = BackfillJob(
  3320.             self,
  3321.             start_date=start_date,
  3322.             end_date=end_date,
  3323.             mark_success=mark_success,
  3324.             include_adhoc=include_adhoc,
  3325.             executor=executor,
  3326.             donot_pickle=donot_pickle,
  3327.             ignore_task_deps=ignore_task_deps,
  3328.             ignore_first_depends_on_past=ignore_first_depends_on_past,
  3329.             pool=pool)
  3330.         job.run()
  3331.  
  3332.     def cli(self):
  3333.         """
  3334.        Exposes a CLI specific to this DAG
  3335.        """
  3336.         from airflow.bin import cli
  3337.         parser = cli.CLIFactory.get_parser(dag_parser=True)
  3338.         args = parser.parse_args()
  3339.         args.func(args, self)
  3340.  
  3341.     @provide_session
  3342.     def create_dagrun(self,
  3343.                       run_id,
  3344.                       state,
  3345.                       execution_date=None,
  3346.                       start_date=None,
  3347.                       external_trigger=False,
  3348.                       conf=None,
  3349.                       session=None):
  3350.         """
  3351.        Creates a dag run from this dag including the tasks associated with this dag.
  3352.        Returns the dag run.
  3353.  
  3354.        :param run_id: defines the the run id for this dag run
  3355.        :type run_id: string
  3356.        :param execution_date: the execution date of this dag run
  3357.        :type execution_date: datetime
  3358.        :param state: the state of the dag run
  3359.        :type state: State
  3360.        :param start_date: the date this dag run should be evaluated
  3361.        :type start_date: datetime
  3362.        :param external_trigger: whether this dag run is externally triggered
  3363.        :type external_trigger: bool
  3364.        :param session: database session
  3365.        :type session: Session
  3366.        """
  3367.         run = DagRun(
  3368.             dag_id=self.dag_id,
  3369.             run_id=run_id,
  3370.             execution_date=execution_date,
  3371.             start_date=start_date,
  3372.             external_trigger=external_trigger,
  3373.             conf=conf,
  3374.             state=state
  3375.         )
  3376.         session.add(run)
  3377.         session.commit()
  3378.  
  3379.         run.dag = self
  3380.  
  3381.         # create the associated task instances
  3382.         # state is None at the moment of creation
  3383.         run.verify_integrity(session=session)
  3384.  
  3385.         run.refresh_from_db()
  3386.         DagStat.set_dirty(self.dag_id, session=session)
  3387.  
  3388.         # add a placeholder row into DagStat table
  3389.         if not session.query(DagStat).filter(DagStat.dag_id == self.dag_id).first():
  3390.             session.add(DagStat(dag_id=self.dag_id, state=state, count=0, dirty=True))
  3391.         session.commit()
  3392.         return run
  3393.  
  3394.     @staticmethod
  3395.     @provide_session
  3396.     def sync_to_db(dag, owner, sync_time, session=None):
  3397.         """
  3398.        Save attributes about this DAG to the DB. Note that this method
  3399.        can be called for both DAGs and SubDAGs. A SubDag is actually a
  3400.        SubDagOperator.
  3401.  
  3402.        :param dag: the DAG object to save to the DB
  3403.        :type dag: DAG
  3404.        :own
  3405.        :param sync_time: The time that the DAG should be marked as sync'ed
  3406.        :type sync_time: datetime
  3407.        :return: None
  3408.        """
  3409.         orm_dag = session.query(
  3410.             DagModel).filter(DagModel.dag_id == dag.dag_id).first()
  3411.         if not orm_dag:
  3412.             orm_dag = DagModel(dag_id=dag.dag_id)
  3413.             logging.info("Creating ORM DAG for %s",
  3414.                          dag.dag_id)
  3415.         orm_dag.fileloc = dag.fileloc
  3416.         orm_dag.is_subdag = dag.is_subdag
  3417.         orm_dag.owners = owner
  3418.         orm_dag.is_active = True
  3419.         orm_dag.last_scheduler_run = sync_time
  3420.         session.merge(orm_dag)
  3421.         session.commit()
  3422.  
  3423.         for subdag in dag.subdags:
  3424.             DAG.sync_to_db(subdag, owner, sync_time, session=session)
  3425.  
  3426.     @staticmethod
  3427.     @provide_session
  3428.     def deactivate_unknown_dags(active_dag_ids, session=None):
  3429.         """
  3430.        Given a list of known DAGs, deactivate any other DAGs that are
  3431.        marked as active in the ORM
  3432.  
  3433.        :param active_dag_ids: list of DAG IDs that are active
  3434.        :type active_dag_ids: list[unicode]
  3435.        :return: None
  3436.        """
  3437.  
  3438.         if len(active_dag_ids) == 0:
  3439.             return
  3440.         for dag in session.query(
  3441.                 DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
  3442.             dag.is_active = False
  3443.             session.merge(dag)
  3444.  
  3445.     @staticmethod
  3446.     @provide_session
  3447.     def deactivate_stale_dags(expiration_date, session=None):
  3448.         """
  3449.        Deactivate any DAGs that were last touched by the scheduler before
  3450.        the expiration date. These DAGs were likely deleted.
  3451.  
  3452.        :param expiration_date: set inactive DAGs that were touched before this
  3453.        time
  3454.        :type expiration_date: datetime
  3455.        :return: None
  3456.        """
  3457.         for dag in session.query(
  3458.                 DagModel).filter(DagModel.last_scheduler_run < expiration_date,
  3459.                                  DagModel.is_active).all():
  3460.             logging.info("Deactivating DAG ID %s since it was last touched "
  3461.                          "by the scheduler at %s",
  3462.                          dag.dag_id,
  3463.                          dag.last_scheduler_run.isoformat())
  3464.             dag.is_active = False
  3465.             session.merge(dag)
  3466.             session.commit()
  3467.  
  3468.  
  3469. class Chart(Base):
  3470.     __tablename__ = "chart"
  3471.  
  3472.     id = Column(Integer, primary_key=True)
  3473.     label = Column(String(200))
  3474.     conn_id = Column(String(ID_LEN), nullable=False)
  3475.     user_id = Column(Integer(), ForeignKey('users.id'), nullable=True)
  3476.     chart_type = Column(String(100), default="line")
  3477.     sql_layout = Column(String(50), default="series")
  3478.     sql = Column(Text, default="SELECT series, x, y FROM table")
  3479.     y_log_scale = Column(Boolean)
  3480.     show_datatable = Column(Boolean)
  3481.     show_sql = Column(Boolean, default=True)
  3482.     height = Column(Integer, default=600)
  3483.     default_params = Column(String(5000), default="{}")
  3484.     owner = relationship(
  3485.         "User", cascade=False, cascade_backrefs=False, backref='charts')
  3486.     x_is_date = Column(Boolean, default=True)
  3487.     iteration_no = Column(Integer, default=0)
  3488.     last_modified = Column(DateTime, default=func.now())
  3489.  
  3490.     def __repr__(self):
  3491.         return self.label
  3492.  
  3493.  
  3494. class KnownEventType(Base):
  3495.     __tablename__ = "known_event_type"
  3496.  
  3497.     id = Column(Integer, primary_key=True)
  3498.     know_event_type = Column(String(200))
  3499.  
  3500.     def __repr__(self):
  3501.         return self.know_event_type
  3502.  
  3503.  
  3504. class KnownEvent(Base):
  3505.     __tablename__ = "known_event"
  3506.  
  3507.     id = Column(Integer, primary_key=True)
  3508.     label = Column(String(200))
  3509.     start_date = Column(DateTime)
  3510.     end_date = Column(DateTime)
  3511.     user_id = Column(Integer(), ForeignKey('users.id'),)
  3512.     known_event_type_id = Column(Integer(), ForeignKey('known_event_type.id'),)
  3513.     reported_by = relationship(
  3514.         "User", cascade=False, cascade_backrefs=False, backref='known_events')
  3515.     event_type = relationship(
  3516.         "KnownEventType",
  3517.         cascade=False,
  3518.         cascade_backrefs=False, backref='known_events')
  3519.     description = Column(Text)
  3520.  
  3521.     def __repr__(self):
  3522.         return self.label
  3523.  
  3524.  
  3525. class Variable(Base):
  3526.     __tablename__ = "variable"
  3527.  
  3528.     id = Column(Integer, primary_key=True)
  3529.     key = Column(String(ID_LEN), unique=True)
  3530.     _val = Column('val', Text)
  3531.     is_encrypted = Column(Boolean, unique=False, default=False)
  3532.  
  3533.     def __repr__(self):
  3534.         # Hiding the value
  3535.         return '{} : {}'.format(self.key, self._val)
  3536.  
  3537.     def get_val(self):
  3538.         if self._val and self.is_encrypted:
  3539.             if not ENCRYPTION_ON:
  3540.                 raise AirflowException(
  3541.                     "Can't decrypt _val for key={}, FERNET_KEY configuration \
  3542.                    missing".format(self.key))
  3543.             return FERNET.decrypt(bytes(self._val, 'utf-8')).decode()
  3544.         else:
  3545.             return self._val
  3546.  
  3547.     def set_val(self, value):
  3548.         if value:
  3549.             try:
  3550.                 self._val = FERNET.encrypt(bytes(value, 'utf-8')).decode()
  3551.                 self.is_encrypted = True
  3552.             except NameError:
  3553.                 self._val = value
  3554.                 self.is_encrypted = False
  3555.  
  3556.     @declared_attr
  3557.     def val(cls):
  3558.         return synonym('_val',
  3559.                        descriptor=property(cls.get_val, cls.set_val))
  3560.  
  3561.     @classmethod
  3562.     def setdefault(cls, key, default, deserialize_json=False):
  3563.         """
  3564.        Like a Python builtin dict object, setdefault returns the current value
  3565.        for a key, and if it isn't there, stores the default value and returns it.
  3566.  
  3567.        :param key: Dict key for this Variable
  3568.        :type key: String
  3569.        :param: default: Default value to set and return if the variable
  3570.        isn't already in the DB
  3571.        :type: default: Mixed
  3572.        :param: deserialize_json: Store this as a JSON encoded value in the DB
  3573.         and un-encode it when retrieving a value
  3574.        :return: Mixed
  3575.        """
  3576.         default_sentinel = object()
  3577.         obj = Variable.get(key, default_var=default_sentinel, deserialize_json=False)
  3578.         if obj is default_sentinel:
  3579.             if default is not None:
  3580.                 Variable.set(key, default, serialize_json=deserialize_json)
  3581.                 return default
  3582.             else:
  3583.                 raise ValueError('Default Value must be set')
  3584.         else:
  3585.             if deserialize_json:
  3586.                 return json.loads(obj.val)
  3587.             else:
  3588.                 return obj.val
  3589.  
  3590.     @classmethod
  3591.     @provide_session
  3592.     def get(cls, key, default_var=None, deserialize_json=False, session=None):
  3593.         obj = session.query(cls).filter(cls.key == key).first()
  3594.         if obj is None:
  3595.             if default_var is not None:
  3596.                 return default_var
  3597.             else:
  3598.                 raise KeyError('Variable {} does not exist'.format(key))
  3599.         else:
  3600.             if deserialize_json:
  3601.                 return json.loads(obj.val)
  3602.             else:
  3603.                 return obj.val
  3604.  
  3605.     @classmethod
  3606.     @provide_session
  3607.     def set(cls, key, value, serialize_json=False, session=None):
  3608.  
  3609.         if serialize_json:
  3610.             stored_value = json.dumps(value)
  3611.         else:
  3612.             stored_value = value
  3613.  
  3614.         session.query(cls).filter(cls.key == key).delete()
  3615.         session.add(Variable(key=key, val=stored_value))
  3616.         session.flush()
  3617.  
  3618.  
  3619. class XCom(Base):
  3620.     """
  3621.    Base class for XCom objects.
  3622.    """
  3623.     __tablename__ = "xcom"
  3624.  
  3625.     id = Column(Integer, primary_key=True)
  3626.     key = Column(String(512))
  3627.     value = Column(PickleType(pickler=dill))
  3628.     timestamp = Column(
  3629.         DateTime, default=func.now(), nullable=False)
  3630.     execution_date = Column(DateTime, nullable=False)
  3631.  
  3632.     # source information
  3633.     task_id = Column(String(ID_LEN), nullable=False)
  3634.     dag_id = Column(String(ID_LEN), nullable=False)
  3635.  
  3636.     __table_args__ = (
  3637.         Index('idx_xcom_dag_task_date', dag_id, task_id, execution_date, unique=False),
  3638.     )
  3639.  
  3640.     def __repr__(self):
  3641.         return '<XCom "{key}" ({task_id} @ {execution_date})>'.format(
  3642.             key=self.key,
  3643.             task_id=self.task_id,
  3644.             execution_date=self.execution_date)
  3645.  
  3646.     @classmethod
  3647.     @provide_session
  3648.     def set(
  3649.             cls,
  3650.             key,
  3651.             value,
  3652.             execution_date,
  3653.             task_id,
  3654.             dag_id,
  3655.             session=None):
  3656.         """
  3657.        Store an XCom value.
  3658.        """
  3659.         session.expunge_all()
  3660.  
  3661.         # remove any duplicate XComs
  3662.         session.query(cls).filter(
  3663.             cls.key == key,
  3664.             cls.execution_date == execution_date,
  3665.             cls.task_id == task_id,
  3666.             cls.dag_id == dag_id).delete()
  3667.  
  3668.         session.commit()
  3669.  
  3670.         # insert new XCom
  3671.         session.add(XCom(
  3672.             key=key,
  3673.             value=value,
  3674.             execution_date=execution_date,
  3675.             task_id=task_id,
  3676.             dag_id=dag_id))
  3677.  
  3678.         session.commit()
  3679.  
  3680.     @classmethod
  3681.     @provide_session
  3682.     def get_one(
  3683.             cls,
  3684.             execution_date,
  3685.             key=None,
  3686.             task_id=None,
  3687.             dag_id=None,
  3688.             include_prior_dates=False,
  3689.             session=None):
  3690.         """
  3691.        Retrieve an XCom value, optionally meeting certain criteria
  3692.        """
  3693.         filters = []
  3694.         if key:
  3695.             filters.append(cls.key == key)
  3696.         if task_id:
  3697.             filters.append(cls.task_id == task_id)
  3698.         if dag_id:
  3699.             filters.append(cls.dag_id == dag_id)
  3700.         if include_prior_dates:
  3701.             filters.append(cls.execution_date <= execution_date)
  3702.         else:
  3703.             filters.append(cls.execution_date == execution_date)
  3704.  
  3705.         query = (
  3706.             session.query(cls.value)
  3707.             .filter(and_(*filters))
  3708.             .order_by(cls.execution_date.desc(), cls.timestamp.desc())
  3709.             .limit(1))
  3710.  
  3711.         result = query.first()
  3712.         if result:
  3713.             return result.value
  3714.  
  3715.     @classmethod
  3716.     @provide_session
  3717.     def get_many(
  3718.             cls,
  3719.             execution_date,
  3720.             key=None,
  3721.             task_ids=None,
  3722.             dag_ids=None,
  3723.             include_prior_dates=False,
  3724.             limit=100,
  3725.             session=None):
  3726.         """
  3727.        Retrieve an XCom value, optionally meeting certain criteria
  3728.        """
  3729.         filters = []
  3730.         if key:
  3731.             filters.append(cls.key == key)
  3732.         if task_ids:
  3733.             filters.append(cls.task_id.in_(as_tuple(task_ids)))
  3734.         if dag_ids:
  3735.             filters.append(cls.dag_id.in_(as_tuple(dag_ids)))
  3736.         if include_prior_dates:
  3737.             filters.append(cls.execution_date <= execution_date)
  3738.         else:
  3739.             filters.append(cls.execution_date == execution_date)
  3740.  
  3741.         query = (
  3742.             session.query(cls)
  3743.             .filter(and_(*filters))
  3744.             .order_by(cls.execution_date.desc(), cls.timestamp.desc())
  3745.             .limit(limit))
  3746.  
  3747.         return query.all()
  3748.  
  3749.     @classmethod
  3750.     @provide_session
  3751.     def delete(cls, xcoms, session=None):
  3752.         if isinstance(xcoms, XCom):
  3753.             xcoms = [xcoms]
  3754.         for xcom in xcoms:
  3755.             if not isinstance(xcom, XCom):
  3756.                 raise TypeError(
  3757.                     'Expected XCom; received {}'.format(xcom.__class__.__name__)
  3758.                 )
  3759.             session.delete(xcom)
  3760.         session.commit()
  3761.  
  3762.  
  3763. class DagStat(Base):
  3764.     __tablename__ = "dag_stats"
  3765.  
  3766.     dag_id = Column(String(ID_LEN), primary_key=True)
  3767.     state = Column(String(50), primary_key=True)
  3768.     count = Column(Integer, default=0)
  3769.     dirty = Column(Boolean, default=False)
  3770.  
  3771.     def __init__(self, dag_id, state, count, dirty=False):
  3772.         self.dag_id = dag_id
  3773.         self.state = state
  3774.         self.count = count
  3775.         self.dirty = dirty
  3776.  
  3777.     @staticmethod
  3778.     @provide_session
  3779.     def set_dirty(dag_id, session=None):
  3780.         for dag in session.query(DagStat).filter(DagStat.dag_id == dag_id):
  3781.             dag.dirty = True
  3782.         session.commit()
  3783.  
  3784.     @staticmethod
  3785.     @provide_session
  3786.     def clean_dirty(dag_ids, session=None):
  3787.         """
  3788.        Cleans out the dirty/out-of-sync rows from dag_stats table
  3789.  
  3790.        :param dag_ids: dag_ids that may be dirty
  3791.        :type dag_ids: list
  3792.        :param full_query: whether to check dag_runs for new drs not in dag_stats
  3793.        :type full_query: bool
  3794.        """
  3795.         dag_ids = set(dag_ids)
  3796.  
  3797.         qry = (
  3798.             session.query(DagStat)
  3799.             .filter(and_(DagStat.dag_id.in_(dag_ids), DagStat.dirty == True))
  3800.         )
  3801.  
  3802.         dirty_ids = {dag.dag_id for dag in qry.all()}
  3803.         qry.delete(synchronize_session='fetch')
  3804.         session.commit()
  3805.  
  3806.         qry = (
  3807.             session.query(DagRun.dag_id, DagRun.state, func.count('*'))
  3808.             .filter(DagRun.dag_id.in_(dirty_ids))
  3809.             .group_by(DagRun.dag_id, DagRun.state)
  3810.         )
  3811.  
  3812.         for dag_id, state, count in qry:
  3813.             session.add(DagStat(dag_id=dag_id, state=state, count=count))
  3814.  
  3815.         session.commit()
  3816.  
  3817.  
  3818. class DagRun(Base):
  3819.     """
  3820.    DagRun describes an instance of a Dag. It can be created
  3821.    by the scheduler (for regular runs) or by an external trigger
  3822.    """
  3823.     __tablename__ = "dag_run"
  3824.  
  3825.     ID_PREFIX = 'scheduled__'
  3826.     ID_FORMAT_PREFIX = ID_PREFIX + '{0}'
  3827.     DEADLOCK_CHECK_DEP_CONTEXT = DepContext(ignore_in_retry_period=True)
  3828.  
  3829.     id = Column(Integer, primary_key=True)
  3830.     dag_id = Column(String(ID_LEN))
  3831.     execution_date = Column(DateTime, default=func.now())
  3832.     start_date = Column(DateTime, default=func.now())
  3833.     end_date = Column(DateTime)
  3834.     _state = Column('state', String(50), default=State.RUNNING)
  3835.     run_id = Column(String(ID_LEN))
  3836.     external_trigger = Column(Boolean, default=True)
  3837.     conf = Column(PickleType)
  3838.  
  3839.     dag = None
  3840.  
  3841.     __table_args__ = (
  3842.         Index('dr_run_id', dag_id, run_id, unique=True),
  3843.     )
  3844.  
  3845.     def __repr__(self):
  3846.         return (
  3847.             '<DagRun {dag_id} @ {execution_date}: {run_id}, '
  3848.             'externally triggered: {external_trigger}>'
  3849.         ).format(
  3850.             dag_id=self.dag_id,
  3851.             execution_date=self.execution_date,
  3852.             run_id=self.run_id,
  3853.             external_trigger=self.external_trigger)
  3854.  
  3855.     def get_state(self):
  3856.         return self._state
  3857.  
  3858.     def set_state(self, state):
  3859.         if self._state != state:
  3860.             self._state = state
  3861.             # something really weird goes on here: if you try to close the session
  3862.             # dag runs will end up detached
  3863.             session = settings.Session()
  3864.             DagStat.set_dirty(self.dag_id, session=session)
  3865.  
  3866.     @declared_attr
  3867.     def state(self):
  3868.         return synonym('_state',
  3869.                        descriptor=property(self.get_state, self.set_state))
  3870.  
  3871.     @classmethod
  3872.     def id_for_date(cls, date, prefix=ID_FORMAT_PREFIX):
  3873.         return prefix.format(date.isoformat()[:19])
  3874.  
  3875.     @provide_session
  3876.     def refresh_from_db(self, session=None):
  3877.         """
  3878.        Reloads the current dagrun from the database
  3879.        :param session: database session
  3880.        """
  3881.         DR = DagRun
  3882.  
  3883.         exec_date = func.cast(self.execution_date, DateTime)
  3884.  
  3885.         dr = session.query(DR).filter(
  3886.             DR.dag_id == self.dag_id,
  3887.             func.cast(DR.execution_date, DateTime) == exec_date,
  3888.             DR.run_id == self.run_id
  3889.         ).one()
  3890.  
  3891.         self.id = dr.id
  3892.         self.state = dr.state
  3893.  
  3894.     @staticmethod
  3895.     @provide_session
  3896.     def find(dag_id=None, run_id=None, execution_date=None,
  3897.              state=None, external_trigger=None, session=None):
  3898.         """
  3899.        Returns a set of dag runs for the given search criteria.
  3900.        :param dag_id: the dag_id to find dag runs for
  3901.        :type dag_id: integer, list
  3902.        :param run_id: defines the the run id for this dag run
  3903.        :type run_id: string
  3904.        :param execution_date: the execution date
  3905.        :type execution_date: datetime
  3906.        :param state: the state of the dag run
  3907.        :type state: State
  3908.        :param external_trigger: whether this dag run is externally triggered
  3909.        :type external_trigger: bool
  3910.        :param session: database session
  3911.        :type session: Session
  3912.        """
  3913.         DR = DagRun
  3914.  
  3915.         qry = session.query(DR)
  3916.         if dag_id:
  3917.             qry = qry.filter(DR.dag_id == dag_id)
  3918.         if run_id:
  3919.             qry = qry.filter(DR.run_id == run_id)
  3920.         if execution_date:
  3921.             if isinstance(execution_date, list):
  3922.                 qry = qry.filter(DR.execution_date.in_(execution_date))
  3923.             else:
  3924.                 qry = qry.filter(DR.execution_date == execution_date)
  3925.         if state:
  3926.             qry = qry.filter(DR.state == state)
  3927.         if external_trigger:
  3928.             qry = qry.filter(DR.external_trigger == external_trigger)
  3929.  
  3930.         dr = qry.order_by(DR.execution_date).all()
  3931.  
  3932.         return dr
  3933.  
  3934.     @provide_session
  3935.     def get_task_instances(self, state=None, session=None):
  3936.         """
  3937.        Returns the task instances for this dag run
  3938.        """
  3939.         TI = TaskInstance
  3940.         tis = session.query(TI).filter(
  3941.             TI.dag_id == self.dag_id,
  3942.             TI.execution_date == self.execution_date,
  3943.         )
  3944.         if state:
  3945.             if isinstance(state, six.string_types):
  3946.                 tis = tis.filter(TI.state == state)
  3947.             else:
  3948.                 # this is required to deal with NULL values
  3949.                 if None in state:
  3950.                     tis = tis.filter(
  3951.                         or_(TI.state.in_(state),
  3952.                             TI.state.is_(None))
  3953.                     )
  3954.                 else:
  3955.                     tis = tis.filter(TI.state.in_(state))
  3956.  
  3957.         if self.dag and self.dag.partial:
  3958.             tis = tis.filter(TI.task_id.in_(self.dag.task_ids))
  3959.  
  3960.         return tis.all()
  3961.  
  3962.     @provide_session
  3963.     def get_task_instance(self, task_id, session=None):
  3964.         """
  3965.        Returns the task instance specified by task_id for this dag run
  3966.        :param task_id: the task id
  3967.        """
  3968.  
  3969.         TI = TaskInstance
  3970.         ti = session.query(TI).filter(
  3971.             TI.dag_id == self.dag_id,
  3972.             TI.execution_date == self.execution_date,
  3973.             TI.task_id == task_id
  3974.         ).one()
  3975.  
  3976.         return ti
  3977.  
  3978.     def get_dag(self):
  3979.         """
  3980.        Returns the Dag associated with this DagRun.
  3981.  
  3982.        :return: DAG
  3983.        """
  3984.         if not self.dag:
  3985.             raise AirflowException("The DAG (.dag) for {} needs to be set"
  3986.                                    .format(self))
  3987.  
  3988.         return self.dag
  3989.  
  3990.     @provide_session
  3991.     def get_previous_dagrun(self, session=None):
  3992.         """The previous DagRun, if there is one"""
  3993.  
  3994.         return session.query(DagRun).filter(
  3995.             DagRun.dag_id == self.dag_id,
  3996.             DagRun.execution_date < self.execution_date
  3997.         ).order_by(
  3998.             DagRun.execution_date.desc()
  3999.         ).first()
  4000.  
  4001.     @provide_session
  4002.     def get_previous_scheduled_dagrun(self, session=None):
  4003.         """The previous, SCHEDULED DagRun, if there is one"""
  4004.         dag = self.get_dag()
  4005.  
  4006.         return session.query(DagRun).filter(
  4007.             DagRun.dag_id == self.dag_id,
  4008.             DagRun.execution_date == dag.previous_schedule(self.execution_date)
  4009.         ).first()
  4010.  
  4011.     @provide_session
  4012.     def update_state(self, session=None):
  4013.         """
  4014.        Determines the overall state of the DagRun based on the state
  4015.        of its TaskInstances.
  4016.        :returns State:
  4017.        """
  4018.  
  4019.         dag = self.get_dag()
  4020.  
  4021.         tis = self.get_task_instances(session=session)
  4022.  
  4023.         logging.info("Updating state for {} considering {} task(s)"
  4024.                      .format(self, len(tis)))
  4025.  
  4026.         for ti in list(tis):
  4027.             # skip in db?
  4028.             if ti.state == State.REMOVED:
  4029.                 tis.remove(ti)
  4030.             else:
  4031.                 ti.task = dag.get_task(ti.task_id)
  4032.  
  4033.         # pre-calculate
  4034.         # db is faster
  4035.         start_dttm = datetime.now()
  4036.         unfinished_tasks = self.get_task_instances(
  4037.             state=State.unfinished(),
  4038.             session=session
  4039.         )
  4040.         none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
  4041.         # small speed up
  4042.         if unfinished_tasks and none_depends_on_past:
  4043.             # todo: this can actually get pretty slow: one task costs between 0.01-015s
  4044.             no_dependencies_met = all(
  4045.                 # Use a special dependency context that ignores task's up for retry
  4046.                 # dependency, since a task that is up for retry is not necessarily
  4047.                 # deadlocked.
  4048.                 not t.are_dependencies_met(dep_context=self.DEADLOCK_CHECK_DEP_CONTEXT,
  4049.                                            session=session)
  4050.                 for t in unfinished_tasks)
  4051.  
  4052.         duration = (datetime.now() - start_dttm).total_seconds() * 1000
  4053.         Stats.timing("dagrun.dependency-check.{}.{}".
  4054.                      format(self.dag_id, self.execution_date), duration)
  4055.  
  4056.         # future: remove the check on adhoc tasks (=active_tasks)
  4057.         if len(tis) == len(dag.active_tasks):
  4058.             root_ids = [t.task_id for t in dag.roots]
  4059.             roots = [t for t in tis if t.task_id in root_ids]
  4060.  
  4061.             # if all roots finished and at least on failed, the run failed
  4062.             if (not unfinished_tasks and
  4063.                     any(r.state in (State.FAILED, State.UPSTREAM_FAILED) for r in roots)):
  4064.                 logging.info('Marking run {} failed'.format(self))
  4065.                 self.state = State.FAILED
  4066.  
  4067.             # if all roots succeeded, the run succeeded
  4068.             elif all(r.state in (State.SUCCESS, State.SKIPPED)
  4069.                      for r in roots):
  4070.                 logging.info('Marking run {} successful'.format(self))
  4071.                 self.state = State.SUCCESS
  4072.  
  4073.             # if *all tasks* are deadlocked, the run failed
  4074.             elif unfinished_tasks and none_depends_on_past and no_dependencies_met:
  4075.                 logging.info(
  4076.                     'Deadlock; marking run {} failed'.format(self))
  4077.                 self.state = State.FAILED
  4078.  
  4079.             # finally, if the roots aren't done, the dag is still running
  4080.             else:
  4081.                 self.state = State.RUNNING
  4082.  
  4083.         # todo: determine we want to use with_for_update to make sure to lock the run
  4084.         session.merge(self)
  4085.         session.commit()
  4086.  
  4087.         return self.state
  4088.  
  4089.     @provide_session
  4090.     def verify_integrity(self, session=None):
  4091.         """
  4092.        Verifies the DagRun by checking for removed tasks or tasks that are not in the
  4093.        database yet. It will set state to removed or add the task if required.
  4094.        """
  4095.         dag = self.get_dag()
  4096.         tis = self.get_task_instances(session=session)
  4097.  
  4098.         # check for removed tasks
  4099.         task_ids = []
  4100.         for ti in tis:
  4101.             task_ids.append(ti.task_id)
  4102.             try:
  4103.                 dag.get_task(ti.task_id)
  4104.             except AirflowException:
  4105.                 if self.state is not State.RUNNING and not dag.partial:
  4106.                     ti.state = State.REMOVED
  4107.  
  4108.         # check for missing tasks
  4109.         for task in dag.tasks:
  4110.             if task.adhoc:
  4111.                 continue
  4112.  
  4113.             if task.task_id not in task_ids:
  4114.                 ti = TaskInstance(task, self.execution_date)
  4115.                 session.add(ti)
  4116.  
  4117.         session.commit()
  4118.  
  4119.     @staticmethod
  4120.     def get_running_tasks(session, dag_id, task_ids):
  4121.         """
  4122.        Returns the number of tasks running in the given DAG.
  4123.  
  4124.        :param session: ORM session
  4125.        :param dag_id: ID of the DAG to get the task concurrency of
  4126.        :type dag_id: unicode
  4127.        :param task_ids: A list of valid task IDs for the given DAG
  4128.        :type task_ids: list[unicode]
  4129.        :return: The number of running tasks
  4130.        :rtype: int
  4131.        """
  4132.         qry = session.query(func.count(TaskInstance.task_id)).filter(
  4133.             TaskInstance.dag_id == dag_id,
  4134.             TaskInstance.task_id.in_(task_ids),
  4135.             TaskInstance.state == State.RUNNING,
  4136.         )
  4137.         return qry.scalar()
  4138.  
  4139.     @staticmethod
  4140.     def get_run(session, dag_id, execution_date):
  4141.         """
  4142.        :param dag_id: DAG ID
  4143.        :type dag_id: unicode
  4144.        :param execution_date: execution date
  4145.        :type execution_date: datetime
  4146.        :return: DagRun corresponding to the given dag_id and execution date
  4147.        if one exists. None otherwise.
  4148.        :rtype: DagRun
  4149.        """
  4150.         qry = session.query(DagRun).filter(
  4151.             DagRun.dag_id == dag_id,
  4152.             DagRun.external_trigger == False,
  4153.             DagRun.execution_date == execution_date,
  4154.         )
  4155.         return qry.first()
  4156.  
  4157.     @property
  4158.     def is_backfill(self):
  4159.         if "backfill" in self.run_id:
  4160.             return True
  4161.  
  4162.         return False
  4163.  
  4164.  
  4165. class Pool(Base):
  4166.     __tablename__ = "slot_pool"
  4167.  
  4168.     id = Column(Integer, primary_key=True)
  4169.     pool = Column(String(50), unique=True)
  4170.     slots = Column(Integer, default=0)
  4171.     description = Column(Text)
  4172.  
  4173.     def __repr__(self):
  4174.         return self.pool
  4175.  
  4176.     @provide_session
  4177.     def used_slots(self, session):
  4178.         """
  4179.        Returns the number of slots used at the moment
  4180.        """
  4181.         running = (
  4182.             session
  4183.             .query(TaskInstance)
  4184.             .filter(TaskInstance.pool == self.pool)
  4185.             .filter(TaskInstance.state == State.RUNNING)
  4186.             .count()
  4187.         )
  4188.         return running
  4189.  
  4190.     @provide_session
  4191.     def queued_slots(self, session):
  4192.         """
  4193.        Returns the number of slots used at the moment
  4194.        """
  4195.         return (
  4196.             session
  4197.             .query(TaskInstance)
  4198.             .filter(TaskInstance.pool == self.pool)
  4199.             .filter(TaskInstance.state == State.QUEUED)
  4200.             .count()
  4201.         )
  4202.  
  4203.     @provide_session
  4204.     def open_slots(self, session):
  4205.         """
  4206.        Returns the number of slots open at the moment
  4207.        """
  4208.         used_slots = self.used_slots(session=session)
  4209.         queued_slots = self.queued_slots(session=session)
  4210.         return self.slots - used_slots - queued_slots
  4211.  
  4212.  
  4213. class SlaMiss(Base):
  4214.     """
  4215.    Model that stores a history of the SLA that have been missed.
  4216.    It is used to keep track of SLA failures over time and to avoid double
  4217.    triggering alert emails.
  4218.    """
  4219.     __tablename__ = "sla_miss"
  4220.  
  4221.     task_id = Column(String(ID_LEN), primary_key=True)
  4222.     dag_id = Column(String(ID_LEN), primary_key=True)
  4223.     execution_date = Column(DateTime, primary_key=True)
  4224.     email_sent = Column(Boolean, default=False)
  4225.     timestamp = Column(DateTime)
  4226.     description = Column(Text)
  4227.     notification_sent = Column(Boolean, default=False)
  4228.  
  4229.     def __repr__(self):
  4230.         return str((
  4231.             self.dag_id, self.task_id, self.execution_date.isoformat()))
  4232.  
  4233.  
  4234. class ImportError(Base):
  4235.     __tablename__ = "import_error"
  4236.     id = Column(Integer, primary_key=True)
  4237.     timestamp = Column(DateTime)
  4238.     filename = Column(String(1024))
  4239.     stacktrace = Column(Text)
Add Comment
Please, Sign In to add comment