Source code for gemini.GeminiQuery

#!/usr/bin/env python

import os
import sys
import sqlite3
import re
import compiler
import collections
import json
import abc
import numpy as np
from itertools import chain
flatten = chain.from_iterable
import itertools as it

# gemini imports
import gemini_utils as util
from gemini_constants import *
from gemini_utils import (OrderedSet, OrderedDict, itersubclasses)
from .pdict import PDict
import compression
from sql_utils import ensure_columns, get_select_cols_and_rest
from gemini_subjects import get_subjects

class GeminiError(Exception):
    pass

def RowFactory(cursor, row):
    return dict(it.izip((c[0] for c in cursor.description), row))

class RowFormat:
    """A row formatter to output rows in a custom format.  To provide
    a new output format 'foo', implement the class methods and set the
    name field to foo.  This will automatically add support for 'foo' to
    anything accepting the --format option via --format foo.
    """

    __metaclass__ = abc.ABCMeta

    @abc.abstractproperty
    def name(self):
        return

    @abc.abstractmethod
    def format(self, row):
        """ return a string representation of a GeminiRow object
        """
        return str(row.print_fields)

    @abc.abstractmethod
    def format_query(self, query):
        """ augment the query with columns necessary for the format or else just
        return the untouched query
        """
        return query

    @abc.abstractmethod
    def predicate(self, row):
        """ the row must pass this additional predicate to be output. Just
        return True if there is no additional predicate"""
        return True

    @abc.abstractmethod
    def header(self, fields):
        """ return a header for the row """
        return "\t".join(fields)


class DefaultRowFormat(RowFormat):

    name = "default"

    def __init__(self, args):
        pass

    def format(self, row):
        r = row.print_fields
        return '\t'.join(str(v) if not isinstance(v, np.ndarray) else ",".join(map(str, v)) for v in r._vals)

    def format_query(self, query):
        return query

    def predicate(self, row):
        return True

    def header(self, fields):
        """ return a header for the row """
        return "\t".join(fields)


class CarrierSummary(RowFormat):
    """
    Generates a count of the carrier/noncarrier status of each feature in a given
    column of the sample table

    Assumes None == unknown.
    """
    name = "carrier_summary"

    def __init__(self, args):
        subjects = get_subjects(args)
        self.carrier_summary = args.carrier_summary

        # get the list of all possible values in the column
        # but don't include None, since we are treating that as unknown.
        self.column_types = list(set([getattr(x, self.carrier_summary)
                                      for x in subjects.values()]))
        self.column_types = [i for i in self.column_types if i is not None]
        self.column_counters = {None: set()}
        for ct in self.column_types:
            self.column_counters[ct] = set([k for (k, v) in subjects.items() if
                                            getattr(v, self.carrier_summary) == ct])


    def format(self, row):
        have_variant = set(row['variant_samples'])
        have_reference = set(row['hom_ref_samples'])
        unknown = len(set(row['unknown_samples']).union(self.column_counters[None]))
        carrier_counts = []
        for ct in self.column_types:
            counts = len(self.column_counters[ct].intersection(have_variant))
            carrier_counts.append(counts)
        for ct in self.column_types:
            counts = len(self.column_counters[ct].intersection(have_reference))
            carrier_counts.append(counts)

        carrier_counts.append(unknown)
        carrier_counts = map(str, carrier_counts)
        r = row.print_fields
        return '\t'.join(str(r[c]) if not isinstance(r[c], np.ndarray) else
                         ",".join(str(s) for s in r[c]) for c in r) \
               + "\t" + "\t".join(carrier_counts)

    def format_query(self, query):
        return query

    def predicate(self, row):
        return True

    def header(self, fields):
        """ return a header for the row """
        header_columns = self.column_types
        if self.carrier_summary == "affected":
            header_columns = self._rename_affected()
        carriers = [x + "_carrier" for x in map(str, header_columns)]
        noncarriers = [x + "_noncarrier" for x in map(str, header_columns)]
        fields += carriers
        fields += noncarriers
        fields += ["unknown"]
        return "\t".join(fields)

    def _rename_affected(self):
        header_columns = []
        for ct in self.column_types:
            if ct is True:
                header_columns.append("affected")
            elif ct is False:
                header_columns.append("unaffected")
        return header_columns



class TPEDRowFormat(RowFormat):

    X_PAR_REGIONS = [(60001, 2699520), (154931044, 155260560)]
    Y_PAR_REGIONS = [(10001, 2649520), (59034050, 59363566)]

    name = "tped"
    NULL_GENOTYPES = ["."]
    PED_MISSING = ["0", "0"]
    VALID_CHROMOSOMES = map(str, range(1, 23)) + ["X", "Y", "XY", "MT"]
    POSSIBLE_HAPLOID = ["X", "Y"]

    def __init__(self, args):
        gq = GeminiQuery(args.db)
        subjects = get_subjects(args, skip_filter=True)
        # get samples in order of genotypes
        self.samples = [gq.idx_to_sample_object[x] for x in range(len(subjects))]

    def format(self, row):
        VALID_CHROMOSOMES = map(str, range(1, 23)) + ["X", "Y", "XY", "MT"]
        chrom = row['chrom'].split("chr")[1]
        chrom = chrom if chrom in VALID_CHROMOSOMES else "0"
        start = str(row.row['start'])
        end = str(row.row['end'])
        ref = row['ref']
        alt = row['alt']
        geno = [re.split('\||/', x) for x in row['gts']]
        geno = [self._fix_genotype(chrom, start, genotype, self.samples[i].sex)
                for i, genotype in enumerate(geno)]
        genotypes = " ".join(list(flatten(geno)))
        name = str(row['variant_id'])
        return " ".join([chrom, name, "0", start, genotypes])

    def format_query(self, query):
        NEED_COLUMNS = ["chrom", "rs_ids", "start", "ref", "alt", "gts", "type", "variant_id"]
        return ensure_columns(query, NEED_COLUMNS)


    def predicate(self, row, _splitter=re.compile("\||/")):
        geno = [_splitter.split(x) for x in row['gts']]
        geno = list(flatten(geno))
        num_alleles = len(set(geno).difference(self.NULL_GENOTYPES))
        return num_alleles > 0 and num_alleles <= 2 and row['type'] != "sv"

    def _is_haploid(self, genotype):
        return len(genotype) < 2

    def _has_missing(self, genotype):
        return any([allele in self.NULL_GENOTYPES for allele in genotype])

    def _is_heterozygote(self, genotype):
        return len(genotype) == 2 and (genotype[0] != genotype[1])

    def _in_PAR(self, chrom, start):
        if chrom == "X":
            for region in self.X_PAR_REGIONS:
                if start > region[0] and start < region[1]:
                    return True
        elif chrom == "Y":
            for region in self.Y_PAR_REGIONS:
                if start > region[0] and start < region[1]:
                    return True
        return False

    def _fix_genotype(self, chrom, start, genotype, sex):
        """
        the TPED format has to have both alleles set, even if it is haploid.
        this fixes that setting Y calls on the female to missing,
        heterozygotic calls on the male non PAR regions to missing and haploid
        calls on non-PAR regions to be the haploid call for both alleles
        """
        if sex == "2":
            # set female Y calls and haploid calls to missing
            if self._is_haploid(genotype) or chrom == "Y" or self._has_missing(genotype):
                return self.PED_MISSING
            return genotype
        if chrom in self.POSSIBLE_HAPLOID and sex == "1":
            # remove the missing genotype calls
            genotype = [x for x in genotype if x not in self.NULL_GENOTYPES]
            # if all genotypes are missing skip
            if not genotype:
                return self.PED_MISSING
            # heterozygote males in non PAR regions are a mistake
            if self._is_heterozygote(genotype) and not self._in_PAR(chrom, start):
                return self.PED_MISSING
            # set haploid males to be homozygous for the allele
            if self._is_haploid(genotype):
                return [genotype[0], genotype[0]]

        # if a genotype is missing or is haploid set it to missing
        if self._has_missing(genotype) or self._is_haploid(genotype):
            return self.PED_MISSING
        else:
            return genotype

    def header(self, fields):
        return None

class JSONRowFormat(RowFormat):

    name = "json"

    def __init__(self, args):
        pass

    def format(self, row):
        """Emit a JSON representation of a given row
        """
        from .pdict import to_json
        return json.dumps(row.print_fields, default=to_json)

    def format_query(self, query):
        return query

    def predicate(self, row):
        return True

    def header(self, fields):
        return None

class VCFRowFormat(RowFormat):

    name = "vcf"

    def __init__(self, args):
        self.gq = GeminiQuery(args.db)

    def format(self, row):
        """Emit a VCF representation of a given row

           TODO: handle multiple alleles
        """
        istr = self.gq._info_dict_to_string

        # core VCF fields
        vcf_rec = [row.row['chrom'], row.row['start'] + 1]
        if row.row['vcf_id'] is None:
            vcf_rec.append('.')
        else:
            vcf_rec.append(row.row['vcf_id'])
        vcf_rec += [row.row['ref'], row.row['alt'], row.row['qual']]
        if row.row['filter'] is None:
            vcf_rec.append('PASS')
        else:
            vcf_rec.append(row.row['filter'])
        vcf_rec += [istr(row['info']), 'GT']

        # construct genotypes
        gts = list(row['gts'])
        gt_types = list(row['gt_types'])
        gt_phases = list(row['gt_phases'])
        for idx, gt_type in enumerate(gt_types):
            phase_char = '/' if not gt_phases[idx] else '|'
            gt = gts[idx]
            alleles = gt.split(phase_char)
            if gt_type == HOM_REF:
                vcf_rec.append('0' + phase_char + '0')
            elif gt_type == HET:
                # if the genotype is phased, need to check for 1|0 vs. 0|1
                if gt_phases[idx] and alleles[0] != row.row['ref']:
                    vcf_rec.append('1' + phase_char + '0')
                else:
                    vcf_rec.append('0' + phase_char + '1')
            elif gt_type == HOM_ALT:
                vcf_rec.append('1' + phase_char + '1')
            elif gt_type == UNKNOWN:
                vcf_rec.append('.' + phase_char + '.')

        return '\t'.join([str(c) if c is not None else "." for c in vcf_rec])

    def format_query(self, query):
        return query

    def predicate(self, row):
        return True

    def header(self, fields):
        """Return the original VCF's header
        """
        try:
            self.gq.run('select vcf_header from vcf_header')
            return str(self.gq.next()).strip()
        except:
            sys.exit("Your database does not contain the vcf_header table. Therefore, you cannot use --header.\n")

class SampleDetailRowFormat(RowFormat):
    """Retrieve queries with flattened sample information for samples present.

    This melts/tidys a single line result to have a separate line for every sample
    with that call and adds in sample metadata from the samples table.
    """
    name = "sampledetail"

    def __init__(self, args):
        self.gq = GeminiQuery(args.db)
        self.gq.run("SELECT * from samples")
        self.cols = self.gq.header.split()[1:]
        self.args = args

        self.samples = {}
        for row in self.gq:
            vals = [row[x] for x in self.cols]
            self.samples[row["name"]] = vals

    def format(self, row):
        r = row.print_fields
        samples = [s for s in r["variant_samples"].split(self.args.sample_delim) if s]
        for x in self.gq.sample_show_fields:
            del r[x]
        out = []
        for sample in samples:
            out.append('\t'.join([str(r[c]) for c in r] + self.samples[sample]))
        return "\n".join(out)

    def format_query(self, query):
        return query

    def predicate(self, row):
        return True

    def header(self, fields):
        for x in self.gq.sample_show_fields:
            if x in fields:
                fields.remove(x)
        return "\t".join(fields + self.cols)

class GeminiRow(object):
    __slots__ = ('cache', 'genotype_dict', 'row', 'formatter', 'query',
                 'print_fields')

    def __init__(self, row, query, formatter=DefaultRowFormat(None), print_fields=None):
        # row can be a dict() from the database or another GeminiRow (from the
        # same db entry). we try to re-use the cached stuff if possible.
        self.cache = {}
        self.genotype_dict = {}
        self.row = getattr(row, "row", row)
        self.cache = getattr(row, "cache", {})
        self.genotype_dict = getattr(row, "genotype_dict", {})

        # for the eval.
        #self.cache['sample_info'] = dict(query.sample_info)

        self.formatter = formatter
        self.query = query
        self.print_fields = print_fields or {}

    def __getitem__(self, key):
        # we cache what we can.
        key = str(key)
        if key in ('het_samples', 'hom_alt_samples', 'unknown_samples',
                'variant_samples', 'hom_ref_samples'):
            if self.genotype_dict == {}:
                self.genotype_dict = self.query._group_samples_by_genotype(self['gt_types'])
            if key == 'het_samples':
                return self.genotype_dict[HET]
            if key == 'hom_alt_samples':
                return self.genotype_dict[HOM_ALT]
            if key == 'hom_ref_samples':
                return self.genotype_dict[HOM_REF]
            if key == 'unknown_samples':
                return self.genotype_dict[UNKNOWN]
            if key == 'variant_samples':
                return self.genotype_dict[HET] + self.genotype_dict[HOM_ALT]

        if key in self.cache:
            return self.cache[key]

        if key == 'info':
            if 'info' not in self.cache:
                self.cache['info'] = compression.unpack_ordereddict_blob(self.row['info'])
            return self.cache['info']
        if key not in self.query.gt_cols:
            return self.row[key]
        elif key in self.query.gt_cols:
            if key not in self.cache:
                self.cache[key] = compression.unpack_genotype_blob(self.row[key])
            return self.cache[key]
        raise KeyError(key)

    def keys(self):
        return self.row.keys() + self.cache.keys()

    def __iter__(self):
        return self

    def __repr__(self):
        return self.formatter.format(self)

    def next(self):
        try:
            return self.row.keys()
        except:
            raise StopIteration


[docs]class GeminiQuery(object): """ An interface to submit queries to an existing Gemini database and iterate over the results of the query. We create a GeminiQuery object by specifying database to which to connect:: from gemini import GeminiQuery gq = GeminiQuery("my.db") We can then issue a query against the database and iterate through the results by using the ``run()`` method:: for row in gq: print row Instead of printing the entire row, one access print specific columns:: gq.run("select chrom, start, end from variants") for row in gq: print row['chrom'] Also, all of the underlying numpy genotype arrays are always available:: gq.run("select chrom, start, end from variants") for row in gq: gts = row.gts print row['chrom'], gts # yields "chr1" ['A/G' 'G/G' ... 'A/G'] The ``run()`` methods also accepts genotype filter:: query = "select chrom, start, end" from variants" gt_filter = "gt_types.NA20814 == HET" gq.run(query) for row in gq: print row Lastly, one can use the ``sample_to_idx`` and ``idx_to_sample`` dictionaries to gain access to sample-level genotype information either by sample name or by sample index:: # grab dict mapping sample to genotype array indices smp2idx = gq.sample_to_idx query = "select chrom, start, end from variants" gt_filter = "gt_types.NA20814 == HET" gq.run(query, gt_filter) # print a header listing the selected columns print gq.header for row in gq: # access a NUMPY array of the sample genotypes. gts = row['gts'] # use the smp2idx dict to access sample genotypes idx = smp2idx['NA20814'] print row, gts[idx] """ def __init__(self, db, include_gt_cols=False, out_format=DefaultRowFormat(None), variant_id_getter=None): assert os.path.exists(db), "%s does not exist." % db self.db = db self.query_executed = False self.for_browser = False self.include_gt_cols = include_gt_cols self.variant_id_getter = variant_id_getter # try to connect to the provided database self._connect_to_database() # save the gt_cols in the database and don't hard-code them anywhere. self.gt_cols = util.get_gt_cols(self.conn) # extract the column names from the sample table. # needed for gt-filter wildcard support. self._collect_sample_table_columns() # list of samples ids for each clause in the --gt-filter self.sample_info = collections.defaultdict(list) # map sample names to indices. e.g. self.sample_to_idx[NA20814] -> 323 self.sample_to_idx = util.map_samples_to_indices(self.c) # and vice versa. e.g., self.idx_to_sample[323] -> NA20814 self.idx_to_sample = util.map_indices_to_samples(self.c) self.idx_to_sample_object = util.map_indices_to_sample_objects(self.c) self.sample_to_sample_object = util.map_samples_to_sample_objects(self.c) self.formatter = out_format self.predicates = [self.formatter.predicate] self.sample_show_fields = ["variant_samples", "het_samples", "hom_alt_samples"] def _set_gemini_browser(self, for_browser): self.for_browser = for_browser
[docs] def run(self, query, gt_filter=None, show_variant_samples=False, variant_samples_delim=',', predicates=None, needs_genotypes=False, needs_genes=False, show_families=False, subjects=None): """ Execute a query against a Gemini database. The user may specify: 1. (reqd.) an SQL `query`. 2. (opt.) a genotype filter. """ self.query = self.formatter.format_query(query) self.gt_filter = gt_filter if self._is_gt_filter_safe() is False: sys.exit("ERROR: invalid --gt-filter command.") self.show_variant_samples = show_variant_samples self.variant_samples_delim = variant_samples_delim self.needs_genotypes = needs_genotypes self.needs_vcf_columns = False if self.formatter.name == 'vcf': self.needs_vcf_columns = True self.needs_genes = needs_genes self.show_families = show_families self.subjects = subjects if predicates: self.predicates += predicates # make sure the SELECT columns are separated by a # comma and a space. then tokenize by spaces. self.query = self.query.replace(',', ', ') self.query_pieces = self.query.split() if not any(s.startswith(("gt", "(gt")) for s in self.query_pieces) and \ not any(".gt" in s for s in self.query_pieces): if self.gt_filter is None: self.query_type = "no-genotypes" else: self.gt_filter = self._correct_genotype_filter() self.query_type = "filter-genotypes" else: if self.gt_filter is None: self.query_type = "select-genotypes" else: self.gt_filter = self._correct_genotype_filter() self.query_type = "filter-genotypes" if self.gt_filter: # here's how we use the fast if self.variant_id_getter: if os.environ.get('GEMINI_DEBUG') == 'TRUE': print >>sys.stderr, "bcolz: using index" user_dict = dict(HOM_REF=0, HET=1, UNKNOWN=2, HOM_ALT=3, sample_info=self.sample_info, MISSING=None, UNAFFECTED=1, AFFECTED=2) import time t0 = time.time() vids = self.variant_id_getter(self.db, self.gt_filter, user_dict) if vids is None: print >>sys.stderr, "bcolz: can't parse this filter (falling back to gemini): %s" % self.gt_filter else: if os.environ.get('GEMINI_DEBUG') == 'TRUE': print >>sys.stderr, "bcolz: %.2f seconds to get %d rows." % (time.time() - t0, len(vids)) self.add_vids_to_query(vids) if self.gt_filter: self.gt_filter_compiled = compiler.compile(self.gt_filter, self.gt_filter, 'eval') self._apply_query() self.query_executed = True
def __iter__(self): return self def add_vids_to_query(self, vids): #extra = " variant_id IN (%s)" % ",".join(map(str, vids)) # faster way to convert to string. # http://stackoverflow.com/a/13861407 q = add_variant_ids_to_query(self.query, vids) if q: self.query = q self.gt_filter = None @property def header(self): """ Return a header describing the columns that were selected in the query issued to a GeminiQuery object. """ if self.query_type == "no-genotypes": h = [col for col in self.all_query_cols] else: h = [col for col in self.all_query_cols] + \ [col for col in OrderedSet(self.all_columns_orig) - OrderedSet(self.select_columns)] if self.show_variant_samples: h += self.sample_show_fields if self.show_families: h += ["families"] return self.formatter.header(h) @property def sample2index(self): """ Return a dictionary mapping sample names to genotype array offsets:: gq = GeminiQuery("my.db") s2i = gq.sample2index print s2i['NA20814'] # yields 1088 """ return self.sample_to_idx @property def index2sample(self): """ Return a dictionary mapping sample names to genotype array offsets:: gq = GeminiQuery("my.db") i2s = gq.index2sample print i2s[1088] # yields "NA20814" """ return self.idx_to_sample def next(self): """ Return the GeminiRow object for the next query result. """ # we use a while loop since we may skip records based upon # genotype filters. if we need to skip a record, we just # throw a continue and keep trying. the alternative is to just # recursively call self.next() if we need to skip, but this # can quickly exceed the stack. while (1): try: row = GeminiRow(self.c.next(), self) except Exception: self.conn.close() raise StopIteration # skip the record if it does not meet the user's genotype filter # short circuit some expensive ops if self.gt_filter: try: if 'False' == self.gt_filter: continue unpacked = {'sample_info': self.sample_info} for col in self.gt_cols: if col in self.gt_filter: unpacked[col] = row[col] if not eval(self.gt_filter_compiled, unpacked): continue # eval on a phred_ll column that was None except TypeError: continue fields = PDict() for idx, col in enumerate(self.report_cols): if col == "*": continue if not col[:2] in ("gt", "GT"): # need to use add in case of duplicated fields (from # variants and variant_impacts) fields.add(col, row[col]) else: # reuse the original column name user requested # e.g. replace gts[1085] with gts.NA20814 if '[' in col: orig_col = self.gt_idx_to_name_map[col] source, extra = col.split('[', 1) assert extra[-1] == ']' if source.startswith('gt_phred_ll') and row[source] is None: fields[orig_col] = None continue idx = int(extra[:-1]) val = row[source][idx] if type(val) in (np.int8, np.int32, np.bool_): fields.add(orig_col, int(val)) elif type(val) in (np.float32,): fields.add(orig_col, float(val)) else: fields.add(orig_col, val) else: # asked for "gts" or "gt_types", e.g. if row[col] is not None: fields[col] = row[col] else: fields[col] = str(None) if self.show_variant_samples: fields["variant_samples"] = \ self.variant_samples_delim.join(self._filter_samples(row['variant_samples'])) fields["het_samples"] = \ self.variant_samples_delim.join(self._filter_samples(row['het_samples'])) fields["hom_alt_samples"] = \ self.variant_samples_delim.join(self._filter_samples(row['hom_alt_samples'])) if self.show_families: families = map(str, list(set([self.sample_to_sample_object[x].family_id for x in row['variant_samples']]))) fields["families"] = self.variant_samples_delim.join(families) if not all(predicate(row) for predicate in self.predicates): continue if not self.for_browser: # need to use new row for formatter. return GeminiRow(row, self, formatter=self.formatter, print_fields=fields) else: return fields def _filter_samples(self, samples): """Respect --sample-filter when outputting lists of sample information. """ if self.subjects is not None: return [x for x in samples if x in self.subjects] else: return samples def _group_samples_by_genotype(self, gt_types): """ make list keyed by genotype of list of samples with that genotype so index 0 is HOM, 1 is HET, 2 is UKNOWN, 3 is HOM_ALT. """ d = [[], [], [], []] lookup = self.idx_to_sample for i, x in enumerate(gt_types): d[x].append(lookup[i]) return d def _connect_to_database(self): """ Establish a connection to the requested Gemini database. """ # open up a new database if os.path.exists(self.db): self.conn = sqlite3.connect(self.db) self.conn.text_factory = str self.conn.isolation_level = None # allow us to refer to columns by name #self.conn.row_factory = RowFactory self.conn.row_factory = sqlite3.Row self.c = self.conn.cursor() def _collect_sample_table_columns(self): """ extract the column names in the samples table into a list """ self.c.execute('select * from samples limit 1') self.sample_column_names = [tup[0] for tup in self.c.description] def _is_gt_filter_safe(self, gt_filter=None): """ Test to see if the gt_filter string is potentially malicious. A future improvement would be to use pyparsing to traverse and directly validate the string. """ gt_filter = gt_filter or self.gt_filter if gt_filter is None or len(gt_filter.strip()) == 0: return True # avoid builtins # http://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html if "__" in gt_filter: return False # avoid malicious commands evil = [" rm ", "os.system"] if any(s in gt_filter for s in evil): return False # make sure a "gt" col is in the string valid_cols = list(flatten(("%s." % gtc, "(%s)." % gtc) for gtc in self.gt_cols)) if any(s in gt_filter for s in valid_cols): return True # assume the worst return False def _execute_query(self): try: self.c.execute(self.query) except sqlite3.OperationalError as e: msg = "SQLite error: {0}\n".format(e) print msg sys.stderr.write(msg) sys.exit("The query issued (%s) has a syntax error." % self.query) def _apply_query(self): """ Execute a query. Intercept gt* columns and replace sample names with indices where necessary. """ if self.needs_genes: self.query = self._add_gene_col_to_query() if self.needs_vcf_columns: self.query = self._add_vcf_cols_to_query() if self._query_needs_genotype_info(): # break up the select statement into individual # pieces and replace genotype columns using sample # names with sample indices self._split_select() # we only need genotype information if the user is # querying the variants table self.query = self._add_gt_cols_to_query() self._execute_query() self.all_query_cols = [ str(tuple[0]) for tuple in self.c.description if not tuple[0][:2] == "gt" and ".gt" not in tuple[0] ] if "*" in self.select_columns: self.select_columns.remove("*") self.all_columns_orig.remove("*") self.all_columns_new.remove("*") self.select_columns += self.all_query_cols self.report_cols = self.all_query_cols + \ list(OrderedSet(self.all_columns_new) - OrderedSet(self.select_columns)) # the query does not involve the variants table # and as such, we don't need to do anything fancy. else: self._execute_query() self.all_query_cols = [str(tuple[0]) for tuple in self.c.description if not tuple[0][:2] == "gt"] self.report_cols = self.all_query_cols def _correct_genotype_col(self, raw_col): """ Convert a _named_ genotype index to a _numerical_ genotype index so that the appropriate value can be extracted for the sample from the genotype numpy arrays. These lookups will be eval()'ed on the resuting rows to extract the appropriate information. For example, convert gt_types.1478PC0011 to gt_types[11] """ if raw_col == "*": return raw_col # e.g., "gts.NA12878" elif '.' in raw_col: (column, sample) = raw_col.split('.', 1) corrected = "%s[%d]" % (column.lower(), self.sample_to_idx[sample]) else: # e.g. "gts" - do nothing corrected = raw_col return corrected def _get_matching_sample_ids(self, wildcard): """ Helper function to convert a sample wildcard to a list of tuples reflecting the sample indices and sample names so that the wildcard query can be applied to the gt_* columns. """ query = 'SELECT sample_id, name FROM samples ' if wildcard.strip() != "*": query += ' WHERE ' + wildcard sample_info = [] # list of sample_id/name tuples self.c.execute(query) for row in self.c: # sample_ids are 1-based but gt_* indices are 0-based sample_info.append((int(row['sample_id']) - 1, str(row['name']))) return sample_info def _correct_genotype_filter(self, gt_filter=None): """ This converts a raw genotype filter that contains 'wildcard' statements into a filter that can be eval()'ed. Specifically, we must convert a _named_ genotype index to a _numerical_ genotype index so that the appropriate value can be extracted for the sample from the genotype numpy arrays. For example, without WILDCARDS, this converts: --gt-filter "(gt_types.1478PC0011 == 1)" to: (gt_types[11] == 1) With WILDCARDS, this converts things like: "(gt_types).(phenotype==1).(==HET)" to: "gt_types[2] == HET and gt_types[5] == HET" """ def _swap_genotype_for_number(token): """ This is a bit of a hack to get around the fact that eval() doesn't handle the imported constants well when also having to find local variables. This requires some eval/globals()/locals() fu that has evaded me thus far. Just replacing HET, etc. with 1, etc. works. """ if any(g in token for g in ['HET', 'HOM_ALT', 'HOM_REF', 'UNKNOWN']): token = token.replace('HET', str(HET)) token = token.replace('HOM_ALT', str(HOM_ALT)) token = token.replace('HOM_REF', str(HOM_REF)) token = token.replace('UNKNOWN', str(UNKNOWN)) return token corrected_gt_filter = [] # first try to identify wildcard rules. # (\s*gt\w+\) handles both # (gt_types).(*).(!=HOM_REF).(all) # and # ( gt_types).(*).(!=HOM_REF).(all) seen_count = False wildcard_tokens = re.split(r'(\(\s*gt\w+\s*\)\.\(.+?\)\.\(.+?\)\.\(.+?\))', str(gt_filter or self.gt_filter)) for token_idx, token in enumerate(wildcard_tokens): # NOT a WILDCARD # We must then split on whitespace and # correct the gt_* columns: # e.g., "gts.NA12878" or "and gt_types.M10500 == HET" if (token.find("gt") >= 0 or token.find("GT") >= 0) \ and not '.(' in token and not ')self.' in token: tokens = re.split(r'[\s+]+', str(token)) for t in tokens: if len(t) == 0: continue if (t.find("gt") >= 0 or t.find("GT") >= 0): corrected = self._correct_genotype_col(t) corrected_gt_filter.append(corrected) else: t = _swap_genotype_for_number(t) if t.strip() in ("AND", "OR"): t = t.lower() corrected_gt_filter.append(t) # IS a WILDCARD # e.g., "gt_types.(affected==1).(==HET)" elif (token.find("gt") >= 0 or token.find("GT") >= 0) \ and '.(' in token and ').' in token: # break the wildcard into its pieces. That is: # (COLUMN).(WILDCARD).(WILDCARD_RULE).(WILDCARD_OP) # e.g, (gts).(phenotype==2).(==HET).(any) if token.count('.') != 3 or \ token.count('(') != 4 or \ token.count(')') != 4: sys.exit("Wildcard filter should consist of 4 elements. Exiting.") (column, wildcard, wildcard_rule, wildcard_op) = token.split('.') # remove the syntactic parentheses column = column.strip('(').strip(')').strip() wildcard = wildcard.strip('(').strip(')').strip() wildcard_rule = wildcard_rule.strip('(').strip(')').strip() wildcard_op = wildcard_op.strip('(').strip(')').strip() # collect and save all of the samples that meet the wildcard criteria # for each clause. # these will be used in the list comprehension for the eval expression # constructed below. self.sample_info[token_idx] = self._get_matching_sample_ids(wildcard) # Replace HET, etc. with 1, et.c to avoid eval() issues. wildcard_rule = _swap_genotype_for_number(wildcard_rule) # build the rule based on the wildcard the user has supplied. if wildcard_op in ["all", "any"]: if self.variant_id_getter: joiner = " and " if wildcard_op == "all" else " or " rule = joiner.join("%s[%s]%s" % (column, s[0], wildcard_rule) for s in self.sample_info[token_idx]) rule = "(" + rule + ")" else: rule = wildcard_op + "(" + column + '[sample[0]]' + wildcard_rule + " for sample in sample_info[" + str(token_idx) + "])" elif wildcard_op == "none": if self.variant_id_getter: rule = " or ".join("%s[%s]%s" % (column, s[0], wildcard_rule) for s in self.sample_info[token_idx]) rule = "~ ((" + rule + "))" else: rule = "not any(" + column + '[sample[0]]' + wildcard_rule + " for sample in sample_info[" + str(token_idx) + "])" elif "count" in wildcard_op: # break "count>=2" into ['', '>=2'] tokens = wildcard_op.split('count') count_comp = tokens[len(tokens) - 1] if self.variant_id_getter: rule = "|count|".join("((%s[%s]%s))" % (column, s[0], wildcard_rule) for s in self.sample_info[token_idx]) rule = "%s|count|%s" % (rule, count_comp.strip()) seen_count = True else: rule = "sum(" + column + '[sample[0]]' + wildcard_rule + " for sample in sample_info[" + str(token_idx) + "])" + count_comp else: sys.exit("Unsupported wildcard operation: (%s). Exiting." % wildcard_op) corrected_gt_filter.append(rule) else: if len(token) > 0: if token.strip() in ("AND", "OR"): token = token.lower() corrected_gt_filter.append(token) if seen_count and len(corrected_gt_filter) > 1 and self.variant_id_getter: raise GeminiError("count operations can not be combined with other operations") return " ".join(corrected_gt_filter) def _add_gt_cols_to_query(self): """ We have to modify the raw query to select the genotype columns in order to support the genotype filters. That is, if the user wants to limit the rows returned based upon, for example, "gts.joe == 1", then we need to select the full gts BLOB column in order to enforce that limit. The user wouldn't have selected gts as a columns, so therefore, we have to modify the select statement to add it. In essence, when a gneotype filter has been requested, we always add the gts, gt_types and gt_phases columns. """ if "from" not in self.query.lower(): sys.exit("Malformed query: expected a FROM keyword.") (select_tokens, rest_of_query) = get_select_cols_and_rest(self.query) # remove any GT columns select_clause_list = [] for token in select_tokens: if not token[:2] in ("GT", "gt") and \ not token[:3] in ("(gt", "(GT") and \ not ".gt" in token and \ not ".GT" in token: select_clause_list.append(token) # reconstruct the query with the GT* columns added select_clause = (", ".join(select_clause_list)).strip() if select_clause: select_clause += "," select_clause += ",".join(self.gt_cols) self.query = "select %s %s" % (select_clause, rest_of_query) # extract the original select columns return self.query def _add_gene_col_to_query(self): """ Add the gene column to the list of SELECT'ed columns in a query. """ if "from" not in self.query.lower(): sys.exit("Malformed query: expected a FROM keyword.") (select_tokens, rest_of_query) = get_select_cols_and_rest(self.query) if not any("gene" in s for s in select_tokens): select_clause = ",".join(select_tokens) + \ ", gene " self.query = "select " + select_clause + rest_of_query return self.query def _add_vcf_cols_to_query(self): """ Add the VCF columns to the list of SELECT'ed columns in a query. NOTE: Should only be called if using VCFRowFormat() """ if "from" not in self.query.lower(): sys.exit("Malformed query: expected a FROM keyword.") (select_tokens, rest_of_query) = get_select_cols_and_rest(self.query) cols_to_add = [] for col in ['chrom', 'start', 'vcf_id', 'ref', 'alt', 'qual', 'filter', 'info', \ 'gts', 'gt_types', 'gt_phases']: if not any(col in s for s in select_tokens): cols_to_add.append(col) select_clause = ",".join(select_tokens + cols_to_add) self.query = "select " + select_clause + rest_of_query return self.query def _split_select(self): """ Build a list of _all_ columns in the SELECT statement and segregated the non-genotype specific SELECT columns. This is used to control how to report the results, as the genotype-specific columns need to be eval()'ed whereas others do not. For example: "SELECT chrom, start, end, gt_types.1478PC0011" will populate the lists as follows: select_columns = ['chrom', 'start', 'end'] all_columns = ['chrom', 'start', 'end', 'gt_types[11]'] """ self.select_columns = [] self.all_columns_new = [] self.all_columns_orig = [] self.gt_name_to_idx_map = {} self.gt_idx_to_name_map = {} # iterate through all of the select columns andclear # distinguish the genotype-specific columns from the base columns if "from" not in self.query.lower(): sys.exit("Malformed query: expected a FROM keyword.") (select_tokens, rest_of_query) = get_select_cols_and_rest(self.query) for token in select_tokens: # it is a WILDCARD if (token.find("gt") >= 0 or token.find("GT") >= 0) \ and '.(' in token and ').' in token: # break the wildcard into its pieces. That is: # (COLUMN).(WILDCARD) (column, wildcard) = token.split('.') # remove the syntactic parentheses wildcard = wildcard.strip('(').strip(')') column = column.strip('(').strip(')') # convert "gt_types.(affected==1)" # to: gt_types[3] == HET and gt_types[9] == HET sample_info = self._get_matching_sample_ids(wildcard) # maintain a list of the sample indices that should # be displayed as a result of the SELECT'ed wildcard wildcard_indices = [] for (idx, sample) in enumerate(sample_info): wildcard_display_col = column + '.' + str(sample[1]) wildcard_mask_col = column + '[' + str(sample[0]) + ']' wildcard_indices.append(sample[0]) new_col = wildcard_mask_col self.all_columns_new.append(new_col) self.all_columns_orig.append(wildcard_display_col) self.gt_name_to_idx_map[wildcard_display_col] = wildcard_mask_col self.gt_idx_to_name_map[wildcard_mask_col] = wildcard_display_col # it is a basic genotype column elif (token.find("gt") >= 0 or token.find("GT") >= 0) \ and '.(' not in token and not ').' in token \ and "length" not in token: # e.g., no aa_lenGTh, etc. false positives new_col = self._correct_genotype_col(token) self.all_columns_new.append(new_col) self.all_columns_orig.append(token) self.gt_name_to_idx_map[token] = new_col self.gt_idx_to_name_map[new_col] = token # it is neither else: self.select_columns.append(token) self.all_columns_new.append(token) self.all_columns_orig.append(token) def _info_dict_to_string(self, info): """ Flatten the VCF info-field OrderedDict into a string, including all arrays for allelic-level info. """ if info is not None: return ';'.join('%s=%s' % (key, value) if not isinstance(value, list) else '%s=%s' % (key, ','.join(str(v) for v in value)) for (key, value) in info.items()) else: return None def _tokenize_query(self): return list(flatten(x.split(",") for x in self.query.split(" "))) def _query_needs_genotype_info(self): if self.include_gt_cols or self.show_variant_samples or self.needs_genotypes: return True tokens = self._tokenize_query() requested_genotype = "variants" in tokens and \ (any(x.startswith(("gt", "(gt")) for x in tokens) or \ any(".gt" in x for x in tokens)) return requested_genotype
def select_formatter(args): SUPPORTED_FORMATS = {x.name.lower(): x for x in itersubclasses(RowFormat)} if hasattr(args, 'carrier_summary') and args.carrier_summary: return SUPPORTED_FORMATS["carrier_summary"](args) if args.format not in SUPPORTED_FORMATS: raise NotImplementedError("Conversion to %s not supported. Valid " "formats are %s." % (args.format, SUPPORTED_FORMATS)) else: return SUPPORTED_FORMATS[args.format](args) def add_variant_ids_to_query(query, vids): """ >>> vids = range(1, 4) >>> add_variant_ids_to_query("select gene, chrom, start, end from variants limit 10", vids) 'select gene, chrom, start, end from variants where variant_id IN (1,2,3) limit 10' >>> add_variant_ids_to_query("select gene, chrom, start, end from variants where gene = 'asdf' limit 10", vids) "select gene, chrom, start, end from variants where gene = 'asdf' and variant_id IN (1,2,3) limit 10" >>> add_variant_ids_to_query("select gene, chrom, start, end from variants where gene = 'asdf' order by gene limit 10", vids) "select gene, chrom, start, end from variants where gene = 'asdf' and variant_id IN (1,2,3) order by gene limit 10" >>> add_variant_ids_to_query("select gene, chrom, start, end from variants", vids) 'select gene, chrom, start, end from variants where variant_id IN (1,2,3)' """ if len(vids) == 0: return vids extra = " variant_id IN (%s)" % ",".join(np.char.mod("%i", vids)) # order by, then limit. limit_idx = query.lower().index(" limit ") if " limit " in query.lower() else None if limit_idx: query, qlimit = query[:limit_idx].strip(), query[limit_idx:].strip() assert qlimit.lower().startswith("limit ") else: qlimit = "" order_idx = query.lower().index(" order by ") if " order by " in query.lower() else None if order_idx: query, qorder = query[:order_idx].strip(), query[order_idx:].strip() assert qorder.lower().startswith("order by ") else: qorder = "" if " where " in query.lower(): extra = " and " + extra else: extra = " where " + extra return " ".join([x.strip() for x in (query, extra, qorder, qlimit)]).strip() if __name__ == "__main__": db = sys.argv[1] gq = GeminiQuery(db) print "test a basic query with no genotypes" query = "select chrom, start, end from variants limit 5" gq.run(query) for row in gq: print row print "test a basic query with no genotypes using a header" query = "select chrom, start, end from variants limit 5" gq.run(query) print gq.header for row in gq: print row print "test query that selects a sample genotype" query = "select chrom, start, end, gts.NA20814 from variants limit 5" gq.run(query) for row in gq: print row print "test query that selects a sample genotype and uses a header" query = "select chrom, start, end, gts.NA20814 from variants limit 5" gq.run(query) print gq.header for row in gq: print row print "test query that selects and _filters_ on a sample genotype" query = "select chrom, start, end, gts.NA20814 from variants limit 50" db_filter = "gt_types.NA20814 == HET" gq.run(query, db_filter) for row in gq: print row print "test query that selects and _filters_ on a sample genotype and uses a filter" query = "select chrom, start, end, gts.NA20814 from variants limit 50" db_filter = "gt_types.NA20814 == HET" gq.run(query, db_filter) print gq.header for row in gq: print row print "test query that selects and _filters_ on a sample genotype and uses a filter and a header" query = "select chrom, start, end, gts.NA20814 from variants limit 50" db_filter = "gt_types.NA20814 == HET" gq.run(query, db_filter) print gq.header for row in gq: print row print "demonstrate accessing individual columns" query = "select chrom, start, end, gts.NA20814 from variants limit 50" db_filter = "gt_types.NA20814 == HET" gq.run(query, db_filter) for row in gq: print row['chrom'], row['start'], row['end'], row['gts.NA20814']
comments powered by Disqus