import collections
import functools
import operator
import os
import re
from time import time
from ase.atoms import Atoms, symbols2numbers
from ase.db.row import atoms2dict, AtomsRow
from ase.calculators.calculator import all_properties, all_changes
from ase.data import atomic_numbers
from ase.parallel import world, broadcast, DummyMPI
from ase.utils import Lock, basestring
T2000 = 946681200.0 # January 1. 2000
YEAR = 31557600.0 # 365.25 days
def now():
"""Return time since January 1. 2000 in years."""
return (time() - T2000) / YEAR
seconds = {'s': 1,
'm': 60,
'h': 3600,
'd': 86400,
'w': 604800,
'M': 2629800,
'y': YEAR}
longwords = {'s': 'second',
'm': 'minute',
'h': 'hour',
'd': 'day',
'w': 'week',
'M': 'month',
'y': 'year'}
ops = {'<': operator.lt,
'<=': operator.le,
'=': operator.eq,
'>=': operator.ge,
'>': operator.gt,
'!=': operator.ne}
invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}
word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')
reserved_keys = set(all_properties + all_changes +
['id', 'unique_id', 'ctime', 'mtime', 'user',
'momenta', 'constraints',
'calculator', 'calculator_parameters',
'key_value_pairs', 'data'])
numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms'])
def check(key_value_pairs):
for key, value in key_value_pairs.items():
if not word.match(key) or key in reserved_keys:
raise ValueError('Bad key: {0}'.format(key))
if not isinstance(value, (int, float, basestring)):
raise ValueError('Bad value: {0}'.format(value))
[docs]def connect(name, type='extract_from_name', create_indices=True,
use_lock_file=True, append=True):
"""Create connection to database.
name: str
Filename or address of database.
type: str
One of 'json', 'db', 'postgresql', 'mysql'
(JSON, SQLite, PostgreSQL, MySQL/MariaDB).
Default is 'extract_from_name', which will ... guess the type
from the name.
use_lock_file: bool
You can turn this off if you know what you are doing ...
append: bool
Use append=False to start a new database.
"""
if type == 'extract_from_name':
if name is None:
type = None
elif name.startswith('pg://'):
type = 'postgresql'
else:
type = os.path.splitext(name)[1][1:]
if type is None:
return Database()
if not append and world.rank == 0 and os.path.isfile(name):
os.remove(name)
if type == 'json':
from ase.db.jsondb import JSONDatabase
return JSONDatabase(name, use_lock_file=use_lock_file)
if type == 'db':
from ase.db.sqlite import SQLite3Database
return SQLite3Database(name, create_indices, use_lock_file)
if type == 'postgresql':
from ase.db.postgresql import PostgreSQLDatabase
return PostgreSQLDatabase(name[5:])
raise ValueError('Unknown database type: ' + type)
def lock(method):
"""Decorator for using a lock-file."""
@functools.wraps(method)
def new_method(self, *args, **kwargs):
if self.lock is None:
return method(self, *args, **kwargs)
else:
with self.lock:
return method(self, *args, **kwargs)
return new_method
def parallel(method):
"""Decorator for broadcasting from master to slaves using MPI."""
if world.size == 1:
return method
@functools.wraps(method)
def new_method(*args, **kwargs):
ex = None
result = None
if world.rank == 0:
try:
result = method(*args, **kwargs)
except Exception as ex:
pass
ex, result = broadcast((ex, result))
if ex is not None:
raise ex
return result
return new_method
def parallel_generator(generator):
"""Decorator for broadcasting yields from master to slaves using MPI."""
if world.size == 1:
return generator
@functools.wraps(generator)
def new_generator(*args, **kwargs):
if world.rank == 0:
for result in generator(*args, **kwargs):
result = broadcast(result)
yield result
broadcast(None)
else:
result = broadcast(None)
while result is not None:
yield result
result = broadcast(None)
return new_generator
def convert_str_to_float_or_str(value):
"""Safe eval()"""
try:
value = float(value)
except ValueError:
value = {'True': 1.0, 'False': 0.0}.get(value, value)
return value
[docs]class Database:
"""Base class for all databases."""
def __init__(self, filename=None, create_indices=True,
use_lock_file=False):
if isinstance(filename, str):
filename = os.path.expanduser(filename)
self.filename = filename
self.create_indices = create_indices
if use_lock_file and isinstance(filename, str):
self.lock = Lock(filename + '.lock', world=DummyMPI())
else:
self.lock = None
@parallel
@lock
[docs] def write(self, atoms, key_value_pairs={}, data={}, **kwargs):
"""Write atoms to database with key-value pairs.
atoms: Atoms object
Write atomic numbers, positions, unit cell and boundary
conditions. If a calculator is attached, write also already
calculated properties such as the energy and forces.
key_value_pairs: dict
Dictionary of key-value pairs. Values must be strings or numbers.
data: dict
Extra stuff (not for searching).
Key-value pairs can also be set using keyword arguments::
connection.write(atoms, name='ABC', frequency=42.0)
Returns integer id of the new row.
"""
if atoms is None:
atoms = Atoms()
kvp = dict(key_value_pairs) # modify a copy
kvp.update(kwargs)
id = self._write(atoms, kvp, data)
return id
def _write(self, atoms, key_value_pairs, data):
check(key_value_pairs)
return 1
@parallel
@lock
[docs] def reserve(self, **key_value_pairs):
"""Write empty row if not already present.
Usage::
id = conn.reserve(key1=value1, key2=value2, ...)
Write an empty row with the given key-value pairs and
return the integer id. If such a row already exists, don't write
anything and return None.
"""
for dct in self._select([],
[(key, '=', value)
for key, value in key_value_pairs.items()]):
return None
atoms = Atoms()
calc_name = key_value_pairs.pop('calculator', None)
if calc_name:
# Allow use of calculator key
assert calc_name.lower() == calc_name
# Fake calculator class:
class Fake:
name = calc_name
def todict(self):
return {}
def check_state(self, atoms):
return ['positions']
atoms.calc = Fake()
id = self._write(atoms, key_value_pairs, {})
return id
def __delitem__(self, id):
self.delete([id])
[docs] def get_atoms(self, selection=None, attach_calculator=False,
add_additional_information=False, **kwargs):
"""Get Atoms object.
selection: int, str or list
See the select() method.
attach_calculator: bool
Attach calculator object to Atoms object (default value is
False).
add_additional_information: bool
Put key-value pairs and data into Atoms.info dictionary.
In addition, one can use keyword arguments to select specific
key-value pairs.
"""
row = self.get(selection, **kwargs)
return row.toatoms(attach_calculator, add_additional_information)
def __getitem__(self, selection):
return self.get(selection)
[docs] def get(self, selection=None, **kwargs):
"""Select a single row and return it as a dictionary.
selection: int, str or list
See the select() method.
fancy: bool
return fancy dictionary with keys as attributes (this is the
default).
"""
rows = list(self.select(selection, limit=2, **kwargs))
if not rows:
raise KeyError('no match')
assert len(rows) == 1, 'more than one row matched'
return rows[0]
def parse_selection(self, selection, **kwargs):
if selection is None or selection == '':
expressions = []
elif isinstance(selection, int):
expressions = [('id', '=', selection)]
elif isinstance(selection, list):
expressions = selection
else:
expressions = [w.strip() for w in selection.split(',')]
keys = []
comparisons = []
for expression in expressions:
if isinstance(expression, (list, tuple)):
comparisons.append(expression)
continue
if expression.count('<') == 2:
value, expression = expression.split('<', 1)
if expression[0] == '=':
op = '>='
expression = expression[1:]
else:
op = '>'
key = expression.split('<', 1)[0]
comparisons.append((key, op, value))
for op in ['!=', '<=', '>=', '<', '>', '=']:
if op in expression:
break
else:
if expression in atomic_numbers:
comparisons.append((expression, '>', 0))
else:
keys.append(expression)
continue
key, value = expression.split(op)
comparisons.append((key, op, value))
cmps = []
for key, value in kwargs.items():
comparisons.append((key, '=', value))
for key, op, value in comparisons:
if key == 'age':
key = 'ctime'
op = invop[op]
value = now() - time_string_to_float(value)
elif key == 'formula':
assert op == '='
numbers = symbols2numbers(value)
count = collections.defaultdict(int)
for Z in numbers:
count[Z] += 1
cmps.extend((Z, '=', count[Z]) for Z in count)
key = 'natoms'
value = len(numbers)
elif key in atomic_numbers:
key = atomic_numbers[key]
value = int(value)
elif isinstance(value, basestring):
value = convert_str_to_float_or_str(value)
if key in numeric_keys and not isinstance(value, (int, float)):
msg = 'Wrong type for "{0}{1}{2}" - must be a number'
raise ValueError(msg.format(key, op, value))
cmps.append((key, op, value))
return keys, cmps
@parallel_generator
[docs] def select(self, selection=None, filter=None, explain=False,
verbosity=1, limit=None, offset=0, sort=None, **kwargs):
"""Select rows.
Return AtomsRow iterator with results. Selection is done
using key-value pairs and the special keys:
formula, age, user, calculator, natoms, energy, magmom
and/or charge.
selection: int, str or list
Can be:
* an integer id
* a string like 'key=value', where '=' can also be one of
'<=', '<', '>', '>=' or '!='.
* a string like 'key'
* comma separated strings like 'key1<value1,key2=value2,key'
* list of strings or tuples: [('charge', '=', 1)].
filter: function
A function that takes as input a row and returns True or False.
explain: bool
Explain query plan.
verbosity: int
Possible values: 0, 1 or 2.
limit: int or None
Limit selection.
"""
if sort:
if sort == 'age':
sort = '-ctime'
elif sort == '-age':
sort = 'ctime'
elif sort.lstrip('-') == 'user':
sort += 'name'
keys, cmps = self.parse_selection(selection, **kwargs)
for row in self._select(keys, cmps, explain=explain,
verbosity=verbosity,
limit=limit, offset=offset, sort=sort):
if filter is None or filter(row):
yield row
def count(self, selection=None, **kwargs):
n = 0
for row in self.select(selection, **kwargs):
n += 1
return n
@parallel
@lock
[docs] def update(self, ids, delete_keys=[], block_size=1000,
**add_key_value_pairs):
"""Update row(s).
ids: int or list of int
ID's of rows to update.
delete_keys: list of str
Keys to remove.
Use keyword argumnts to add new keys-value pairs.
Returns number of key-value pairs added and removed.
"""
check(add_key_value_pairs)
if isinstance(ids, int):
ids = [ids]
B = block_size
nblocks = (len(ids) - 1) // B + 1
M = 0
N = 0
for b in range(nblocks):
m, n = self._update(ids[b * B:(b + 1) * B], delete_keys,
add_key_value_pairs)
M += m
N += n
return M, N
[docs] def delete(self, ids):
"""Delete rows."""
raise NotImplementedError
def time_string_to_float(s):
if isinstance(s, (float, int)):
return s
s = s.replace(' ', '')
if '+' in s:
return sum(time_string_to_float(x) for x in s.split('+'))
if s[-2].isalpha() and s[-1] == 's':
s = s[:-1]
i = 1
while s[i].isdigit():
i += 1
return seconds[s[i:]] * int(s[:i]) / YEAR
def float_to_time_string(t, long=False):
t *= YEAR
for s in 'yMwdhms':
x = t / seconds[s]
if x > 5:
break
if long:
return '{0:.3f} {1}s'.format(x, longwords[s])
else:
return '{0:.0f}{1}'.format(round(x), s)