1
2 r"""
3 ==========================
4 Schema module generation
5 ==========================
6
7 Schema module generation code.
8
9 :Copyright:
10
11 Copyright 2010 - 2017
12 Andr\xe9 Malo or his licensors, as applicable
13
14 :License:
15
16 Licensed under the Apache License, Version 2.0 (the "License");
17 you may not use this file except in compliance with the License.
18 You may obtain a copy of the License at
19
20 http://www.apache.org/licenses/LICENSE-2.0
21
22 Unless required by applicable law or agreed to in writing, software
23 distributed under the License is distributed on an "AS IS" BASIS,
24 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 See the License for the specific language governing permissions and
26 limitations under the License.
27
28 """
29 if __doc__:
30
31 __doc__ = __doc__.encode('ascii').decode('unicode_escape')
32 __author__ = r"Andr\xe9 Malo".encode('ascii').decode('unicode_escape')
33 __docformat__ = "restructuredtext en"
34
35 import sqlalchemy as _sa
36
37 from . import _table
38 from . import _template
39
40
42 """
43 Schema container
44
45 :CVariables:
46 `_MODULE_TPL` : ``Template``
47 Template for the module
48
49 :IVariables:
50 `_dialect` : ``str``
51 Dialect name
52
53 `_tables` : `TableCollection`
54 Table collection
55
56 `_schemas` : ``dict``
57 Schema -> module mapping
58
59 `_symbols` : `Symbols`
60 Symbol table
61
62 `_dbname` : ``str`` or ``None``
63 DB identifier
64 """
65
66 _MODULE_TPL = _template.Template('''
67 # -*- coding: ascii -*-
68 # flake8: noqa pylint: skip-file
69 """
70 ==============================
71 SQLAlchemy schema definition
72 ==============================
73
74 SQLAlchemy schema definition%(dbspec)s.
75
76 :Warning: DO NOT EDIT, this file is generated
77 """
78 __docformat__ = "restructuredtext en"
79
80 import sqlalchemy as %(sa)s
81 from sqlalchemy.dialects import %(dialect)s as %(type)s
82 %(imports)s
83 %(meta)s = %(sa)s.MetaData()
84 %(table)s = %(sa)s.Table
85 %(column)s = %(sa)s.Column
86 %(default)s = %(sa)s.DefaultClause
87 %(lines)s
88 del %(sa)s, %(table)s, %(column)s, %(default)s, %(meta)s
89
90 # vim: nowrap tw=0
91 ''')
92
93 - def __init__(self, conn, tables, schemas, symbols, dbname=None,
94 types=None):
95 """
96 Initialization
97
98 :Parameters:
99 `conn` : ``Connection`` or ``Engine``
100 SQLAlchemy connection or engine
101
102 `tables` : ``list``
103 List of tables to reflect, (local name, table name) pairs
104
105 `schemas` : ``dict``
106 schema -> module mapping
107
108 `symbols` : `Symbols`
109 Symbol table
110
111 `dbname` : ``str``
112 Optional db identifier. Used for informational purposes. If
113 omitted or ``None``, the information just won't be emitted.
114
115 `types` : callable
116 Extra type loader. If the type reflection fails, because
117 SQLAlchemy cannot resolve it, the type loader will be called with
118 the type name, (bound) metadata and the symbol table. It is
119 responsible for modifying the symbols and imports *and* the
120 dialect's ``ischema_names``. If omitted or ``None``, the reflector
121 will always fail on unknown types.
122 """
123 metadata = _sa.MetaData(conn)
124 self._dialect = metadata.bind.dialect.name
125 self._tables = _table.TableCollection.by_names(
126 metadata, tables, schemas, symbols, types=types
127 )
128 self._schemas = schemas
129 self._symbols = symbols
130 self._dbname = dbname
131
132 - def dump(self, fp):
133 """
134 Dump schema module to fp
135
136 :Parameters:
137 `fp` : ``file``
138 File to write to
139 """
140 imports = [item % self._symbols for item in self._symbols.imports]
141 if imports:
142 imports.sort()
143 imports.append('')
144 lines = []
145
146 defines = self._symbols.types.defines
147 if defines:
148 defined = []
149 for define in defines:
150 defined.extend(define(self._dialect, self._symbols))
151 if defined:
152 lines.append('')
153 lines.append('# Custom type definitions')
154 lines.extend(defined)
155 lines.append('')
156
157 for table in self._tables:
158 if table.is_reference:
159 continue
160 if not lines:
161 lines.append('')
162 name = table.sa_table.name.encode('ascii', 'backslashescape')
163 if bytes is not str:
164 name = name.decode('ascii')
165 lines.append('# Table "%s"' % (name,))
166 lines.append('%s = %r' % (table.varname, table))
167 lines.append('')
168 lines.append('')
169
170 param = dict(((str(key), value) for key, value in self._symbols),
171 dbspec=" for %s" % self._dbname if self._dbname else "",
172 dialect=self._dialect,
173 imports='\n'.join(imports),
174 lines='\n'.join(lines))
175 fp.write(self._MODULE_TPL.expand(**param))
176 fp.write('\n')
177