Advertisement
Guest User

cqlsh.py with protocol version support for cassandra 2.2.0

a guest
Oct 9th, 2015
1,135
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 101.87 KB | None | 0 0
  1. #!/bin/sh
  2. # -*- mode: Python -*-
  3.  
  4. # Licensed to the Apache Software Foundation (ASF) under one
  5. # or more contributor license agreements.  See the NOTICE file
  6. # distributed with this work for additional information
  7. # regarding copyright ownership.  The ASF licenses this file
  8. # to you under the Apache License, Version 2.0 (the
  9. # "License"); you may not use this file except in compliance
  10. # with the License.  You may obtain a copy of the License at
  11. #
  12. #     http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19.  
  20. """:"
  21. # bash code here; finds a suitable python interpreter and execs this file.
  22. # prefer unqualified "python" if suitable:
  23. python -c 'import sys; sys.exit(not (0x020500b0 < sys.hexversion < 0x03000000))' 2>/dev/null \
  24.    && exec python "$0" "$@"
  25. for pyver in 2.6 2.7 2.5; do
  26.    which python$pyver > /dev/null 2>&1 && exec python$pyver "$0" "$@"
  27. done
  28. echo "No appropriate python interpreter found." >&2
  29. exit 1
  30. ":"""
  31.  
  32. from __future__ import with_statement
  33. from uuid import UUID
  34.  
  35. description = "CQL Shell for Apache Cassandra"
  36. version = "5.0.1"
  37.  
  38. from StringIO import StringIO
  39. from contextlib import contextmanager
  40. from glob import glob
  41.  
  42. import cmd
  43. import sys
  44. import os
  45. import time
  46. import optparse
  47. import ConfigParser
  48. import codecs
  49. import locale
  50. import platform
  51. import warnings
  52. import csv
  53. import getpass
  54. from functools import partial
  55. import traceback
  56.  
  57.  
  58. readline = None
  59. try:
  60.     # check if tty first, cause readline doesn't check, and only cares
  61.     # about $TERM. we don't want the funky escape code stuff to be
  62.     # output if not a tty.
  63.     if sys.stdin.isatty():
  64.         import readline
  65. except ImportError:
  66.     pass
  67.  
  68. CQL_LIB_PREFIX = 'cassandra-driver-internal-only-'
  69.  
  70. CASSANDRA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
  71.  
  72. # use bundled libs for python-cql and thrift, if available. if there
  73. # is a ../lib dir, use bundled libs there preferentially.
  74. ZIPLIB_DIRS = [os.path.join(CASSANDRA_PATH, 'lib')]
  75. myplatform = platform.system()
  76. if myplatform == 'Linux':
  77.     ZIPLIB_DIRS.append('/usr/share/cassandra/lib')
  78.  
  79. if os.environ.get('CQLSH_NO_BUNDLED', ''):
  80.     ZIPLIB_DIRS = ()
  81.  
  82.  
  83. def find_zip(libprefix):
  84.     for ziplibdir in ZIPLIB_DIRS:
  85.         zips = glob(os.path.join(ziplibdir, libprefix + '*.zip'))
  86.         if zips:
  87.             return max(zips)   # probably the highest version, if multiple
  88.  
  89. cql_zip = find_zip(CQL_LIB_PREFIX)
  90. if cql_zip:
  91.     ver = os.path.splitext(os.path.basename(cql_zip))[0][len(CQL_LIB_PREFIX):]
  92.     sys.path.insert(0, os.path.join(cql_zip, 'cassandra-driver-' + ver))
  93.  
  94. third_parties = ('futures-', 'six-')
  95.  
  96. for lib in third_parties:
  97.     lib_zip = find_zip(lib)
  98.     if lib_zip:
  99.         sys.path.insert(0, lib_zip)
  100.  
  101. warnings.filterwarnings("ignore", r".*blist.*")
  102. try:
  103.     import cassandra
  104. except ImportError, e:
  105.     sys.exit("\nPython Cassandra driver not installed, or not on PYTHONPATH.\n"
  106.              'You might try "pip install cassandra-driver".\n\n'
  107.              'Python: %s\n'
  108.              'Module load path: %r\n\n'
  109.              'Error: %s\n' % (sys.executable, sys.path, e))
  110.  
  111. from cassandra.cluster import Cluster, PagedResult
  112. from cassandra.query import SimpleStatement, ordered_dict_factory
  113. from cassandra.policies import WhiteListRoundRobinPolicy
  114. from cassandra.protocol import QueryMessage, ResultMessage
  115. from cassandra.metadata import protect_name, protect_names, protect_value, KeyspaceMetadata, TableMetadata, ColumnMetadata
  116. from cassandra.auth import PlainTextAuthProvider
  117.  
  118. # cqlsh should run correctly when run out of a Cassandra source tree,
  119. # out of an unpacked Cassandra tarball, and after a proper package install.
  120. cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib')
  121. if os.path.isdir(cqlshlibdir):
  122.     sys.path.insert(0, cqlshlibdir)
  123.  
  124. from cqlshlib import cqlhandling, cql3handling, pylexotron, sslhandling
  125. from cqlshlib.displaying import (RED, BLUE, CYAN, ANSI_RESET, COLUMN_NAME_COLORS,
  126.                                  FormattedValue, colorme)
  127. from cqlshlib.formatting import format_by_type, formatter_for, format_value_utype
  128. from cqlshlib.util import trim_if_present, get_file_encoding_bomsize
  129. from cqlshlib.formatting import DateTimeFormat
  130. from cqlshlib.formatting import DEFAULT_TIMESTAMP_FORMAT
  131. from cqlshlib.formatting import DEFAULT_DATE_FORMAT
  132. from cqlshlib.formatting import DEFAULT_NANOTIME_FORMAT
  133. from cqlshlib.tracing import print_trace_session, print_trace
  134.  
  135. DEFAULT_HOST = '127.0.0.1'
  136. DEFAULT_PORT = 9042
  137. DEFAULT_CQLVER = '3.3.0'
  138. DEFAULT_PROTOCOL_VERSION = 4
  139. DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
  140.  
  141. DEFAULT_FLOAT_PRECISION = 5
  142. DEFAULT_MAX_TRACE_WAIT = 10
  143.  
  144. if readline is not None and readline.__doc__ is not None and 'libedit' in readline.__doc__:
  145.     DEFAULT_COMPLETEKEY = '\t'
  146. else:
  147.     DEFAULT_COMPLETEKEY = 'tab'
  148.  
  149. cqldocs = None
  150. cqlruleset = None
  151.  
  152. epilog = """Connects to %(DEFAULT_HOST)s:%(DEFAULT_PORT)d by default. These
  153. defaults can be changed by setting $CQLSH_HOST and/or $CQLSH_PORT. When a
  154. host (and optional port number) are given on the command line, they take
  155. precedence over any defaults.""" % globals()
  156.  
  157. parser = optparse.OptionParser(description=description, epilog=epilog,
  158.                                usage="Usage: %prog [options] [host [port]]",
  159.                                version='cqlsh ' + version)
  160. parser.add_option("-C", "--color", action='store_true', dest='color',
  161.                   help='Always use color output')
  162. parser.add_option("--no-color", action='store_false', dest='color',
  163.                   help='Never use color output')
  164. parser.add_option('--ssl', action='store_true', help='Use SSL', default=False)
  165. parser.add_option("-u", "--username", help="Authenticate as user.")
  166. parser.add_option("-p", "--password", help="Authenticate using password.")
  167. parser.add_option('-k', '--keyspace', help='Authenticate to the given keyspace.')
  168. parser.add_option("-f", "--file", help="Execute commands from FILE, then exit")
  169. parser.add_option('--debug', action='store_true',
  170.                   help='Show additional debugging information')
  171. parser.add_option("--cqlshrc", help="Specify an alternative cqlshrc file location.")
  172. parser.add_option('--cqlversion', default=DEFAULT_CQLVER,
  173.                   help='Specify a particular CQL version (default: %default).'
  174.                        ' Examples: "3.0.3", "3.1.0"')
  175. parser.add_option("--protocolversion", default=DEFAULT_PROTOCOL_VERSION, help='Specify protocol version (default: %default).')
  176. parser.add_option("-e", "--execute", help='Execute the statement and quit.')
  177. parser.add_option("--connect-timeout", default=DEFAULT_CONNECT_TIMEOUT_SECONDS, dest='connect_timeout',
  178.                   help='Specify the connection timeout in seconds (default: %default seconds).')
  179.  
  180. optvalues = optparse.Values()
  181. (options, arguments) = parser.parse_args(sys.argv[1:], values=optvalues)
  182.  
  183. #BEGIN history/config definition
  184. HISTORY_DIR = os.path.expanduser(os.path.join('~', '.cassandra'))
  185.  
  186. if hasattr(options, 'cqlshrc'):
  187.     CONFIG_FILE = options.cqlshrc
  188.     if not os.path.exists(CONFIG_FILE):
  189.         print '\nWarning: Specified cqlshrc location `%s` does not exist.  Using `%s` instead.\n' % (CONFIG_FILE, HISTORY_DIR)
  190.         CONFIG_FILE = os.path.join(HISTORY_DIR, 'cqlshrc')
  191. else:
  192.     CONFIG_FILE = os.path.join(HISTORY_DIR, 'cqlshrc')
  193.  
  194. HISTORY = os.path.join(HISTORY_DIR, 'cqlsh_history')
  195. if not os.path.exists(HISTORY_DIR):
  196.     try:
  197.         os.mkdir(HISTORY_DIR)
  198.     except OSError:
  199.         print '\nWarning: Cannot create directory at `%s`. Command history will not be saved.\n' % HISTORY_DIR
  200.  
  201. OLD_CONFIG_FILE = os.path.expanduser(os.path.join('~', '.cqlshrc'))
  202. if os.path.exists(OLD_CONFIG_FILE):
  203.     os.rename(OLD_CONFIG_FILE, CONFIG_FILE)
  204. OLD_HISTORY = os.path.expanduser(os.path.join('~', '.cqlsh_history'))
  205. if os.path.exists(OLD_HISTORY):
  206.     os.rename(OLD_HISTORY, HISTORY)
  207. #END history/config definition
  208.  
  209. CQL_ERRORS = (
  210.     cassandra.AlreadyExists, cassandra.AuthenticationFailed, cassandra.InvalidRequest,
  211.     cassandra.Timeout, cassandra.Unauthorized, cassandra.OperationTimedOut,
  212.     cassandra.cluster.NoHostAvailable,
  213.     cassandra.connection.ConnectionBusy, cassandra.connection.ProtocolError, cassandra.connection.ConnectionException,
  214.     cassandra.protocol.ErrorMessage, cassandra.protocol.InternalError, cassandra.query.TraceUnavailable
  215. )
  216.  
  217. debug_completion = bool(os.environ.get('CQLSH_DEBUG_COMPLETION', '') == 'YES')
  218.  
  219. # we want the cql parser to understand our cqlsh-specific commands too
  220. my_commands_ending_with_newline = (
  221.     'help',
  222.     '?',
  223.     'consistency',
  224.     'serial',
  225.     'describe',
  226.     'desc',
  227.     'show',
  228.     'source',
  229.     'capture',
  230.     'login',
  231.     'debug',
  232.     'tracing',
  233.     'expand',
  234.     'paging',
  235.     'exit',
  236.     'quit'
  237. )
  238.  
  239.  
  240. cqlsh_syntax_completers = []
  241.  
  242.  
  243. def cqlsh_syntax_completer(rulename, termname):
  244.     def registrator(f):
  245.         cqlsh_syntax_completers.append((rulename, termname, f))
  246.         return f
  247.     return registrator
  248.  
  249.  
  250. cqlsh_extra_syntax_rules = r'''
  251. <cqlshCommand> ::= <CQL_Statement>
  252.                 | <specialCommand> ( ";" | "\n" )
  253.                 ;
  254.  
  255. <specialCommand> ::= <describeCommand>
  256.                   | <consistencyCommand>
  257.                   | <serialConsistencyCommand>
  258.                   | <showCommand>
  259.                   | <sourceCommand>
  260.                   | <captureCommand>
  261.                   | <copyCommand>
  262.                   | <loginCommand>
  263.                   | <debugCommand>
  264.                   | <helpCommand>
  265.                   | <tracingCommand>
  266.                   | <expandCommand>
  267.                   | <exitCommand>
  268.                   | <pagingCommand>
  269.                   ;
  270.  
  271. <describeCommand> ::= ( "DESCRIBE" | "DESC" )
  272.                                  ( "FUNCTIONS" ksname=<keyspaceName>?
  273.                                  | "FUNCTION" udf=<anyFunctionName>
  274.                                  | "AGGREGATES" ksname=<keyspaceName>?
  275.                                  | "AGGREGATE" uda=<userAggregateName>
  276.                                  | "KEYSPACES"
  277.                                  | "KEYSPACE" ksname=<keyspaceName>?
  278.                                  | ( "COLUMNFAMILY" | "TABLE" ) cf=<columnFamilyName>
  279.                                  | "INDEX" idx=<indexName>
  280.                                  | ( "COLUMNFAMILIES" | "TABLES" )
  281.                                  | "FULL"? "SCHEMA"
  282.                                  | "CLUSTER"
  283.                                  | "TYPES"
  284.                                  | "TYPE" ut=<userTypeName>
  285.                                  | (ksname=<keyspaceName> | cf=<columnFamilyName> | idx=<indexName>))
  286.                    ;
  287.  
  288. <consistencyCommand> ::= "CONSISTENCY" ( level=<consistencyLevel> )?
  289.                       ;
  290.  
  291. <consistencyLevel> ::= "ANY"
  292.                     | "ONE"
  293.                     | "TWO"
  294.                     | "THREE"
  295.                     | "QUORUM"
  296.                     | "ALL"
  297.                     | "LOCAL_QUORUM"
  298.                     | "EACH_QUORUM"
  299.                     | "SERIAL"
  300.                     | "LOCAL_SERIAL"
  301.                     | "LOCAL_ONE"
  302.                     ;
  303.  
  304. <serialConsistencyCommand> ::= "SERIAL" "CONSISTENCY" ( level=<serialConsistencyLevel> )?
  305.                             ;
  306.  
  307. <serialConsistencyLevel> ::= "SERIAL"
  308.                           | "LOCAL_SERIAL"
  309.                           ;
  310.  
  311. <showCommand> ::= "SHOW" what=( "VERSION" | "HOST" | "SESSION" sessionid=<uuid> )
  312.                ;
  313.  
  314. <sourceCommand> ::= "SOURCE" fname=<stringLiteral>
  315.                  ;
  316.  
  317. <captureCommand> ::= "CAPTURE" ( fname=( <stringLiteral> | "OFF" ) )?
  318.                   ;
  319.  
  320. <copyCommand> ::= "COPY" cf=<columnFamilyName>
  321.                         ( "(" [colnames]=<colname> ( "," [colnames]=<colname> )* ")" )?
  322.                         ( dir="FROM" ( fname=<stringLiteral> | "STDIN" )
  323.                         | dir="TO"   ( fname=<stringLiteral> | "STDOUT" ) )
  324.                         ( "WITH" <copyOption> ( "AND" <copyOption> )* )?
  325.                ;
  326.  
  327. <copyOption> ::= [optnames]=<identifier> "=" [optvals]=<copyOptionVal>
  328.               ;
  329.  
  330. <copyOptionVal> ::= <identifier>
  331.                  | <stringLiteral>
  332.                  ;
  333.  
  334. # avoiding just "DEBUG" so that this rule doesn't get treated as a terminal
  335. <debugCommand> ::= "DEBUG" "THINGS"?
  336.                 ;
  337.  
  338. <helpCommand> ::= ( "HELP" | "?" ) [topic]=( /[a-z_]*/ )*
  339.                ;
  340.  
  341. <tracingCommand> ::= "TRACING" ( switch=( "ON" | "OFF" ) )?
  342.                   ;
  343.  
  344. <expandCommand> ::= "EXPAND" ( switch=( "ON" | "OFF" ) )?
  345.                   ;
  346.  
  347. <pagingCommand> ::= "PAGING" ( switch=( "ON" | "OFF" ) )?
  348.                  ;
  349.  
  350. <loginCommand> ::= "LOGIN" username=<username> (password=<stringLiteral>)?
  351.                 ;
  352.  
  353. <exitCommand> ::= "exit" | "quit"
  354.                ;
  355.  
  356. <qmark> ::= "?" ;
  357. '''
  358.  
  359.  
  360. @cqlsh_syntax_completer('helpCommand', 'topic')
  361. def complete_help(ctxt, cqlsh):
  362.     return sorted([t.upper() for t in cqldocs.get_help_topics() + cqlsh.get_help_topics()])
  363.  
  364.  
  365. def complete_source_quoted_filename(ctxt, cqlsh):
  366.     partial_path = ctxt.get_binding('partial', '')
  367.     head, tail = os.path.split(partial_path)
  368.     exhead = os.path.expanduser(head)
  369.     try:
  370.         contents = os.listdir(exhead or '.')
  371.     except OSError:
  372.         return ()
  373.     matches = filter(lambda f: f.startswith(tail), contents)
  374.     annotated = []
  375.     for f in matches:
  376.         match = os.path.join(head, f)
  377.         if os.path.isdir(os.path.join(exhead, f)):
  378.             match += '/'
  379.         annotated.append(match)
  380.     return annotated
  381.  
  382.  
  383. cqlsh_syntax_completer('sourceCommand', 'fname')(complete_source_quoted_filename)
  384. cqlsh_syntax_completer('captureCommand', 'fname')(complete_source_quoted_filename)
  385.  
  386.  
  387. @cqlsh_syntax_completer('copyCommand', 'fname')
  388. def copy_fname_completer(ctxt, cqlsh):
  389.     lasttype = ctxt.get_binding('*LASTTYPE*')
  390.     if lasttype == 'unclosedString':
  391.         return complete_source_quoted_filename(ctxt, cqlsh)
  392.     partial_path = ctxt.get_binding('partial')
  393.     if partial_path == '':
  394.         return ["'"]
  395.     return ()
  396.  
  397.  
  398. @cqlsh_syntax_completer('copyCommand', 'colnames')
  399. def complete_copy_column_names(ctxt, cqlsh):
  400.     existcols = map(cqlsh.cql_unprotect_name, ctxt.get_binding('colnames', ()))
  401.     ks = cqlsh.cql_unprotect_name(ctxt.get_binding('ksname', None))
  402.     cf = cqlsh.cql_unprotect_name(ctxt.get_binding('cfname'))
  403.     colnames = cqlsh.get_column_names(ks, cf)
  404.     if len(existcols) == 0:
  405.         return [colnames[0]]
  406.     return set(colnames[1:]) - set(existcols)
  407.  
  408.  
  409. COPY_OPTIONS = ('DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'ENCODING', 'NULL')
  410.  
  411.  
  412. @cqlsh_syntax_completer('copyOption', 'optnames')
  413. def complete_copy_options(ctxt, cqlsh):
  414.     optnames = map(str.upper, ctxt.get_binding('optnames', ()))
  415.     direction = ctxt.get_binding('dir').upper()
  416.     opts = set(COPY_OPTIONS) - set(optnames)
  417.     if direction == 'FROM':
  418.         opts -= ('ENCODING',)
  419.     return opts
  420.  
  421.  
  422. @cqlsh_syntax_completer('copyOption', 'optvals')
  423. def complete_copy_opt_values(ctxt, cqlsh):
  424.     optnames = ctxt.get_binding('optnames', ())
  425.     lastopt = optnames[-1].lower()
  426.     if lastopt == 'header':
  427.         return ['true', 'false']
  428.     return [cqlhandling.Hint('<single_character_string>')]
  429.  
  430.  
  431. class NoKeyspaceError(Exception):
  432.     pass
  433.  
  434.  
  435. class KeyspaceNotFound(Exception):
  436.     pass
  437.  
  438.  
  439. class ColumnFamilyNotFound(Exception):
  440.     pass
  441.  
  442. class IndexNotFound(Exception):
  443.     pass
  444.  
  445. class ObjectNotFound(Exception):
  446.     pass
  447.  
  448. class VersionNotSupported(Exception):
  449.     pass
  450.  
  451.  
  452. class UserTypeNotFound(Exception):
  453.     pass
  454.  
  455. class FunctionNotFound(Exception):
  456.     pass
  457.  
  458. class AggregateNotFound(Exception):
  459.     pass
  460.  
  461.  
  462. class DecodeError(Exception):
  463.     verb = 'decode'
  464.  
  465.     def __init__(self, thebytes, err, colname=None):
  466.         self.thebytes = thebytes
  467.         self.err = err
  468.         self.colname = colname
  469.  
  470.     def __str__(self):
  471.         return str(self.thebytes)
  472.  
  473.     def message(self):
  474.         what = 'value %r' % (self.thebytes,)
  475.         if self.colname is not None:
  476.             what = 'value %r (for column %r)' % (self.thebytes, self.colname)
  477.         return 'Failed to %s %s : %s' \
  478.                % (self.verb, what, self.err)
  479.  
  480.     def __repr__(self):
  481.         return '<%s %s>' % (self.__class__.__name__, self.message())
  482.  
  483.  
  484. class FormatError(DecodeError):
  485.     verb = 'format'
  486.  
  487.  
  488. def full_cql_version(ver):
  489.     while ver.count('.') < 2:
  490.         ver += '.0'
  491.     ver_parts = ver.split('-', 1) + ['']
  492.     vertuple = tuple(map(int, ver_parts[0].split('.')) + [ver_parts[1]])
  493.     return ver, vertuple
  494.  
  495.  
  496. def format_value(val, output_encoding, addcolor=False, date_time_format=None,
  497.                  float_precision=None, colormap=None, nullval=None):
  498.     if isinstance(val, DecodeError):
  499.         if addcolor:
  500.             return colorme(repr(val.thebytes), colormap, 'error')
  501.         else:
  502.             return FormattedValue(repr(val.thebytes))
  503.     return format_by_type(type(val), val, output_encoding, colormap=colormap,
  504.                           addcolor=addcolor, nullval=nullval, date_time_format=date_time_format,
  505.                           float_precision=float_precision)
  506.  
  507.  
  508. def show_warning_without_quoting_line(message, category, filename, lineno, file=None, line=None):
  509.     if file is None:
  510.         file = sys.stderr
  511.     try:
  512.         file.write(warnings.formatwarning(message, category, filename, lineno, line=''))
  513.     except IOError:
  514.         pass
  515. warnings.showwarning = show_warning_without_quoting_line
  516. warnings.filterwarnings('always', category=cql3handling.UnexpectedTableStructure)
  517.  
  518.  
  519. def describe_interval(seconds):
  520.     desc = []
  521.     for length, unit in ((86400, 'day'), (3600, 'hour'), (60, 'minute')):
  522.         num = int(seconds) / length
  523.         if num > 0:
  524.             desc.append('%d %s' % (num, unit))
  525.             if num > 1:
  526.                 desc[-1] += 's'
  527.         seconds %= length
  528.     words = '%.03f seconds' % seconds
  529.     if len(desc) > 1:
  530.         words = ', '.join(desc) + ', and ' + words
  531.     elif len(desc) == 1:
  532.         words = desc[0] + ' and ' + words
  533.     return words
  534.  
  535.  
  536. def auto_format_udts():
  537.     # when we see a new user defined type, set up the shell formatting for it
  538.     udt_apply_params = cassandra.cqltypes.UserType.apply_parameters
  539.  
  540.     def new_apply_params(cls, *args, **kwargs):
  541.         udt_class = udt_apply_params(*args, **kwargs)
  542.         formatter_for(udt_class.typename)(format_value_utype)
  543.         return udt_class
  544.  
  545.     cassandra.cqltypes.UserType.udt_apply_parameters = classmethod(new_apply_params)
  546.  
  547.     make_udt_class = cassandra.cqltypes.UserType.make_udt_class
  548.  
  549.     def new_make_udt_class(cls, *args, **kwargs):
  550.         udt_class = make_udt_class(*args, **kwargs)
  551.         formatter_for(udt_class.tuple_type.__name__)(format_value_utype)
  552.         return udt_class
  553.  
  554.     cassandra.cqltypes.UserType.make_udt_class = classmethod(new_make_udt_class)
  555.  
  556.  
  557. class FrozenType(cassandra.cqltypes._ParameterizedType):
  558.     """
  559.    Needed until the bundled python driver adds FrozenType.
  560.    """
  561.     typename = "frozen"
  562.     num_subtypes = 1
  563.  
  564.     @classmethod
  565.     def deserialize_safe(cls, byts, protocol_version):
  566.         subtype, = cls.subtypes
  567.         return subtype.from_binary(byts)
  568.  
  569.     @classmethod
  570.     def serialize_safe(cls, val, protocol_version):
  571.         subtype, = cls.subtypes
  572.         return subtype.to_binary(val, protocol_version)
  573.  
  574. class Shell(cmd.Cmd):
  575.     custom_prompt = os.getenv('CQLSH_PROMPT', '')
  576.     if custom_prompt is not '':
  577.         custom_prompt += "\n"
  578.     default_prompt = custom_prompt + "cqlsh> "
  579.     continue_prompt = "   ... "
  580.     keyspace_prompt = custom_prompt + "cqlsh:%s> "
  581.     keyspace_continue_prompt = "%s    ... "
  582.     show_line_nums = False
  583.     debug = False
  584.     stop = False
  585.     last_hist = None
  586.     shunted_query_out = None
  587.     use_paging = True
  588.     csv_dialect_defaults = dict(delimiter=',', doublequote=False,
  589.                                 escapechar='\\', quotechar='"')
  590.     default_page_size = 100
  591.  
  592.     def __init__(self, hostname, port, color=False,
  593.                  username=None, password=None, encoding=None, stdin=None, tty=True,
  594.                  completekey=DEFAULT_COMPLETEKEY, use_conn=None,
  595.                  cqlver=DEFAULT_CQLVER, keyspace=None,
  596.                  tracing_enabled=False, expand_enabled=False,
  597.                  display_nanotime_format=DEFAULT_NANOTIME_FORMAT,
  598.                  display_timestamp_format=DEFAULT_TIMESTAMP_FORMAT,
  599.                  display_date_format=DEFAULT_DATE_FORMAT,
  600.                  display_float_precision=DEFAULT_FLOAT_PRECISION,
  601.                  max_trace_wait=DEFAULT_MAX_TRACE_WAIT,
  602.                  ssl=False,
  603.                  single_statement=None,
  604.                  client_timeout=10,
  605.                  protocol_version=DEFAULT_PROTOCOL_VERSION,
  606.                  connect_timeout=DEFAULT_CONNECT_TIMEOUT_SECONDS):
  607.         cmd.Cmd.__init__(self, completekey=completekey)
  608.         self.hostname = hostname
  609.         self.port = port
  610.         self.auth_provider = None
  611.         if username:
  612.             if not password:
  613.                 password = getpass.getpass()
  614.             self.auth_provider = PlainTextAuthProvider(username=username, password=password)
  615.         self.username = username
  616.         self.keyspace = keyspace
  617.         self.ssl = ssl
  618.         self.tracing_enabled = tracing_enabled
  619.         self.expand_enabled = expand_enabled
  620.         if use_conn:
  621.             self.conn = use_conn
  622.         else:
  623.             self.conn = Cluster(contact_points=(self.hostname,), port=self.port, cql_version=cqlver,
  624.                                 protocol_version=protocol_version,
  625.                                 auth_provider=self.auth_provider,
  626.                                 ssl_options=sslhandling.ssl_settings(hostname, CONFIG_FILE) if ssl else None,
  627.                                 load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
  628.                                 connect_timeout=connect_timeout)
  629.         self.owns_connection = not use_conn
  630.         self.set_expanded_cql_version(cqlver)
  631.  
  632.         if keyspace:
  633.             self.session = self.conn.connect(keyspace)
  634.         else:
  635.             self.session = self.conn.connect()
  636.  
  637.         self.color = color
  638.  
  639.         self.display_nanotime_format = display_nanotime_format
  640.         self.display_timestamp_format = display_timestamp_format
  641.         self.display_date_format = display_date_format
  642.  
  643.         self.display_float_precision = display_float_precision
  644.  
  645.         # Workaround for CASSANDRA-8521 until PYTHON-205 is resolved.
  646.         # If there is no schema metadata present (due to a schema mismatch),
  647.         # get rid of the code that checks for a schema mismatch and force
  648.         # the schema metadata to be built.
  649.         if not self.conn.metadata.keyspaces:
  650.             self.printerr("Warning: schema version mismatch detected; check the schema versions of your "
  651.                           "nodes in system.local and system.peers.")
  652.             original_method = self.conn.control_connection._get_schema_mismatches
  653.             try:
  654.                 self.conn.control_connection._get_schema_mismatches = lambda *args, **kwargs: None
  655.                 future = self.conn.submit_schema_refresh()
  656.                 future.result(timeout=10)
  657.             finally:
  658.                 self.conn.control_connection._get_schema_mismatches = original_method
  659.  
  660.         self.session.default_timeout = client_timeout
  661.         self.session.row_factory = ordered_dict_factory
  662.         self.get_connection_versions()
  663.  
  664.         self.current_keyspace = keyspace
  665.  
  666.         self.display_timestamp_format = display_timestamp_format
  667.         self.display_nanotime_format = display_nanotime_format
  668.         self.display_date_format = display_date_format
  669.  
  670.         self.max_trace_wait = max_trace_wait
  671.         self.session.max_trace_wait = max_trace_wait
  672.         if encoding is None:
  673.             encoding = locale.getpreferredencoding()
  674.         self.encoding = encoding
  675.         self.output_codec = codecs.lookup(encoding)
  676.  
  677.         self.statement = StringIO()
  678.         self.lineno = 1
  679.         self.in_comment = False
  680.  
  681.         self.prompt = ''
  682.         if stdin is None:
  683.             stdin = sys.stdin
  684.         self.tty = tty
  685.         if tty:
  686.             self.reset_prompt()
  687.             self.report_connection()
  688.             print 'Use HELP for help.'
  689.         else:
  690.             self.show_line_nums = True
  691.         self.stdin = stdin
  692.         self.query_out = sys.stdout
  693.         self.consistency_level = cassandra.ConsistencyLevel.ONE
  694.         self.serial_consistency_level = cassandra.ConsistencyLevel.SERIAL;
  695.         # the python driver returns BLOBs as string, but we expect them as bytearrays
  696.         cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
  697.         cassandra.cqltypes.CassandraType.support_empty_values = True
  698.  
  699.         auto_format_udts()
  700.  
  701.         self.empty_lines = 0
  702.         self.statement_error = False
  703.         self.single_statement = single_statement
  704.  
  705.     def set_expanded_cql_version(self, ver):
  706.         ver, vertuple = full_cql_version(ver)
  707.         self.cql_version = ver
  708.         self.cql_ver_tuple = vertuple
  709.  
  710.     def cqlver_atleast(self, major, minor=0, patch=0):
  711.         return self.cql_ver_tuple[:3] >= (major, minor, patch)
  712.  
  713.     def myformat_value(self, val, **kwargs):
  714.         if isinstance(val, DecodeError):
  715.             self.decoding_errors.append(val)
  716.         try:
  717.             dtformats = DateTimeFormat(timestamp_format=self.display_timestamp_format,
  718.                                        date_format=self.display_date_format, nanotime_format=self.display_nanotime_format)
  719.             return format_value(val, self.output_codec.name,
  720.                                 addcolor=self.color, date_time_format=dtformats,
  721.                                 float_precision=self.display_float_precision, **kwargs)
  722.         except Exception, e:
  723.             err = FormatError(val, e)
  724.             self.decoding_errors.append(err)
  725.             return format_value(err, self.output_codec.name, addcolor=self.color)
  726.  
  727.     def myformat_colname(self, name, table_meta=None):
  728.         column_colors = COLUMN_NAME_COLORS.copy()
  729.         # check column role and color appropriately
  730.         if table_meta:
  731.             if name in [col.name for col in table_meta.partition_key]:
  732.                 column_colors.default_factory = lambda: RED
  733.             elif name in [col.name for col in table_meta.clustering_key]:
  734.                 column_colors.default_factory = lambda: CYAN
  735.         return self.myformat_value(name, colormap=column_colors)
  736.  
  737.     def report_connection(self):
  738.         self.show_host()
  739.         self.show_version()
  740.  
  741.     def show_host(self):
  742.         print "Connected to %s at %s:%d." % \
  743.                (self.applycolor(self.get_cluster_name(), BLUE),
  744.                 self.hostname,
  745.                 self.port)
  746.  
  747.     def show_version(self):
  748.         vers = self.connection_versions.copy()
  749.         vers['shver'] = version
  750.         # system.Versions['cql'] apparently does not reflect changes with
  751.         # set_cql_version.
  752.         vers['cql'] = self.cql_version
  753.         print "[cqlsh %(shver)s | Cassandra %(build)s | CQL spec %(cql)s | Native protocol v%(protocol)s]" % vers
  754.  
  755.     def show_session(self, sessionid):
  756.         print_trace_session(self, self.session, sessionid)
  757.  
  758.     def get_connection_versions(self):
  759.         result, = self.session.execute("select * from system.local where key = 'local'")
  760.         vers = {
  761.             'build': result['release_version'],
  762.             'protocol': result['native_protocol_version'],
  763.             'cql': result['cql_version'],
  764.         }
  765.         self.connection_versions = vers
  766.  
  767.     def get_keyspace_names(self):
  768.         return map(str, self.conn.metadata.keyspaces.keys())
  769.  
  770.     def get_columnfamily_names(self, ksname=None):
  771.         if ksname is None:
  772.             ksname = self.current_keyspace
  773.  
  774.         return map(str, self.get_keyspace_meta(ksname).tables.keys())
  775.  
  776.     def get_index_names(self, ksname=None):
  777.         if ksname is None:
  778.             ksname = self.current_keyspace
  779.  
  780.         return map(str, self.get_keyspace_meta(ksname).indexes.keys())
  781.  
  782.     def get_column_names(self, ksname, cfname):
  783.         if ksname is None:
  784.             ksname = self.current_keyspace
  785.         layout = self.get_table_meta(ksname, cfname)
  786.         return [str(col) for col in layout.columns]
  787.  
  788.     def get_usertype_names(self, ksname=None):
  789.         if ksname is None:
  790.             ksname = self.current_keyspace
  791.  
  792.         return self.get_keyspace_meta(ksname).user_types.keys()
  793.  
  794.     def get_usertype_layout(self, ksname, typename):
  795.         if ksname is None:
  796.             ksname = self.current_keyspace
  797.  
  798.         ks_meta = self.get_keyspace_meta(ksname)
  799.  
  800.         try:
  801.             user_type = ks_meta.user_types[typename]
  802.         except KeyError:
  803.             raise UserTypeNotFound("User type %r not found" % typename)
  804.  
  805.         return [(field_name, field_type.cql_parameterized_type())
  806.                 for field_name, field_type in zip(user_type.field_names, user_type.field_types)]
  807.  
  808.     def get_userfunction_names(self, ksname=None):
  809.         if ksname is None:
  810.             ksname = self.current_keyspace
  811.  
  812.         return map(lambda f: f.name, self.get_keyspace_meta(ksname).functions.values())
  813.  
  814.     def get_useraggregate_names(self, ksname=None):
  815.         if ksname is None:
  816.             ksname = self.current_keyspace
  817.  
  818.         return map(lambda f: f.name, self.get_keyspace_meta(ksname).aggregates.values())
  819.  
  820.     def get_cluster_name(self):
  821.         return self.conn.metadata.cluster_name
  822.  
  823.     def get_partitioner(self):
  824.         return self.conn.metadata.partitioner
  825.  
  826.     def get_keyspace_meta(self, ksname):
  827.         if not ksname in self.conn.metadata.keyspaces:
  828.             raise KeyspaceNotFound('Keyspace %r not found.' % ksname)
  829.         return self.conn.metadata.keyspaces[ksname]
  830.  
  831.     def get_keyspaces(self):
  832.         return self.conn.metadata.keyspaces.values()
  833.  
  834.     def get_ring(self):
  835.         if self.current_keyspace is None or self.current_keyspace == 'system':
  836.             raise NoKeyspaceError("Ring view requires a current non-system keyspace")
  837.         self.conn.metadata.token_map.rebuild_keyspace(self.current_keyspace, build_if_absent=True)
  838.         return self.conn.metadata.token_map.tokens_to_hosts_by_ks[self.current_keyspace]
  839.  
  840.     def get_table_meta(self, ksname, tablename):
  841.         if ksname is None:
  842.             ksname = self.current_keyspace
  843.         ksmeta = self.get_keyspace_meta(ksname)
  844.  
  845.         if tablename not in ksmeta.tables:
  846.             if ksname == 'system_auth' and tablename in ['roles', 'role_permissions']:
  847.                 self.get_fake_auth_table_meta(ksname, tablename)
  848.             else:
  849.                 raise ColumnFamilyNotFound("Column family %r not found" % tablename)
  850.         else:
  851.             return ksmeta.tables[tablename]
  852.  
  853.     def get_fake_auth_table_meta(self, ksname, tablename):
  854.         # may be using external auth implementation so internal tables
  855.         # aren't actually defined in schema. In this case, we'll fake
  856.         # them up
  857.         if tablename == 'roles':
  858.             ks_meta = KeyspaceMetadata(ksname, True, None, None)
  859.             table_meta = TableMetadata(ks_meta, 'roles')
  860.             table_meta.columns['role'] = ColumnMetadata(table_meta, 'role', cassandra.cqltypes.UTF8Type)
  861.             table_meta.columns['is_superuser'] = ColumnMetadata(table_meta, 'is_superuser', cassandra.cqltypes.BooleanType)
  862.             table_meta.columns['can_login'] = ColumnMetadata(table_meta, 'can_login', cassandra.cqltypes.BooleanType)
  863.         elif tablename == 'role_permissions':
  864.             ks_meta = KeyspaceMetadata(ksname, True, None, None)
  865.             table_meta = TableMetadata(ks_meta, 'role_permissions')
  866.             table_meta.columns['role'] = ColumnMetadata(table_meta, 'role', cassandra.cqltypes.UTF8Type)
  867.             table_meta.columns['resource'] = ColumnMetadata(table_meta, 'resource', cassandra.cqltypes.UTF8Type)
  868.             table_meta.columns['permission'] = ColumnMetadata(table_meta, 'permission', cassandra.cqltypes.UTF8Type)
  869.         else:
  870.             raise ColumnFamilyNotFound("Column family %r not found" % tablename)
  871.  
  872.     def get_index_meta(self, ksname, idxname):
  873.         if ksname is None:
  874.             ksname = self.current_keyspace
  875.         ksmeta = self.get_keyspace_meta(ksname)
  876.  
  877.         if idxname not in ksmeta.indexes:
  878.             raise IndexNotFound("Index %r not found" % idxname)
  879.  
  880.         return ksmeta.indexes[idxname]
  881.  
  882.     def get_object_meta(self, ks, name):
  883.         if name is None:
  884.             if ks and ks in self.conn.metadata.keyspaces:
  885.                 return self.conn.metadata.keyspaces[ks]
  886.             elif self.current_keyspace is None:
  887.                 raise ObjectNotFound("%r not found in keyspaces" % (ks))
  888.             else:
  889.                 name = ks
  890.                 ks = self.current_keyspace
  891.  
  892.         if ks is None:
  893.             ks = self.current_keyspace
  894.  
  895.         ksmeta = self.get_keyspace_meta(ks)
  896.  
  897.         if name in ksmeta.tables:
  898.             return ksmeta.tables[name]
  899.         elif name in ksmeta.indexes:
  900.             return ksmeta.indexes[name]
  901.  
  902.         raise ObjectNotFound("%r not found in keyspace %r" % (name, ks))
  903.  
  904.     def get_usertypes_meta(self):
  905.         data = self.session.execute("select * from system.schema_usertypes")
  906.         if not data:
  907.             return cql3handling.UserTypesMeta({})
  908.  
  909.         return cql3handling.UserTypesMeta.from_layout(data)
  910.  
  911.     def get_trigger_names(self, ksname=None):
  912.         if ksname is None:
  913.             ksname = self.current_keyspace
  914.  
  915.         return [trigger.name
  916.                 for table in self.get_keyspace_meta(ksname).tables.values()
  917.                 for trigger in table.triggers.values()]
  918.  
  919.     def reset_statement(self):
  920.         self.reset_prompt()
  921.         self.statement.truncate(0)
  922.         self.empty_lines = 0
  923.  
  924.     def reset_prompt(self):
  925.         if self.current_keyspace is None:
  926.             self.set_prompt(self.default_prompt, True)
  927.         else:
  928.             self.set_prompt(self.keyspace_prompt % self.current_keyspace, True)
  929.  
  930.     def set_continue_prompt(self):
  931.         if self.empty_lines >= 3:
  932.             self.set_prompt("Statements are terminated with a ';'.  You can press CTRL-C to cancel an incomplete statement.")
  933.             self.empty_lines = 0
  934.             return
  935.         if self.current_keyspace is None:
  936.             self.set_prompt(self.continue_prompt)
  937.         else:
  938.             spaces = ' ' * len(str(self.current_keyspace))
  939.             self.set_prompt(self.keyspace_continue_prompt % spaces)
  940.         self.empty_lines = self.empty_lines + 1 if not self.lastcmd else 0
  941.  
  942.     @contextmanager
  943.     def prepare_loop(self):
  944.         readline = None
  945.         if self.tty and self.completekey:
  946.             try:
  947.                 import readline
  948.             except ImportError:
  949.                 if platform.system() == 'Windows':
  950.                     print "WARNING: pyreadline dependency missing.  Install to enable tab completion."
  951.                 pass
  952.             else:
  953.                 old_completer = readline.get_completer()
  954.                 readline.set_completer(self.complete)
  955.                 if readline.__doc__ is not None and 'libedit' in readline.__doc__:
  956.                     readline.parse_and_bind("bind -e")
  957.                     readline.parse_and_bind("bind '" + self.completekey + "' rl_complete")
  958.                     readline.parse_and_bind("bind ^R em-inc-search-prev")
  959.                 else:
  960.                     readline.parse_and_bind(self.completekey + ": complete")
  961.         try:
  962.             yield
  963.         finally:
  964.             if readline is not None:
  965.                 readline.set_completer(old_completer)
  966.  
  967.     def get_input_line(self, prompt=''):
  968.         if self.tty:
  969.             self.lastcmd = raw_input(prompt)
  970.             line = self.lastcmd + '\n'
  971.         else:
  972.             self.lastcmd = self.stdin.readline()
  973.             line = self.lastcmd
  974.             if not len(line):
  975.                 raise EOFError
  976.         self.lineno += 1
  977.         return line
  978.  
  979.     def use_stdin_reader(self, until='', prompt=''):
  980.         until += '\n'
  981.         while True:
  982.             try:
  983.                 newline = self.get_input_line(prompt=prompt)
  984.             except EOFError:
  985.                 return
  986.             if newline == until:
  987.                 return
  988.             yield newline
  989.  
  990.     def cmdloop(self):
  991.         """
  992.        Adapted from cmd.Cmd's version, because there is literally no way with
  993.        cmd.Cmd.cmdloop() to tell the difference between "EOF" showing up in
  994.        input and an actual EOF.
  995.        """
  996.         with self.prepare_loop():
  997.             while not self.stop:
  998.                 try:
  999.                     if self.single_statement:
  1000.                         line = self.single_statement
  1001.                         self.stop = True
  1002.                     else:
  1003.                         line = self.get_input_line(self.prompt)
  1004.                     self.statement.write(line)
  1005.                     if self.onecmd(self.statement.getvalue()):
  1006.                         self.reset_statement()
  1007.                 except EOFError:
  1008.                     self.handle_eof()
  1009.                 except CQL_ERRORS, cqlerr:
  1010.                     self.printerr(str(cqlerr))
  1011.                 except KeyboardInterrupt:
  1012.                     self.reset_statement()
  1013.                     print
  1014.  
  1015.     def onecmd(self, statementtext):
  1016.         """
  1017.        Returns true if the statement is complete and was handled (meaning it
  1018.        can be reset).
  1019.        """
  1020.  
  1021.         try:
  1022.             statements, in_batch = cqlruleset.cql_split_statements(statementtext)
  1023.         except pylexotron.LexingError, e:
  1024.             if self.show_line_nums:
  1025.                 self.printerr('Invalid syntax at char %d' % (e.charnum,))
  1026.             else:
  1027.                 self.printerr('Invalid syntax at line %d, char %d'
  1028.                               % (e.linenum, e.charnum))
  1029.             statementline = statementtext.split('\n')[e.linenum - 1]
  1030.             self.printerr('  %s' % statementline)
  1031.             self.printerr(' %s^' % (' ' * e.charnum))
  1032.             return True
  1033.  
  1034.         while statements and not statements[-1]:
  1035.             statements = statements[:-1]
  1036.         if not statements:
  1037.             return True
  1038.         if in_batch or statements[-1][-1][0] != 'endtoken':
  1039.             self.set_continue_prompt()
  1040.             return
  1041.         for st in statements:
  1042.             try:
  1043.                 self.handle_statement(st, statementtext)
  1044.             except Exception, e:
  1045.                 if self.debug:
  1046.                     traceback.print_exc()
  1047.                 else:
  1048.                     self.printerr(e)
  1049.         return True
  1050.  
  1051.     def handle_eof(self):
  1052.         if self.tty:
  1053.             print
  1054.         statement = self.statement.getvalue()
  1055.         if statement.strip():
  1056.             if not self.onecmd(statement):
  1057.                 self.printerr('Incomplete statement at end of file')
  1058.         self.do_exit()
  1059.  
  1060.     def handle_statement(self, tokens, srcstr):
  1061.         # Concat multi-line statements and insert into history
  1062.         if readline is not None:
  1063.             nl_count = srcstr.count("\n")
  1064.  
  1065.             new_hist = srcstr.replace("\n", " ").rstrip()
  1066.  
  1067.             if nl_count > 1 and self.last_hist != new_hist:
  1068.                 readline.add_history(new_hist)
  1069.  
  1070.             self.last_hist = new_hist
  1071.         cmdword = tokens[0][1]
  1072.         if cmdword == '?':
  1073.             cmdword = 'help'
  1074.         custom_handler = getattr(self, 'do_' + cmdword.lower(), None)
  1075.         if custom_handler:
  1076.             parsed = cqlruleset.cql_whole_parse_tokens(tokens, srcstr=srcstr,
  1077.                                                        startsymbol='cqlshCommand')
  1078.             if parsed and not parsed.remainder:
  1079.                 # successful complete parse
  1080.                 return custom_handler(parsed)
  1081.             else:
  1082.                 return self.handle_parse_error(cmdword, tokens, parsed, srcstr)
  1083.         return self.perform_statement(cqlruleset.cql_extract_orig(tokens, srcstr))
  1084.  
  1085.     def handle_parse_error(self, cmdword, tokens, parsed, srcstr):
  1086.         if cmdword.lower() in ('select', 'insert', 'update', 'delete', 'truncate',
  1087.                                'create', 'drop', 'alter', 'grant', 'revoke',
  1088.                                'batch', 'list'):
  1089.             # hey, maybe they know about some new syntax we don't. type
  1090.             # assumptions won't work, but maybe the query will.
  1091.             return self.perform_statement(cqlruleset.cql_extract_orig(tokens, srcstr))
  1092.         if parsed:
  1093.             self.printerr('Improper %s command (problem at %r).' % (cmdword, parsed.remainder[0]))
  1094.         else:
  1095.             self.printerr('Improper %s command.' % cmdword)
  1096.  
  1097.     def do_use(self, parsed):
  1098.         ksname = parsed.get_binding('ksname')
  1099.         if self.perform_simple_statement(SimpleStatement(parsed.extract_orig())):
  1100.             if ksname[0] == '"' and ksname[-1] == '"':
  1101.                 self.current_keyspace = self.cql_unprotect_name(ksname)
  1102.             else:
  1103.                 self.current_keyspace = ksname.lower()
  1104.  
  1105.     def do_select(self, parsed):
  1106.         tracing_was_enabled = self.tracing_enabled
  1107.         ksname = parsed.get_binding('ksname')
  1108.         stop_tracing = ksname == 'system_traces' or (ksname is None and self.current_keyspace == 'system_traces')
  1109.         self.tracing_enabled = self.tracing_enabled and not stop_tracing
  1110.         statement = parsed.extract_orig()
  1111.         self.perform_statement(statement)
  1112.         self.tracing_enabled = tracing_was_enabled
  1113.  
  1114.     def perform_statement(self, statement):
  1115.         stmt = SimpleStatement(statement, consistency_level=self.consistency_level, serial_consistency_level=self.serial_consistency_level, fetch_size=self.default_page_size if self.use_paging else None)
  1116.         result, future = self.perform_simple_statement(stmt)
  1117.  
  1118.         if future:
  1119.             if future.warnings:
  1120.                 self.print_warnings(future.warnings)
  1121.  
  1122.             if self.tracing_enabled:
  1123.                 try:
  1124.                     trace = future.get_query_trace(self.max_trace_wait)
  1125.                     if trace:
  1126.                         print_trace(self, trace)
  1127.                     else:
  1128.                         msg = "Statement trace did not complete within %d seconds" % (self.session.max_trace_wait)
  1129.                         self.writeresult(msg, color=RED)
  1130.                 except Exception, err:
  1131.                     self.printerr("Unable to fetch query trace: %s" % (str(err),))
  1132.  
  1133.         return result
  1134.  
  1135.     def parse_for_table_meta(self, query_string):
  1136.         try:
  1137.             parsed = cqlruleset.cql_parse(query_string)[1]
  1138.         except IndexError:
  1139.             return None
  1140.         ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1141.         cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
  1142.         return self.get_table_meta(ks, cf)
  1143.  
  1144.     def perform_simple_statement(self, statement):
  1145.         if not statement:
  1146.             return False, None
  1147.         rows = None
  1148.         while True:
  1149.             try:
  1150.                 future = self.session.execute_async(statement, trace=self.tracing_enabled)
  1151.                 rows = future.result(self.session.default_timeout)
  1152.                 break
  1153.             except CQL_ERRORS, err:
  1154.                 self.printerr(str(err.__class__.__name__) + ": " + str(err))
  1155.                 return False, None
  1156.             except Exception, err:
  1157.                 import traceback
  1158.                 self.printerr(traceback.format_exc())
  1159.                 return False, None
  1160.  
  1161.         if statement.query_string[:6].lower() == 'select':
  1162.             self.print_result(rows, self.parse_for_table_meta(statement.query_string))
  1163.         elif statement.query_string.lower().startswith("list users") or statement.query_string.lower().startswith("list roles"):
  1164.             self.print_result(rows, self.get_table_meta('system_auth', 'roles'))
  1165.         elif statement.query_string.lower().startswith("list"):
  1166.             self.print_result(rows, self.get_table_meta('system_auth', 'role_permissions'))
  1167.         elif rows:
  1168.             # CAS INSERT/UPDATE
  1169.             self.writeresult("")
  1170.             self.print_static_result(rows, self.parse_for_table_meta(statement.query_string))
  1171.         self.flush_output()
  1172.         return True, future
  1173.  
  1174.     def print_result(self, rows, table_meta):
  1175.         self.decoding_errors = []
  1176.  
  1177.         self.writeresult("")
  1178.         if isinstance(rows, PagedResult) and self.tty:
  1179.             num_rows = 0
  1180.             while True:
  1181.                 page = list(rows.current_response)
  1182.                 if not page:
  1183.                     break
  1184.                 num_rows += len(page)
  1185.                 self.print_static_result(page, table_meta)
  1186.                 if not rows.response_future.has_more_pages:
  1187.                     break
  1188.                 raw_input("---MORE---")
  1189.  
  1190.                 rows.response_future.start_fetching_next_page()
  1191.                 result = rows.response_future.result()
  1192.                 if rows.response_future.has_more_pages:
  1193.                     rows.current_response = result.current_response
  1194.                 else:
  1195.                     rows.current_response = iter(result)
  1196.         else:
  1197.             rows = list(rows or [])
  1198.             num_rows = len(rows)
  1199.             self.print_static_result(rows, table_meta)
  1200.         self.writeresult("(%d rows)" % num_rows)
  1201.  
  1202.         if self.decoding_errors:
  1203.             for err in self.decoding_errors[:2]:
  1204.                 self.writeresult(err.message(), color=RED)
  1205.             if len(self.decoding_errors) > 2:
  1206.                 self.writeresult('%d more decoding errors suppressed.'
  1207.                                  % (len(self.decoding_errors) - 2), color=RED)
  1208.  
  1209.     def print_static_result(self, rows, table_meta):
  1210.         if not rows:
  1211.             if not table_meta:
  1212.                 return
  1213.             # print header only
  1214.             colnames = table_meta.columns.keys()  # full header
  1215.             formatted_names = [self.myformat_colname(name, table_meta) for name in colnames]
  1216.             self.print_formatted_result(formatted_names, None)
  1217.             return
  1218.  
  1219.         colnames = rows[0].keys()
  1220.         formatted_names = [self.myformat_colname(name, table_meta) for name in colnames]
  1221.         formatted_values = [map(self.myformat_value, row.values()) for row in rows]
  1222.  
  1223.         if self.expand_enabled:
  1224.             self.print_formatted_result_vertically(formatted_names, formatted_values)
  1225.         else:
  1226.             self.print_formatted_result(formatted_names, formatted_values)
  1227.  
  1228.     def print_formatted_result(self, formatted_names, formatted_values):
  1229.         # determine column widths
  1230.         widths = [n.displaywidth for n in formatted_names]
  1231.         if formatted_values is not None:
  1232.             for fmtrow in formatted_values:
  1233.                 for num, col in enumerate(fmtrow):
  1234.                     widths[num] = max(widths[num], col.displaywidth)
  1235.  
  1236.         # print header
  1237.         header = ' | '.join(hdr.ljust(w, color=self.color) for (hdr, w) in zip(formatted_names, widths))
  1238.         self.writeresult(' ' + header.rstrip())
  1239.         self.writeresult('-%s-' % '-+-'.join('-' * w for w in widths))
  1240.  
  1241.         # stop if there are no rows
  1242.         if formatted_values is None:
  1243.             self.writeresult("")
  1244.             return
  1245.  
  1246.         # print row data
  1247.         for row in formatted_values:
  1248.             line = ' | '.join(col.rjust(w, color=self.color) for (col, w) in zip(row, widths))
  1249.             self.writeresult(' ' + line)
  1250.  
  1251.         self.writeresult("")
  1252.  
  1253.     def print_formatted_result_vertically(self, formatted_names, formatted_values):
  1254.         max_col_width = max([n.displaywidth for n in formatted_names])
  1255.         max_val_width = max([n.displaywidth for row in formatted_values for n in row])
  1256.  
  1257.         # for each row returned, list all the column-value pairs
  1258.         for row_id, row in enumerate(formatted_values):
  1259.             self.writeresult("@ Row %d" % (row_id + 1))
  1260.             self.writeresult('-%s-' % '-+-'.join(['-' * max_col_width, '-' * max_val_width]))
  1261.             for field_id, field in enumerate(row):
  1262.                 column = formatted_names[field_id].ljust(max_col_width, color=self.color)
  1263.                 value = field.ljust(field.displaywidth, color=self.color)
  1264.                 self.writeresult(' ' + " | ".join([column, value]))
  1265.             self.writeresult('')
  1266.  
  1267.     def print_warnings(self, warnings):
  1268.         if warnings is None or len(warnings) == 0:
  1269.             return;
  1270.  
  1271.         self.writeresult('')
  1272.         self.writeresult('Warnings :')
  1273.         for warning in warnings:
  1274.             self.writeresult(warning)
  1275.             self.writeresult('')
  1276.  
  1277.     def emptyline(self):
  1278.         pass
  1279.  
  1280.     def parseline(self, line):
  1281.         # this shouldn't be needed
  1282.         raise NotImplementedError
  1283.  
  1284.     def complete(self, text, state):
  1285.         if readline is None:
  1286.             return
  1287.         if state == 0:
  1288.             try:
  1289.                 self.completion_matches = self.find_completions(text)
  1290.             except Exception:
  1291.                 if debug_completion:
  1292.                     import traceback
  1293.                     traceback.print_exc()
  1294.                 else:
  1295.                     raise
  1296.         try:
  1297.             return self.completion_matches[state]
  1298.         except IndexError:
  1299.             return None
  1300.  
  1301.     def find_completions(self, text):
  1302.         curline = readline.get_line_buffer()
  1303.         prevlines = self.statement.getvalue()
  1304.         wholestmt = prevlines + curline
  1305.         begidx = readline.get_begidx() + len(prevlines)
  1306.         stuff_to_complete = wholestmt[:begidx]
  1307.         return cqlruleset.cql_complete(stuff_to_complete, text, cassandra_conn=self,
  1308.                                        debug=debug_completion, startsymbol='cqlshCommand')
  1309.  
  1310.     def set_prompt(self, prompt, prepend_user=False):
  1311.         if prepend_user and self.username:
  1312.             self.prompt = "%s@%s" % (self.username, prompt)
  1313.             return
  1314.         self.prompt = prompt
  1315.  
  1316.     def cql_unprotect_name(self, namestr):
  1317.         if namestr is None:
  1318.             return
  1319.         return cqlruleset.dequote_name(namestr)
  1320.  
  1321.     def cql_unprotect_value(self, valstr):
  1322.         if valstr is not None:
  1323.             return cqlruleset.dequote_value(valstr)
  1324.  
  1325.     def print_recreate_keyspace(self, ksdef, out):
  1326.         out.write(ksdef.export_as_string())
  1327.         out.write("\n")
  1328.  
  1329.     def print_recreate_columnfamily(self, ksname, cfname, out):
  1330.         """
  1331.        Output CQL commands which should be pasteable back into a CQL session
  1332.        to recreate the given table.
  1333.  
  1334.        Writes output to the given out stream.
  1335.        """
  1336.         out.write(self.get_table_meta(ksname, cfname).export_as_string())
  1337.         out.write("\n")
  1338.  
  1339.     def print_recreate_index(self, ksname, idxname, out):
  1340.         """
  1341.        Output CQL commands which should be pasteable back into a CQL session
  1342.        to recreate the given index.
  1343.  
  1344.        Writes output to the given out stream.
  1345.        """
  1346.         out.write(self.get_index_meta(ksname, idxname).export_as_string())
  1347.         out.write("\n")
  1348.  
  1349.     def print_recreate_object(self, ks, name, out):
  1350.         """
  1351.        Output CQL commands which should be pasteable back into a CQL session
  1352.        to recreate the given object (ks, table or index).
  1353.  
  1354.        Writes output to the given out stream.
  1355.        """
  1356.         out.write(self.get_object_meta(ks, name).export_as_string())
  1357.         out.write("\n")
  1358.  
  1359.     def describe_keyspaces(self):
  1360.         print
  1361.         cmd.Cmd.columnize(self, protect_names(self.get_keyspace_names()))
  1362.         print
  1363.  
  1364.     def describe_keyspace(self, ksname):
  1365.         print
  1366.         self.print_recreate_keyspace(self.get_keyspace_meta(ksname), sys.stdout)
  1367.         print
  1368.  
  1369.     def describe_columnfamily(self, ksname, cfname):
  1370.         if ksname is None:
  1371.             ksname = self.current_keyspace
  1372.         if ksname is None:
  1373.             raise NoKeyspaceError("No keyspace specified and no current keyspace")
  1374.         print
  1375.         self.print_recreate_columnfamily(ksname, cfname, sys.stdout)
  1376.         print
  1377.  
  1378.     def describe_index(self, ksname, idxname):
  1379.         print
  1380.         self.print_recreate_index(ksname, idxname, sys.stdout)
  1381.         print
  1382.  
  1383.     def describe_object(self, ks, name):
  1384.         print
  1385.         self.print_recreate_object(ks, name, sys.stdout)
  1386.         print
  1387.  
  1388.     def describe_columnfamilies(self, ksname):
  1389.         print
  1390.         if ksname is None:
  1391.             for k in self.get_keyspaces():
  1392.                 name = protect_name(k.name)
  1393.                 print 'Keyspace %s' % (name,)
  1394.                 print '---------%s' % ('-' * len(name))
  1395.                 cmd.Cmd.columnize(self, protect_names(self.get_columnfamily_names(k.name)))
  1396.                 print
  1397.         else:
  1398.             cmd.Cmd.columnize(self, protect_names(self.get_columnfamily_names(ksname)))
  1399.             print
  1400.  
  1401.     def describe_functions(self, ksname=None):
  1402.         print
  1403.         if ksname is None:
  1404.             for ksmeta in self.get_keyspaces():
  1405.                 name = protect_name(ksmeta.name)
  1406.                 print 'Keyspace %s' % (name,)
  1407.                 print '---------%s' % ('-' * len(name))
  1408.                 cmd.Cmd.columnize(self, protect_names(ksmeta.functions.keys()))
  1409.                 print
  1410.         else:
  1411.             ksmeta = self.get_keyspace_meta(ksname)
  1412.             cmd.Cmd.columnize(self, protect_names(ksmeta.functions.keys()))
  1413.             print
  1414.  
  1415.     def describe_function(self, ksname, functionname):
  1416.         if ksname is None:
  1417.             ksname = self.current_keyspace
  1418.         if ksname is None:
  1419.             raise NoKeyspaceError("No keyspace specified and no current keyspace")
  1420.         print
  1421.         ksmeta = self.get_keyspace_meta(ksname)
  1422.         functions = filter(lambda f: f.name == functionname, ksmeta.functions.values())
  1423.         if len(functions) == 0:
  1424.             raise FunctionNotFound("User defined function %r not found" % functionname)
  1425.         print "\n\n".join(func.as_cql_query(formatted=True) for func in functions)
  1426.         print
  1427.  
  1428.     def describe_aggregates(self, ksname=None):
  1429.         print
  1430.         if ksname is None:
  1431.             for ksmeta in self.get_keyspaces():
  1432.                 name = protect_name(ksmeta.name)
  1433.                 print 'Keyspace %s' % (name,)
  1434.                 print '---------%s' % ('-' * len(name))
  1435.                 cmd.Cmd.columnize(self, protect_names(ksmeta.aggregates.keys()))
  1436.                 print
  1437.         else:
  1438.             ksmeta = self.get_keyspace_meta(ksname)
  1439.             cmd.Cmd.columnize(self, protect_names(ksmeta.aggregates.keys()))
  1440.             print
  1441.  
  1442.     def describe_aggregate(self, ksname, aggregatename):
  1443.         if ksname is None:
  1444.             ksname = self.current_keyspace
  1445.         if ksname is None:
  1446.             raise NoKeyspaceError("No keyspace specified and no current keyspace")
  1447.         print
  1448.         ksmeta = self.get_keyspace_meta(ksname)
  1449.         aggregates = filter(lambda f: f.name == aggregatename, ksmeta.aggregates.values())
  1450.         if len(aggregates) == 0:
  1451.             raise FunctionNotFound("User defined aggregate %r not found" % aggregatename)
  1452.         print "\n\n".join(aggr.as_cql_query(formatted=True) for aggr in aggregates)
  1453.         print
  1454.  
  1455.     def describe_usertypes(self, ksname):
  1456.         print
  1457.         if ksname is None:
  1458.             for ksmeta in self.get_keyspaces():
  1459.                 name = protect_name(ksmeta.name)
  1460.                 print 'Keyspace %s' % (name,)
  1461.                 print '---------%s' % ('-' * len(name))
  1462.                 cmd.Cmd.columnize(self, protect_names(ksmeta.user_types.keys()))
  1463.                 print
  1464.         else:
  1465.             ksmeta = self.get_keyspace_meta(ksname)
  1466.             cmd.Cmd.columnize(self, protect_names(ksmeta.user_types.keys()))
  1467.             print
  1468.  
  1469.     def describe_usertype(self, ksname, typename):
  1470.         if ksname is None:
  1471.             ksname = self.current_keyspace
  1472.         if ksname is None:
  1473.             raise NoKeyspaceError("No keyspace specified and no current keyspace")
  1474.         print
  1475.         ksmeta = self.get_keyspace_meta(ksname)
  1476.         try:
  1477.             usertype = ksmeta.user_types[typename]
  1478.         except KeyError:
  1479.             raise UserTypeNotFound("User type %r not found" % typename)
  1480.         print usertype.as_cql_query(formatted=True)
  1481.         print
  1482.  
  1483.     def describe_cluster(self):
  1484.         print '\nCluster: %s' % self.get_cluster_name()
  1485.         p = trim_if_present(self.get_partitioner(), 'org.apache.cassandra.dht.')
  1486.         print 'Partitioner: %s\n' % p
  1487.         # TODO: snitch?
  1488.         #snitch = trim_if_present(self.get_snitch(), 'org.apache.cassandra.locator.')
  1489.         #print 'Snitch: %s\n' % snitch
  1490.         if self.current_keyspace is not None \
  1491.         and self.current_keyspace != 'system':
  1492.             print "Range ownership:"
  1493.             ring = self.get_ring()
  1494.             for entry in ring.items():
  1495.                 print ' %39s  [%s]' % (str(entry[0].value), ', '.join([host.address for host in entry[1]]))
  1496.             print
  1497.  
  1498.     def describe_schema(self, include_system=False):
  1499.         print
  1500.         for k in self.get_keyspaces():
  1501.             if include_system or not k.name in cql3handling.SYSTEM_KEYSPACES:
  1502.                 self.print_recreate_keyspace(k, sys.stdout)
  1503.                 print
  1504.  
  1505.     def do_describe(self, parsed):
  1506.         """
  1507.        DESCRIBE [cqlsh only]
  1508.  
  1509.        (DESC may be used as a shorthand.)
  1510.  
  1511.          Outputs information about the connected Cassandra cluster, or about
  1512.          the data stored on it. Use in one of the following ways:
  1513.  
  1514.        DESCRIBE KEYSPACES
  1515.  
  1516.          Output the names of all keyspaces.
  1517.  
  1518.        DESCRIBE KEYSPACE [<keyspacename>]
  1519.  
  1520.          Output CQL commands that could be used to recreate the given
  1521.          keyspace, and the tables in it. In some cases, as the CQL interface
  1522.          matures, there will be some metadata about a keyspace that is not
  1523.          representable with CQL. That metadata will not be shown.
  1524.  
  1525.          The '<keyspacename>' argument may be omitted when using a non-system
  1526.          keyspace; in that case, the current keyspace will be described.
  1527.  
  1528.        DESCRIBE TABLES
  1529.  
  1530.          Output the names of all tables in the current keyspace, or in all
  1531.          keyspaces if there is no current keyspace.
  1532.  
  1533.        DESCRIBE TABLE <tablename>
  1534.  
  1535.          Output CQL commands that could be used to recreate the given table.
  1536.          In some cases, as above, there may be table metadata which is not
  1537.          representable and which will not be shown.
  1538.  
  1539.        DESCRIBE INDEX <indexname>
  1540.  
  1541.          Output CQL commands that could be used to recreate the given index.
  1542.          In some cases, there may be index metadata which is not representable
  1543.          and which will not be shown.
  1544.  
  1545.        DESCRIBE CLUSTER
  1546.  
  1547.          Output information about the connected Cassandra cluster, such as the
  1548.          cluster name, and the partitioner and snitch in use. When you are
  1549.          connected to a non-system keyspace, also shows endpoint-range
  1550.          ownership information for the Cassandra ring.
  1551.  
  1552.        DESCRIBE [FULL] SCHEMA
  1553.  
  1554.          Output CQL commands that could be used to recreate the entire (non-system) schema.
  1555.          Works as though "DESCRIBE KEYSPACE k" was invoked for each non-system keyspace
  1556.          k. Use DESCRIBE FULL SCHEMA to include the system keyspaces.
  1557.  
  1558.        DESCRIBE FUNCTIONS <keyspace>
  1559.  
  1560.          Output the names of all user defined functions in the given keyspace.
  1561.  
  1562.        DESCRIBE FUNCTION [<keyspace>.]<function>
  1563.  
  1564.          Describe the given user defined function.
  1565.  
  1566.        DESCRIBE AGGREGATES <keyspace>
  1567.  
  1568.          Output the names of all user defined aggregates in the given keyspace.
  1569.  
  1570.        DESCRIBE AGGREGATE [<keyspace>.]<aggregate>
  1571.  
  1572.          Describe the given user defined aggregate.
  1573.  
  1574.        DESCRIBE <objname>
  1575.  
  1576.          Output CQL commands that could be used to recreate the entire object schema,
  1577.          where object can be either a keyspace or a table or an index (in this order).
  1578.     """
  1579.         what = parsed.matched[1][1].lower()
  1580.         if what == 'functions':
  1581.             ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1582.             self.describe_functions(ksname)
  1583.         elif what == 'function':
  1584.             ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1585.             functionname = self.cql_unprotect_name(parsed.get_binding('udfname'))
  1586.             self.describe_function(ksname, functionname)
  1587.         elif what == 'aggregates':
  1588.             ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1589.             self.describe_aggregates(ksname)
  1590.         elif what == 'aggregate':
  1591.             ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1592.             aggregatename = self.cql_unprotect_name(parsed.get_binding('udaname'))
  1593.             self.describe_aggregate(ksname, aggregatename)
  1594.         elif what == 'keyspaces':
  1595.             self.describe_keyspaces()
  1596.         elif what == 'keyspace':
  1597.             ksname = self.cql_unprotect_name(parsed.get_binding('ksname', ''))
  1598.             if not ksname:
  1599.                 ksname = self.current_keyspace
  1600.                 if ksname is None:
  1601.                     self.printerr('Not in any keyspace.')
  1602.                     return
  1603.             self.describe_keyspace(ksname)
  1604.         elif what in ('columnfamily', 'table'):
  1605.             ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1606.             cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
  1607.             self.describe_columnfamily(ks, cf)
  1608.         elif what == 'index':
  1609.             ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1610.             idx = self.cql_unprotect_name(parsed.get_binding('idxname', None))
  1611.             self.describe_index(ks, idx)
  1612.         elif what in ('columnfamilies', 'tables'):
  1613.             self.describe_columnfamilies(self.current_keyspace)
  1614.         elif what == 'types':
  1615.             self.describe_usertypes(self.current_keyspace)
  1616.         elif what == 'type':
  1617.             ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1618.             ut = self.cql_unprotect_name(parsed.get_binding('utname'))
  1619.             self.describe_usertype(ks, ut)
  1620.         elif what == 'cluster':
  1621.             self.describe_cluster()
  1622.         elif what == 'schema':
  1623.             self.describe_schema(False)
  1624.         elif what == 'full' and parsed.matched[2][1].lower() == 'schema':
  1625.             self.describe_schema(True)
  1626.         elif what:
  1627.             ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1628.             name = self.cql_unprotect_name(parsed.get_binding('cfname'))
  1629.             if not name:
  1630.                 name = self.cql_unprotect_name(parsed.get_binding('idxname', None))
  1631.             self.describe_object(ks, name)
  1632.     do_desc = do_describe
  1633.  
  1634.     def do_copy(self, parsed):
  1635.         r"""
  1636.        COPY [cqlsh only]
  1637.  
  1638.          COPY x FROM: Imports CSV data into a Cassandra table
  1639.          COPY x TO: Exports data from a Cassandra table in CSV format.
  1640.  
  1641.        COPY <table_name> [ ( column [, ...] ) ]
  1642.             FROM ( '<filename>' | STDIN )
  1643.             [ WITH <option>='value' [AND ...] ];
  1644.  
  1645.        COPY <table_name> [ ( column [, ...] ) ]
  1646.             TO ( '<filename>' | STDOUT )
  1647.             [ WITH <option>='value' [AND ...] ];
  1648.  
  1649.        Available options and defaults:
  1650.  
  1651.          DELIMITER=','    - character that appears between records
  1652.          QUOTE='"'        - quoting character to be used to quote fields
  1653.          ESCAPE='\'       - character to appear before the QUOTE char when quoted
  1654.          HEADER=false     - whether to ignore the first line
  1655.          NULL=''          - string that represents a null value
  1656.          ENCODING='utf8'  - encoding for CSV output (COPY TO only)
  1657.  
  1658.        When entering CSV data on STDIN, you can use the sequence "\."
  1659.        on a line by itself to end the data input.
  1660.        """
  1661.         ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
  1662.         if ks is None:
  1663.             ks = self.current_keyspace
  1664.             if ks is None:
  1665.                 raise NoKeyspaceError("Not in any keyspace.")
  1666.         cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
  1667.         columns = parsed.get_binding('colnames', None)
  1668.         if columns is not None:
  1669.             columns = map(self.cql_unprotect_name, columns)
  1670.         else:
  1671.             # default to all known columns
  1672.             columns = self.get_column_names(ks, cf)
  1673.         fname = parsed.get_binding('fname', None)
  1674.         if fname is not None:
  1675.             fname = os.path.expanduser(self.cql_unprotect_value(fname))
  1676.         copyoptnames = map(str.lower, parsed.get_binding('optnames', ()))
  1677.         copyoptvals = map(self.cql_unprotect_value, parsed.get_binding('optvals', ()))
  1678.         cleancopyoptvals = [optval.decode('string-escape') for optval in copyoptvals]
  1679.         opts = dict(zip(copyoptnames, cleancopyoptvals))
  1680.  
  1681.         timestart = time.time()
  1682.  
  1683.         direction = parsed.get_binding('dir').upper()
  1684.         if direction == 'FROM':
  1685.             rows = self.perform_csv_import(ks, cf, columns, fname, opts)
  1686.             verb = 'imported'
  1687.         elif direction == 'TO':
  1688.             rows = self.perform_csv_export(ks, cf, columns, fname, opts)
  1689.             verb = 'exported'
  1690.         else:
  1691.             raise SyntaxError("Unknown direction %s" % direction)
  1692.  
  1693.         timeend = time.time()
  1694.         print "\n%d rows %s in %s." % (rows, verb, describe_interval(timeend - timestart))
  1695.  
  1696.     def perform_csv_import(self, ks, cf, columns, fname, opts):
  1697.         dialect_options = self.csv_dialect_defaults.copy()
  1698.         if 'quote' in opts:
  1699.             dialect_options['quotechar'] = opts.pop('quote')
  1700.         if 'escape' in opts:
  1701.             dialect_options['escapechar'] = opts.pop('escape')
  1702.         if 'delimiter' in opts:
  1703.             dialect_options['delimiter'] = opts.pop('delimiter')
  1704.         nullval = opts.pop('null', '')
  1705.         header = bool(opts.pop('header', '').lower() == 'true')
  1706.         if dialect_options['quotechar'] == dialect_options['escapechar']:
  1707.             dialect_options['doublequote'] = True
  1708.             del dialect_options['escapechar']
  1709.         if opts:
  1710.             self.printerr('Unrecognized COPY FROM options: %s'
  1711.                           % ', '.join(opts.keys()))
  1712.             return 0
  1713.  
  1714.         if fname is None:
  1715.             do_close = False
  1716.             print "[Use \. on a line by itself to end input]"
  1717.             linesource = self.use_stdin_reader(prompt='[copy] ', until=r'\.')
  1718.         else:
  1719.             do_close = True
  1720.             try:
  1721.                 linesource = open(fname, 'rb')
  1722.             except IOError, e:
  1723.                 self.printerr("Can't open %r for reading: %s" % (fname, e))
  1724.                 return 0
  1725.  
  1726.         current_record = None
  1727.  
  1728.         try:
  1729.             if header:
  1730.                 linesource.next()
  1731.             reader = csv.reader(linesource, **dialect_options)
  1732.  
  1733.             from multiprocessing import Process, Pipe, cpu_count
  1734.  
  1735.             # Pick a resonable number of child processes. We need to leave at
  1736.             # least one core for the parent process.  This doesn't necessarily
  1737.             # need to be capped at 4, but it's currently enough to keep
  1738.             # a single local Cassandra node busy, and I see lower throughput
  1739.             # with more processes.
  1740.             try:
  1741.                 num_processes = max(1, min(4, cpu_count() - 1))
  1742.             except NotImplementedError:
  1743.                 num_processes = 1
  1744.  
  1745.             processes, pipes = [], [],
  1746.             for i in range(num_processes):
  1747.                 parent_conn, child_conn = Pipe()
  1748.                 pipes.append(parent_conn)
  1749.                 processes.append(ImportProcess(self, child_conn, ks, cf, columns, nullval))
  1750.  
  1751.             for process in processes:
  1752.                 process.start()
  1753.  
  1754.             meter = RateMeter(10000)
  1755.             for current_record, row in enumerate(reader, start=1):
  1756.                 # write to the child process
  1757.                 pipes[current_record % num_processes].send((current_record, row))
  1758.  
  1759.                 # update the progress and current rate periodically
  1760.                 meter.increment()
  1761.  
  1762.                 # check for any errors reported by the children
  1763.                 if (current_record % 100) == 0:
  1764.                     if self._check_child_pipes(current_record, pipes):
  1765.                         # no errors seen, continue with outer loop
  1766.                         continue
  1767.                     else:
  1768.                         # errors seen, break out of outer loop
  1769.                         break
  1770.         except Exception, exc:
  1771.             if current_record is None:
  1772.                 # we failed before we started
  1773.                 self.printerr("\nError starting import process:\n")
  1774.                 self.printerr(str(exc))
  1775.                 if self.debug:
  1776.                     traceback.print_exc()
  1777.             else:
  1778.                 self.printerr("\n" + str(exc))
  1779.                 self.printerr("\nAborting import at record #%d. "
  1780.                               "Previously inserted records and some records after "
  1781.                               "this number may be present."
  1782.                               % (current_record,))
  1783.                 if self.debug:
  1784.                     traceback.print_exc()
  1785.         finally:
  1786.             # send a message that indicates we're done
  1787.             for pipe in pipes:
  1788.                 pipe.send((None, None))
  1789.  
  1790.             for process in processes:
  1791.                 process.join()
  1792.  
  1793.             self._check_child_pipes(current_record, pipes)
  1794.  
  1795.             for pipe in pipes:
  1796.                 pipe.close()
  1797.  
  1798.             if do_close:
  1799.                 linesource.close()
  1800.             elif self.tty:
  1801.                 print
  1802.  
  1803.         return current_record
  1804.  
  1805.     def _check_child_pipes(self, current_record, pipes):
  1806.         # check the pipes for errors from child processes
  1807.         for pipe in pipes:
  1808.             if pipe.poll():
  1809.                 try:
  1810.                     (record_num, error) = pipe.recv()
  1811.                     self.printerr("\n" + str(error))
  1812.                     self.printerr(
  1813.                         "Aborting import at record #%d. "
  1814.                         "Previously inserted records are still present, "
  1815.                         "and some records after that may be present as well."
  1816.                         % (record_num,))
  1817.                     return False
  1818.                 except EOFError:
  1819.                     # pipe is closed, nothing to read
  1820.                     self.printerr("\nChild process died without notification, "
  1821.                                   "aborting import at record #%d. Previously "
  1822.                                   "inserted records are probably still present, "
  1823.                                   "and some records after that may be present "
  1824.                                   "as well." % (current_record,))
  1825.                     return False
  1826.         return True
  1827.  
  1828.     def perform_csv_export(self, ks, cf, columns, fname, opts):
  1829.         dialect_options = self.csv_dialect_defaults.copy()
  1830.         if 'quote' in opts:
  1831.             dialect_options['quotechar'] = opts.pop('quote')
  1832.         if 'escape' in opts:
  1833.             dialect_options['escapechar'] = opts.pop('escape')
  1834.         if 'delimiter' in opts:
  1835.             dialect_options['delimiter'] = opts.pop('delimiter')
  1836.         encoding = opts.pop('encoding', 'utf8')
  1837.         nullval = opts.pop('null', '')
  1838.         header = bool(opts.pop('header', '').lower() == 'true')
  1839.         if dialect_options['quotechar'] == dialect_options['escapechar']:
  1840.             dialect_options['doublequote'] = True
  1841.             del dialect_options['escapechar']
  1842.  
  1843.         if opts:
  1844.             self.printerr('Unrecognized COPY TO options: %s'
  1845.                           % ', '.join(opts.keys()))
  1846.             return 0
  1847.  
  1848.         if fname is None:
  1849.             do_close = False
  1850.             csvdest = sys.stdout
  1851.         else:
  1852.             do_close = True
  1853.             try:
  1854.                 csvdest = open(fname, 'wb')
  1855.             except IOError, e:
  1856.                 self.printerr("Can't open %r for writing: %s" % (fname, e))
  1857.                 return 0
  1858.  
  1859.         meter = RateMeter(10000)
  1860.         try:
  1861.             dtformats = DateTimeFormat(self.display_timestamp_format, self.display_date_format, self.display_nanotime_format)
  1862.             dump = self.prep_export_dump(ks, cf, columns)
  1863.             writer = csv.writer(csvdest, **dialect_options)
  1864.             if header:
  1865.                 writer.writerow(columns)
  1866.             for row in dump:
  1867.                 fmt = lambda v: \
  1868.                     format_value(v, output_encoding=encoding, nullval=nullval,
  1869.                                  date_time_format=dtformats,
  1870.                                  float_precision=self.display_float_precision).strval
  1871.                 writer.writerow(map(fmt, row.values()))
  1872.                 meter.increment()
  1873.         finally:
  1874.             if do_close:
  1875.                 csvdest.close()
  1876.         return meter.current_record
  1877.  
  1878.     def prep_export_dump(self, ks, cf, columns):
  1879.         if columns is None:
  1880.             columns = self.get_column_names(ks, cf)
  1881.         columnlist = ', '.join(protect_names(columns))
  1882.         query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(ks), protect_name(cf))
  1883.         return self.session.execute(query)
  1884.  
  1885.     def do_show(self, parsed):
  1886.         """
  1887.        SHOW [cqlsh only]
  1888.  
  1889.          Displays information about the current cqlsh session. Can be called in
  1890.          the following ways:
  1891.  
  1892.        SHOW VERSION
  1893.  
  1894.          Shows the version and build of the connected Cassandra instance, as
  1895.          well as the versions of the CQL spec and the Thrift protocol that
  1896.          the connected Cassandra instance understands.
  1897.  
  1898.        SHOW HOST
  1899.  
  1900.          Shows where cqlsh is currently connected.
  1901.  
  1902.        SHOW SESSION <sessionid>
  1903.  
  1904.          Pretty-prints the requested tracing session.
  1905.        """
  1906.         showwhat = parsed.get_binding('what').lower()
  1907.         if showwhat == 'version':
  1908.             self.get_connection_versions()
  1909.             self.show_version()
  1910.         elif showwhat == 'host':
  1911.             self.show_host()
  1912.         elif showwhat.startswith('session'):
  1913.             session_id = parsed.get_binding('sessionid').lower()
  1914.             self.show_session(UUID(session_id))
  1915.         else:
  1916.             self.printerr('Wait, how do I show %r?' % (showwhat,))
  1917.  
  1918.     def do_source(self, parsed):
  1919.         """
  1920.        SOURCE [cqlsh only]
  1921.  
  1922.        Executes a file containing CQL statements. Gives the output for each
  1923.        statement in turn, if any, or any errors that occur along the way.
  1924.  
  1925.        Errors do NOT abort execution of the CQL source file.
  1926.  
  1927.        Usage:
  1928.  
  1929.          SOURCE '<file>';
  1930.  
  1931.        That is, the path to the file to be executed must be given inside a
  1932.        string literal. The path is interpreted relative to the current working
  1933.        directory. The tilde shorthand notation ('~/mydir') is supported for
  1934.        referring to $HOME.
  1935.  
  1936.        See also the --file option to cqlsh.
  1937.        """
  1938.         fname = parsed.get_binding('fname')
  1939.         fname = os.path.expanduser(self.cql_unprotect_value(fname))
  1940.         try:
  1941.             encoding, bom_size = get_file_encoding_bomsize(fname)
  1942.             f = codecs.open(fname, 'r', encoding)
  1943.             f.seek(bom_size)
  1944.         except IOError, e:
  1945.             self.printerr('Could not open %r: %s' % (fname, e))
  1946.             return
  1947.         subshell = Shell(self.hostname, self.port,
  1948.                          color=self.color, encoding=self.encoding, stdin=f,
  1949.                          tty=False, use_conn=self.conn, cqlver=self.cql_version,
  1950.                          display_timestamp_format=self.display_timestamp_format,
  1951.                          display_date_format=self.display_date_format,
  1952.                          display_nanotime_format=self.display_nanotime_format,
  1953.                          display_float_precision=self.display_float_precision,
  1954.                          max_trace_wait=self.max_trace_wait)
  1955.         subshell.cmdloop()
  1956.         f.close()
  1957.  
  1958.     def do_capture(self, parsed):
  1959.         """
  1960.        CAPTURE [cqlsh only]
  1961.  
  1962.        Begins capturing command output and appending it to a specified file.
  1963.        Output will not be shown at the console while it is captured.
  1964.  
  1965.        Usage:
  1966.  
  1967.          CAPTURE '<file>';
  1968.          CAPTURE OFF;
  1969.          CAPTURE;
  1970.  
  1971.        That is, the path to the file to be appended to must be given inside a
  1972.        string literal. The path is interpreted relative to the current working
  1973.        directory. The tilde shorthand notation ('~/mydir') is supported for
  1974.        referring to $HOME.
  1975.  
  1976.        Only query result output is captured. Errors and output from cqlsh-only
  1977.        commands will still be shown in the cqlsh session.
  1978.  
  1979.        To stop capturing output and show it in the cqlsh session again, use
  1980.        CAPTURE OFF.
  1981.  
  1982.        To inspect the current capture configuration, use CAPTURE with no
  1983.        arguments.
  1984.        """
  1985.         fname = parsed.get_binding('fname')
  1986.         if fname is None:
  1987.             if self.shunted_query_out is not None:
  1988.                 print "Currently capturing query output to %r." % (self.query_out.name,)
  1989.             else:
  1990.                 print "Currently not capturing query output."
  1991.             return
  1992.  
  1993.         if fname.upper() == 'OFF':
  1994.             if self.shunted_query_out is None:
  1995.                 self.printerr('Not currently capturing output.')
  1996.                 return
  1997.             self.query_out.close()
  1998.             self.query_out = self.shunted_query_out
  1999.             self.color = self.shunted_color
  2000.             self.shunted_query_out = None
  2001.             del self.shunted_color
  2002.             return
  2003.  
  2004.         if self.shunted_query_out is not None:
  2005.             self.printerr('Already capturing output to %s. Use CAPTURE OFF'
  2006.                           ' to disable.' % (self.query_out.name,))
  2007.             return
  2008.  
  2009.         fname = os.path.expanduser(self.cql_unprotect_value(fname))
  2010.         try:
  2011.             f = open(fname, 'a')
  2012.         except IOError, e:
  2013.             self.printerr('Could not open %r for append: %s' % (fname, e))
  2014.             return
  2015.         self.shunted_query_out = self.query_out
  2016.         self.shunted_color = self.color
  2017.         self.query_out = f
  2018.         self.color = False
  2019.         print 'Now capturing query output to %r.' % (fname,)
  2020.  
  2021.     def do_tracing(self, parsed):
  2022.         """
  2023.        TRACING [cqlsh]
  2024.  
  2025.          Enables or disables request tracing.
  2026.  
  2027.        TRACING ON
  2028.  
  2029.          Enables tracing for all further requests.
  2030.  
  2031.        TRACING OFF
  2032.  
  2033.          Disables tracing.
  2034.  
  2035.        TRACING
  2036.  
  2037.          TRACING with no arguments shows the current tracing status.
  2038.        """
  2039.         self.tracing_enabled = SwitchCommand("TRACING", "Tracing").execute(self.tracing_enabled, parsed, self.printerr)
  2040.  
  2041.     def do_expand(self, parsed):
  2042.         """
  2043.        EXPAND [cqlsh]
  2044.  
  2045.          Enables or disables expanded (vertical) output.
  2046.  
  2047.        EXPAND ON
  2048.  
  2049.          Enables expanded (vertical) output.
  2050.  
  2051.        EXPAND OFF
  2052.  
  2053.          Disables expanded (vertical) output.
  2054.  
  2055.        EXPAND
  2056.  
  2057.          EXPAND with no arguments shows the current value of expand setting.
  2058.        """
  2059.         self.expand_enabled = SwitchCommand("EXPAND", "Expanded output").execute(self.expand_enabled, parsed, self.printerr)
  2060.  
  2061.     def do_consistency(self, parsed):
  2062.         """
  2063.        CONSISTENCY [cqlsh only]
  2064.  
  2065.           Overrides default consistency level (default level is ONE).
  2066.  
  2067.        CONSISTENCY <level>
  2068.  
  2069.           Sets consistency level for future requests.
  2070.  
  2071.           Valid consistency levels:
  2072.  
  2073.           ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_ONE, LOCAL_QUORUM, EACH_QUORUM, SERIAL and LOCAL_SERIAL.
  2074.  
  2075.           SERIAL and LOCAL_SERIAL may be used only for SELECTs; will be rejected with updates.
  2076.  
  2077.        CONSISTENCY
  2078.  
  2079.           CONSISTENCY with no arguments shows the current consistency level.
  2080.        """
  2081.         level = parsed.get_binding('level')
  2082.         if level is None:
  2083.             print 'Current consistency level is %s.' % (cassandra.ConsistencyLevel.value_to_name[self.consistency_level])
  2084.             return
  2085.  
  2086.         self.consistency_level = cassandra.ConsistencyLevel.name_to_value[level.upper()]
  2087.         print 'Consistency level set to %s.' % (level.upper(),)
  2088.  
  2089.     def do_serial(self, parsed):
  2090.         """
  2091.        SERIAL CONSISTENCY [cqlsh only]
  2092.  
  2093.           Overrides serial consistency level (default level is SERIAL).
  2094.  
  2095.        SERIAL CONSISTENCY <level>
  2096.  
  2097.           Sets consistency level for future conditional updates.
  2098.  
  2099.           Valid consistency levels:
  2100.  
  2101.           SERIAL, LOCAL_SERIAL.
  2102.  
  2103.        SERIAL CONSISTENCY
  2104.  
  2105.           SERIAL CONSISTENCY with no arguments shows the current consistency level.
  2106.        """
  2107.         level = parsed.get_binding('level')
  2108.         if level is None:
  2109.             print 'Current serial consistency level is %s.' % (cassandra.ConsistencyLevel.value_to_name[self.serial_consistency_level])
  2110.             return
  2111.  
  2112.         self.serial_consistency_level = cassandra.ConsistencyLevel.name_to_value[level.upper()]
  2113.         print 'Serial consistency level set to %s.' % (level.upper(),)
  2114.  
  2115.     def do_login(self, parsed):
  2116.         """
  2117.        LOGIN [cqlsh only]
  2118.  
  2119.           Changes login information without requiring restart.
  2120.  
  2121.        LOGIN <username> (<password>)
  2122.  
  2123.           Login using the specified username. If password is specified, it will be used
  2124.           otherwise, you will be prompted to enter.
  2125.        """
  2126.         username = parsed.get_binding('username')
  2127.         password = parsed.get_binding('password')
  2128.         if password is None:
  2129.             password = getpass.getpass()
  2130.         else:
  2131.             password = password[1:-1]
  2132.  
  2133.         auth_provider = PlainTextAuthProvider(username=username, password=password)
  2134.  
  2135.         conn = Cluster(contact_points=(self.hostname,), port=self.port, cql_version=self.conn.cql_version,
  2136.                        protocol_version=self.conn.protocol_version,
  2137.                        auth_provider=auth_provider,
  2138.                        ssl_options=self.conn.ssl_options,
  2139.                        load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
  2140.                        connect_timeout=self.conn.connect_timeout)
  2141.  
  2142.         if self.current_keyspace:
  2143.             session = conn.connect(self.current_keyspace)
  2144.         else:
  2145.             session = conn.connect()
  2146.  
  2147.         # Update after we've connected in case we fail to authenticate
  2148.         self.conn = conn
  2149.         self.auth_provider = auth_provider
  2150.         self.username = username
  2151.         self.session = session
  2152.  
  2153.     def do_exit(self, parsed=None):
  2154.         """
  2155.        EXIT/QUIT [cqlsh only]
  2156.  
  2157.        Exits cqlsh.
  2158.        """
  2159.         self.stop = True
  2160.         if self.owns_connection:
  2161.             self.conn.shutdown()
  2162.     do_quit = do_exit
  2163.  
  2164.     def do_debug(self, parsed):
  2165.         import pdb
  2166.         pdb.set_trace()
  2167.  
  2168.     def get_help_topics(self):
  2169.         topics = [t[3:] for t in dir(self) if t.startswith('do_') and getattr(self, t, None).__doc__]
  2170.         for hide_from_help in ('quit',):
  2171.             topics.remove(hide_from_help)
  2172.         return topics
  2173.  
  2174.     def columnize(self, slist, *a, **kw):
  2175.         return cmd.Cmd.columnize(self, sorted([u.upper() for u in slist]), *a, **kw)
  2176.  
  2177.     def do_help(self, parsed):
  2178.         """
  2179.        HELP [cqlsh only]
  2180.  
  2181.        Gives information about cqlsh commands. To see available topics,
  2182.        enter "HELP" without any arguments. To see help on a topic,
  2183.        use "HELP <topic>".
  2184.        """
  2185.         topics = parsed.get_binding('topic', ())
  2186.         if not topics:
  2187.             shell_topics = [t.upper() for t in self.get_help_topics()]
  2188.             self.print_topics("\nDocumented shell commands:", shell_topics, 15, 80)
  2189.             cql_topics = [t.upper() for t in cqldocs.get_help_topics()]
  2190.             self.print_topics("CQL help topics:", cql_topics, 15, 80)
  2191.             return
  2192.         for t in topics:
  2193.             if t.lower() in self.get_help_topics():
  2194.                 doc = getattr(self, 'do_' + t.lower()).__doc__
  2195.                 self.stdout.write(doc + "\n")
  2196.             elif t.lower() in cqldocs.get_help_topics():
  2197.                 cqldocs.print_help_topic(t)
  2198.             else:
  2199.                 self.printerr("*** No help on %s" % (t,))
  2200.  
  2201.     def do_paging(self, parsed):
  2202.         """
  2203.        PAGING [cqlsh]
  2204.  
  2205.          Enables or disables query paging.
  2206.  
  2207.        PAGING ON
  2208.  
  2209.          Enables query paging for all further queries.
  2210.  
  2211.        PAGING OFF
  2212.  
  2213.          Disables paging.
  2214.  
  2215.        PAGING
  2216.  
  2217.          PAGING with no arguments shows the current query paging status.
  2218.        """
  2219.         self.use_paging = SwitchCommand("PAGING", "Query paging").execute(self.use_paging, parsed, self.printerr)
  2220.  
  2221.     def applycolor(self, text, color=None):
  2222.         if not color or not self.color:
  2223.             return text
  2224.         return color + text + ANSI_RESET
  2225.  
  2226.     def writeresult(self, text, color=None, newline=True, out=None):
  2227.         if out is None:
  2228.             out = self.query_out
  2229.         out.write(self.applycolor(str(text), color) + ('\n' if newline else ''))
  2230.  
  2231.     def flush_output(self):
  2232.         self.query_out.flush()
  2233.  
  2234.     def printerr(self, text, color=RED, newline=True, shownum=None):
  2235.         self.statement_error = True
  2236.         if shownum is None:
  2237.             shownum = self.show_line_nums
  2238.         if shownum:
  2239.             text = '%s:%d:%s' % (self.stdin.name, self.lineno, text)
  2240.         self.writeresult(text, color, newline=newline, out=sys.stderr)
  2241.  
  2242. import multiprocessing
  2243. class ImportProcess(multiprocessing.Process):
  2244.     def __init__(self, parent, pipe, ks, cf, columns, nullval):
  2245.         multiprocessing.Process.__init__(self)
  2246.         self.pipe = pipe
  2247.         self.nullval = nullval
  2248.         self.ks = ks
  2249.         self.cf = cf
  2250.  
  2251.         #validate we can fetch metdata but don't store it since win32 needs to pickle
  2252.         parent.get_table_meta(ks, cf)
  2253.  
  2254.         self.columns = columns
  2255.         self.consistency_level = parent.consistency_level
  2256.         self.connect_timeout = parent.conn.connect_timeout
  2257.         self.hostname = parent.hostname
  2258.         self.port = parent.port
  2259.         self.ssl = parent.ssl
  2260.         self.auth_provider = parent.auth_provider
  2261.         self.cql_version = parent.conn.cql_version
  2262.         self.debug = parent.debug
  2263.  
  2264.     def run(self):
  2265.         new_cluster = Cluster(
  2266.                 contact_points=(self.hostname,),
  2267.                 port=self.port,
  2268.                 cql_version=self.cql_version,
  2269.                 protocol_version=DEFAULT_PROTOCOL_VERSION,
  2270.                 auth_provider=self.auth_provider,
  2271.                 ssl_options=sslhandling.ssl_settings(hostname, CONFIG_FILE) if self.ssl else None,
  2272.                 load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
  2273.                 compression=None,
  2274.                 connect_timeout=self.connect_timeout)
  2275.         session = new_cluster.connect(self.ks)
  2276.         conn = session._pools.values()[0]._connection
  2277.  
  2278.         table_meta = new_cluster.metadata.keyspaces[self.ks].tables[self.cf]
  2279.  
  2280.         pk_cols = [col.name for col in table_meta.primary_key]
  2281.         cqltypes = [table_meta.columns[name].typestring for name in self.columns]
  2282.         pk_indexes = [self.columns.index(col.name) for col in table_meta.primary_key]
  2283.         query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
  2284.             protect_name(table_meta.keyspace.name),
  2285.             protect_name(table_meta.name),
  2286.             ', '.join(protect_names(self.columns)))
  2287.  
  2288.         # we need to handle some types specially
  2289.         should_escape = [t in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet') for t in cqltypes]
  2290.  
  2291.         insert_timestamp = int(time.time() * 1e6)
  2292.  
  2293.         def callback(record_num, response):
  2294.             # This is the callback we register for all inserts.  Because this
  2295.             # is run on the event-loop thread, we need to hold a lock when
  2296.             # adjusting in_flight.
  2297.             with conn.lock:
  2298.                 conn.in_flight -= 1
  2299.  
  2300.             if not isinstance(response, ResultMessage):
  2301.                 # It's an error. Notify the parent process and let it send
  2302.                 # a stop signal to all child processes (including this one).
  2303.                 self.pipe.send((record_num, str(response)))
  2304.                 if isinstance(response, Exception) and self.debug:
  2305.                     traceback.print_exc(response)
  2306.  
  2307.         current_record = 0
  2308.         insert_num = 0
  2309.         try:
  2310.             while True:
  2311.                 # To avoid totally maxing out the connection,
  2312.                 # defer to the reactor thread when we're close
  2313.                 # to capacity
  2314.                 if conn.in_flight > (conn.max_request_id * 0.9):
  2315.                     conn._readable = True
  2316.                     time.sleep(0.05)
  2317.                     continue
  2318.  
  2319.                 try:
  2320.                     (current_record, row) = self.pipe.recv()
  2321.                 except EOFError:
  2322.                     # the pipe was closed and there's nothing to receive
  2323.                     sys.stdout.write('Failed to read from pipe:\n\n')
  2324.                     sys.stdout.flush()
  2325.                     conn._writable = True
  2326.                     conn._readable = True
  2327.                     break
  2328.  
  2329.                 # see if the parent process has signaled that we are done
  2330.                 if (current_record, row) == (None, None):
  2331.                     conn._writable = True
  2332.                     conn._readable = True
  2333.                     self.pipe.close()
  2334.                     break
  2335.  
  2336.                 # format the values in the row
  2337.                 for i, value in enumerate(row):
  2338.                     if value != self.nullval:
  2339.                         if should_escape[i]:
  2340.                             row[i] = protect_value(value)
  2341.                     elif i in pk_indexes:
  2342.                         # By default, nullval is an empty string. See CASSANDRA-7792 for details.
  2343.                         message = "Cannot insert null value for primary key column '%s'." % (pk_cols[i],)
  2344.                         if self.nullval == '':
  2345.                             message += " If you want to insert empty strings, consider using " \
  2346.                                        "the WITH NULL=<marker> option for COPY."
  2347.                         self.pipe.send((current_record, message))
  2348.                         return
  2349.                     else:
  2350.                         row[i] = 'null'
  2351.  
  2352.                 full_query = query % (','.join(row),)
  2353.                 query_message = QueryMessage(
  2354.                         full_query, self.consistency_level, serial_consistency_level=None,
  2355.                         fetch_size=None, paging_state=None, timestamp=insert_timestamp)
  2356.  
  2357.                 request_id = conn.get_request_id()
  2358.                 binary_message = query_message.to_binary(
  2359.                     stream_id=request_id, protocol_version=DEFAULT_PROTOCOL_VERSION, compression=None)
  2360.  
  2361.                 # add the message directly to the connection's queue
  2362.                 with conn.lock:
  2363.                     conn.in_flight += 1
  2364.  
  2365.                 conn._callbacks[request_id] = partial(callback, current_record)
  2366.                 conn.deque.append(binary_message)
  2367.  
  2368.                 # every 50 records, clear the pending writes queue and read
  2369.                 # any responses we have
  2370.                 if insert_num % 50 == 0:
  2371.                     conn._writable = True
  2372.                     conn._readable = True
  2373.  
  2374.                 insert_num += 1
  2375.         except Exception, exc:
  2376.             self.pipe.send((current_record, str(exc)))
  2377.         finally:
  2378.             # wait for any pending requests to finish
  2379.             while conn.in_flight > 0:
  2380.                 conn._readable = True
  2381.                 time.sleep(0.1)
  2382.  
  2383.             new_cluster.shutdown()
  2384.  
  2385.     def stop(self):
  2386.         self.terminate()
  2387.  
  2388.  
  2389.  
  2390. class RateMeter(object):
  2391.  
  2392.     def __init__(self, log_rate):
  2393.         self.log_rate = log_rate
  2394.         self.last_checkpoint_time = time.time()
  2395.         self.current_rate = 0.0
  2396.         self.current_record = 0
  2397.  
  2398.     def increment(self):
  2399.         self.current_record += 1
  2400.  
  2401.         if (self.current_record % self.log_rate) == 0:
  2402.             new_checkpoint_time = time.time()
  2403.             new_rate = self.log_rate / (new_checkpoint_time - self.last_checkpoint_time)
  2404.             self.last_checkpoint_time = new_checkpoint_time
  2405.  
  2406.             # smooth the rate a bit
  2407.             if self.current_rate == 0.0:
  2408.                 self.current_rate = new_rate
  2409.             else:
  2410.                 self.current_rate = (self.current_rate + new_rate) / 2.0
  2411.  
  2412.             output = 'Processed %s rows; Write: %.2f rows/s\r' % \
  2413.                      (self.current_record, self.current_rate)
  2414.             sys.stdout.write(output)
  2415.             sys.stdout.flush()
  2416.  
  2417.  
  2418. class SwitchCommand(object):
  2419.     command = None
  2420.     description = None
  2421.  
  2422.     def __init__(self, command, desc):
  2423.         self.command = command
  2424.         self.description = desc
  2425.  
  2426.     def execute(self, state, parsed, printerr):
  2427.         switch = parsed.get_binding('switch')
  2428.         if switch is None:
  2429.             if state:
  2430.                 print "%s is currently enabled. Use %s OFF to disable" \
  2431.                       % (self.description, self.command)
  2432.             else:
  2433.                 print "%s is currently disabled. Use %s ON to enable." \
  2434.                       % (self.description, self.command)
  2435.             return state
  2436.  
  2437.         if switch.upper() == 'ON':
  2438.             if state:
  2439.                 printerr('%s is already enabled. Use %s OFF to disable.'
  2440.                          % (self.description, self.command))
  2441.                 return state
  2442.             print 'Now %s is enabled' % (self.description,)
  2443.             return True
  2444.  
  2445.         if switch.upper() == 'OFF':
  2446.             if not state:
  2447.                 printerr('%s is not enabled.' % (self.description,))
  2448.                 return state
  2449.             print 'Disabled %s.' % (self.description,)
  2450.             return False
  2451.  
  2452.  
  2453. def option_with_default(cparser_getter, section, option, default=None):
  2454.     try:
  2455.         return cparser_getter(section, option)
  2456.     except ConfigParser.Error:
  2457.         return default
  2458.  
  2459. def raw_option_with_default(configs, section, option, default=None):
  2460.     """
  2461.    Same (almost) as option_with_default() but won't do any string interpolation.
  2462.    Useful for config values that include '%' symbol, e.g. time format string.
  2463.    """
  2464.     try:
  2465.         return configs.get(section, option, raw=True)
  2466.     except ConfigParser.Error:
  2467.         return default
  2468.  
  2469.  
  2470. def should_use_color():
  2471.     if not sys.stdout.isatty():
  2472.         return False
  2473.     if os.environ.get('TERM', '') in ('dumb', ''):
  2474.         return False
  2475.     try:
  2476.         import subprocess
  2477.         p = subprocess.Popen(['tput', 'colors'], stdout=subprocess.PIPE)
  2478.         stdout, _ = p.communicate()
  2479.         if int(stdout.strip()) < 8:
  2480.             return False
  2481.     except (OSError, ImportError, ValueError):
  2482.         # oh well, we tried. at least we know there's a $TERM and it's
  2483.         # not "dumb".
  2484.         pass
  2485.     return True
  2486.  
  2487.  
  2488. def read_options(cmdlineargs, environment):
  2489.     configs = ConfigParser.SafeConfigParser()
  2490.     configs.read(CONFIG_FILE)
  2491.  
  2492.     rawconfigs = ConfigParser.RawConfigParser()
  2493.     rawconfigs.read(CONFIG_FILE)
  2494.  
  2495.     optvalues = optparse.Values()
  2496.     optvalues.username = option_with_default(configs.get, 'authentication', 'username')
  2497.     optvalues.password = option_with_default(rawconfigs.get, 'authentication', 'password')
  2498.     optvalues.keyspace = option_with_default(configs.get, 'authentication', 'keyspace')
  2499.     optvalues.completekey = option_with_default(configs.get, 'ui', 'completekey',
  2500.                                                 DEFAULT_COMPLETEKEY)
  2501.     optvalues.color = option_with_default(configs.getboolean, 'ui', 'color')
  2502.     optvalues.time_format = raw_option_with_default(configs, 'ui', 'time_format',
  2503.                                                     DEFAULT_TIMESTAMP_FORMAT)
  2504.     optvalues.nanotime_format = raw_option_with_default(configs, 'ui', 'nanotime_format',
  2505.                                                     DEFAULT_NANOTIME_FORMAT)
  2506.     optvalues.date_format = raw_option_with_default(configs, 'ui', 'date_format',
  2507.                                                     DEFAULT_DATE_FORMAT)
  2508.     optvalues.float_precision = option_with_default(configs.getint, 'ui', 'float_precision',
  2509.                                                     DEFAULT_FLOAT_PRECISION)
  2510.     optvalues.field_size_limit = option_with_default(configs.getint, 'csv', 'field_size_limit', csv.field_size_limit())
  2511.     optvalues.max_trace_wait = option_with_default(configs.getfloat, 'tracing', 'max_trace_wait',
  2512.                                                    DEFAULT_MAX_TRACE_WAIT)
  2513.  
  2514.     optvalues.debug = False
  2515.     optvalues.file = None
  2516.     optvalues.ssl = False
  2517.  
  2518.     optvalues.tty = sys.stdin.isatty()
  2519.     optvalues.cqlversion = option_with_default(configs.get, 'cql', 'version', DEFAULT_CQLVER)
  2520.     optvalues.protocolversion =  option_with_default(configs.get, 'cql', 'protocolversion', DEFAULT_PROTOCOL_VERSION)
  2521.     optvalues.connect_timeout = option_with_default(configs.getint, 'connection', 'timeout', DEFAULT_CONNECT_TIMEOUT_SECONDS)
  2522.     optvalues.execute = None
  2523.  
  2524.     (options, arguments) = parser.parse_args(cmdlineargs, values=optvalues)
  2525.  
  2526.     hostname = option_with_default(configs.get, 'connection', 'hostname', DEFAULT_HOST)
  2527.     port = option_with_default(configs.get, 'connection', 'port', DEFAULT_PORT)
  2528.  
  2529.     try:
  2530.         options.connect_timeout = int(options.connect_timeout)
  2531.     except ValueError:
  2532.         parser.error('{} is not a valid timeout.'.format(options.connect_timeout))
  2533.         options.connect_timeout = DEFAULT_CONNECT_TIMEOUT_SECONDS
  2534.  
  2535.     options.client_timeout = option_with_default(configs.get, 'connection', 'client_timeout', '10')
  2536.     if options.client_timeout.lower() == 'none':
  2537.         options.client_timeout = None
  2538.     else:
  2539.         options.client_timeout = int(options.client_timeout)
  2540.  
  2541.     hostname = environment.get('CQLSH_HOST', hostname)
  2542.     port = environment.get('CQLSH_PORT', port)
  2543.  
  2544.     if len(arguments) > 0:
  2545.         hostname = arguments[0]
  2546.     if len(arguments) > 1:
  2547.         port = arguments[1]
  2548.  
  2549.     if options.file or options.execute:
  2550.         options.tty = False
  2551.  
  2552.     if options.execute and not options.execute.endswith(';'):
  2553.         options.execute += ';'
  2554.  
  2555.     if optvalues.color in (True, False):
  2556.         options.color = optvalues.color
  2557.     else:
  2558.         if options.file is not None:
  2559.             options.color = False
  2560.         else:
  2561.             options.color = should_use_color()
  2562.  
  2563.     options.cqlversion, cqlvertup = full_cql_version(options.cqlversion)
  2564.     if cqlvertup[0] < 3:
  2565.         parser.error('%r is not a supported CQL version.' % options.cqlversion)
  2566.     else:
  2567.         options.cqlmodule = cql3handling
  2568.  
  2569.     try:
  2570.         port = int(port)
  2571.     except ValueError:
  2572.         parser.error('%r is not a valid port number.' % port)
  2573.  
  2574.     if options.protocolversion:
  2575.         try:
  2576.             options.protocolversion = int(optvalues.protocolversion)
  2577.         except ValueError:
  2578.             options.protocolversion=DEFAULT_PROTOCOL_VERSION
  2579.  
  2580.     return options, hostname, port
  2581.  
  2582.  
  2583. def setup_cqlruleset(cqlmodule):
  2584.     global cqlruleset
  2585.     cqlruleset = cqlmodule.CqlRuleSet
  2586.     cqlruleset.append_rules(cqlsh_extra_syntax_rules)
  2587.     for rulename, termname, func in cqlsh_syntax_completers:
  2588.         cqlruleset.completer_for(rulename, termname)(func)
  2589.     cqlruleset.commands_end_with_newline.update(my_commands_ending_with_newline)
  2590.  
  2591.  
  2592. def setup_cqldocs(cqlmodule):
  2593.     global cqldocs
  2594.     cqldocs = cqlmodule.cqldocs
  2595.  
  2596.  
  2597. def init_history():
  2598.     if readline is not None:
  2599.         try:
  2600.             readline.read_history_file(HISTORY)
  2601.         except IOError:
  2602.             pass
  2603.         delims = readline.get_completer_delims()
  2604.         delims.replace("'", "")
  2605.         delims += '.'
  2606.         readline.set_completer_delims(delims)
  2607.  
  2608.  
  2609. def save_history():
  2610.     if readline is not None:
  2611.         try:
  2612.             readline.write_history_file(HISTORY)
  2613.         except IOError:
  2614.             pass
  2615.  
  2616.  
  2617. def main(options, hostname, port):
  2618.     setup_cqlruleset(options.cqlmodule)
  2619.     setup_cqldocs(options.cqlmodule)
  2620.     init_history()
  2621.     csv.field_size_limit(options.field_size_limit)
  2622.  
  2623.     if options.file is None:
  2624.         stdin = None
  2625.     else:
  2626.         try:
  2627.             encoding, bom_size = get_file_encoding_bomsize(options.file)
  2628.             stdin = codecs.open(options.file, 'r', encoding)
  2629.             stdin.seek(bom_size)
  2630.         except IOError, e:
  2631.             sys.exit("Can't open %r: %s" % (options.file, e))
  2632.  
  2633.     if options.debug:
  2634.         sys.stderr.write("Using CQL driver: {}\n".format(cassandra))
  2635.         sys.stderr.write("Using connect timeout: {} seconds\n".format(options.connect_timeout))
  2636.  
  2637.     try:
  2638.         shell = Shell(hostname,
  2639.                       port,
  2640.                       color=options.color,
  2641.                       username=options.username,
  2642.                       password=options.password,
  2643.                       stdin=stdin,
  2644.                       tty=options.tty,
  2645.                       completekey=options.completekey,
  2646.                       cqlver=options.cqlversion,
  2647.                       keyspace=options.keyspace,
  2648.                       display_timestamp_format=options.time_format,
  2649.                       display_nanotime_format=options.nanotime_format,
  2650.                       display_date_format=options.date_format,
  2651.                       display_float_precision=options.float_precision,
  2652.                       max_trace_wait=options.max_trace_wait,
  2653.                       ssl=options.ssl,
  2654.                       single_statement=options.execute,
  2655.                       client_timeout=options.client_timeout,
  2656.                       connect_timeout=options.connect_timeout,
  2657.                       protocol_version=options.protocolversion)
  2658.     except KeyboardInterrupt:
  2659.         sys.exit('Connection aborted.')
  2660.     except CQL_ERRORS, e:
  2661.         sys.exit('Connection error: %s' % (e,))
  2662.     except VersionNotSupported, e:
  2663.         sys.exit('Unsupported CQL version: %s' % (e,))
  2664.     if options.debug:
  2665.         shell.debug = True
  2666.  
  2667.     shell.cmdloop()
  2668.     save_history()
  2669.     batch_mode = options.file or options.execute
  2670.     if batch_mode and shell.statement_error:
  2671.         sys.exit(2)
  2672.  
  2673.  
  2674. if __name__ == '__main__':
  2675.     main(*read_options(sys.argv[1:], os.environ))
  2676.  
  2677. # vim: set ft=python et ts=4 sw=4 :
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement